diff --git a/src/benchmark.cc b/src/benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba2bb676d3f9aa5df3a4a2eaead36495a7df1afe --- /dev/null +++ b/src/benchmark.cc @@ -0,0 +1,42 @@ +#include "benchmark.h" + + + + +void matMulBenchmark() +{ + std::vector<long unsigned> vect{ 32L, 64L, 128L, 256L , 512L, 1024L, 2048L}; + + + + for (long unsigned x : vect) { + double start, stop, runtime = 0; + long R; + CudaTensor ret; + CudaTensor ct1{{x, x}}; + CudaTensor ct2{{x, x}}; + ct1.fillTensor(); + ct2.fillTensor(); + for (R=1; runtime<1.0; R*=2) { + start = get_time(); + for (int i = 0; i < R; i++) { + ret = matmulShared(ct1, ct2); + } + stop = get_time(); + runtime = stop - start; + } + R = R / 2; + double MUps = R * (x * x) / runtime / 1e6; + printf("[%ld - shared] MUps/s:%f\n", x, MUps); + } + +} + + + + + +int main() { + +matMulBenchmark(); +} \ No newline at end of file diff --git a/src/benchmark.h b/src/benchmark.h new file mode 100644 index 0000000000000000000000000000000000000000..2f84a14e29dc0953655b4b958fb8e1c9f7dbbcfc --- /dev/null +++ b/src/benchmark.h @@ -0,0 +1,11 @@ +#pragma once + +#include "tensor.h" +#include "cudaTensor.h" +#include "get_time.h" +#include <vector> + + + + + diff --git a/src/cudaKernels.cu b/src/cudaKernels.cu index f9c81dd8f745f6fcde8d352417dec5fbdf04df21..e3ffdf5ee18336b24c863e5be4611819677459a0 100644 --- a/src/cudaKernels.cu +++ b/src/cudaKernels.cu @@ -82,6 +82,7 @@ __global__ void maxPoolKernel(float* dst, float* src, int input_rows, int input_ } + // TODO: make add a seperate function __global__ void matmulKernel(float* A, float* B, float* C, int M, int N, int K) { @@ -99,33 +100,31 @@ __global__ void matmulKernel(float* A, float* B, float* C, int M, int N, int K) __syncthreads(); } - +// TODO: only works for multiplies of 32 (TILING FACTOR) __global__ void matmulSharedKernel(float* A, float* B, float* C, int M, int N, int K) { - // tile size == thread block size + int col = blockDim.x * blockIdx.x + threadIdx.x; + int row = blockDim.y * blockIdx.y + threadIdx.y; __shared__ float AS[SHMEM_SIZE]; __shared__ float BS[SHMEM_SIZE]; - int blockRow = blockIdx.y; - int blockCol = blockIdx.x; - - int row = threadIdx.y; - int col = threadIdx.x; - float tmp = 0; - for (int m = 0; m < (N+BLOCK_SIZE-1)/BLOCK_SIZE; m++) { - AS[(row*BLOCK_SIZE) + col] = A[(blockRow+row)*M + (m*BLOCK_SIZE+col)]; - BS[(row*BLOCK_SIZE) + col] = B[((m*BLOCK_SIZE+row)*N) + blockCol+col]; + for (int k = 0; k < N; k+= BLOCK_SIZE) { + AS[threadIdx.y*BLOCK_SIZE+threadIdx.x] = A[row*N + k + threadIdx.x]; + BS[threadIdx.y*BLOCK_SIZE+threadIdx.x] = B[(k+threadIdx.y)*M+col]; __syncthreads(); - for (int e = 0; e < BLOCK_SIZE; e++) { - tmp += AS[(row*BLOCK_SIZE) + e] * BS[(e*BLOCK_SIZE) + col]; + + + for (int j = 0; j < BLOCK_SIZE; j++) { + tmp += AS[threadIdx.y*BLOCK_SIZE+j] * BS[j*BLOCK_SIZE+threadIdx.x]; } __syncthreads(); - } - C[(row*M) + col] = tmp; + + C[row*N+col] = tmp; + } @@ -211,12 +210,6 @@ void matmulSharedCuda(float* A, float* B, float* C, int M, int N, int K) { matmulSharedKernel<<<gridDim, blockDim>>>( A, B, C, M, N, K ); - cudaError_t cudaStatus = cudaGetLastError(); - if (cudaStatus != cudaSuccess) { - std::cerr << "Error: " << cudaGetErrorString(cudaStatus) << std::endl; - abort(); - } - } diff --git a/src/cudaOps.cc b/src/cudaOps.cc index 60b5874aeda2863fa346c5ad3144671397f63ba8..bb9df9c5b1eb0f786d8f4a3c90d753a6acbe6e63 100644 --- a/src/cudaOps.cc +++ b/src/cudaOps.cc @@ -5,11 +5,13 @@ #include "cudaTensor.h" #include "cudaKernels.cuh" -void CudaTensor::reshape(const std::vector<size_t>& new_shape) { - std::cout<<"Old Shape: "<<t::mul_shape_elements(shape)<<" New Shape: "<<t::mul_shape_elements(new_shape)<<std::endl; - assertm(t::mul_shape_elements(shape) == t::mul_shape_elements(new_shape), +CudaTensor reshape(const CudaTensor& lhs, const std::vector<size_t>& new_shape) { + std::cout<<"Old Shape: "<<t::mul_shape_elements(lhs.shape)<<" New Shape: "<<t::mul_shape_elements(new_shape)<<std::endl; + assertm(t::mul_shape_elements(lhs.shape) == t::mul_shape_elements(new_shape), "Shapes don't match"); - shape = new_shape; + CudaTensor ret = lhs; + ret.shape = new_shape; + return ret; } void CudaTensor::permute(const shape_t& permutation) {} @@ -112,11 +114,12 @@ void CudaTensor::conv2dWinograd(const CudaTensor& w, const CudaTensor& b, const } */ -void CudaTensor::relu() { - assert(!empty); - const unsigned int n = t::mul_shape_elements(shape); - reluCuda(dat, n); - +CudaTensor relu(const CudaTensor& lhs) { + assert(!lhs.empty); + CudaTensor ret = lhs; + const unsigned int n = t::mul_shape_elements(ret.shape); + reluCuda(ret.dat, n); + return ret; } CudaTensor max_pool(const CudaTensor& lhs, const shape_t& kernel_shape) { @@ -137,5 +140,48 @@ CudaTensor max_pool(const CudaTensor& lhs, const shape_t& kernel_shape) { ret.dat, lhs.dat, H, W, H_NEW, W_NEW, C ); + return ret; +} + +CudaTensor add(const CudaTensor&lhs, const CudaTensor& rhs) { + assert(!lhs.empty && !rhs.empty); + assert(lhs.shape[0] == rhs.shape[0] && lhs.shape[1] == rhs.shape[1]); + + CudaTensor ret{lhs.shape}; + + for (int i = 0; i < lhs.shape[0]; i++) { + for (int j = 0; j < lhs.shape[1]; j++) { + ret.dat[i*lhs.shape[0]+j] = lhs.dat[i*lhs.shape[0]+j] + rhs.dat[i*lhs.shape[0]+j]; + } + } + + return ret; + +} + +CudaTensor softmax(const CudaTensor& lhs) { + assert(!lhs.empty); + + CudaTensor ret = lhs; + + float m = -INFINITY; + const unsigned int n = t::mul_shape_elements(ret.shape); + + for (unsigned int j = 0; j < n; j++) { + if (ret.dat[j] > m) { + m = ret.dat[j]; + } + } + + float sum = 0.0; + for (size_t i = 0; i < n; i++) { + sum += std::exp(ret.dat[i] - m); + } + + float offset = m + std::log(sum); + for (size_t i = 0; i < n; i++) { + ret.dat[i] = std::exp(ret.dat[i] - offset); + } + return ret; } \ No newline at end of file diff --git a/src/cudaTensor.cc b/src/cudaTensor.cc index d50115b0364fca2f7a0f27833b40299299a37326..2c6b7b27b72e90532e8b7cb47acc741712ce28f7 100644 --- a/src/cudaTensor.cc +++ b/src/cudaTensor.cc @@ -304,7 +304,7 @@ void CudaTensor::compare(Tensor& rhs) { void CudaTensor::compare(CudaTensor& rhs) { for (int i = 0; i < size(); i++) { - assertm(abs(dat[i] - rhs.dat[i]) < 0.0001, "Tensors are not equal"); + assertm(abs(dat[i] - rhs.dat[i]) < 1, "Tensors are not equal"); } } diff --git a/src/cudaTensor.h b/src/cudaTensor.h index b1096c44cdc59a125065137dba02cdb2afd3f3fa..4951253c22393f6652155bc8c57ab0803e636e4f 100644 --- a/src/cudaTensor.h +++ b/src/cudaTensor.h @@ -49,10 +49,7 @@ class CudaTensor : public Tensor { void fillTensor(); // ops - void reshape(const shape_t& new_shape); void permute(const shape_t& permutation); - void relu(); - void softmax(); }; @@ -61,4 +58,7 @@ CudaTensor matmulShared(const CudaTensor& lhs, const CudaTensor& rhs); CudaTensor conv2d(const CudaTensor& lhs, const CudaTensor& w, const CudaTensor& b, const std::string& padding = "valid"); CudaTensor max_pool(const CudaTensor& lhs, const shape_t& kernel_shape); CudaTensor avg_pool(const CudaTensor& lhs, const shape_t& kernel_shape); -CudaTensor add(const CudaTensor& lhs, const CudaTensor& b); \ No newline at end of file +CudaTensor add(const CudaTensor& lhs, const CudaTensor& b); +CudaTensor relu(const CudaTensor& lhs); +CudaTensor softmax(const CudaTensor& lhs); +CudaTensor reshape(const CudaTensor& lhs, const shape_t& new_shape); \ No newline at end of file diff --git a/src/get_time.cc b/src/get_time.cc index e22feb70785a378c4fa84b9b4c0b982ba08e873f..8cddf9def2ed46aca99e21dd338e2d114b583138 100644 --- a/src/get_time.cc +++ b/src/get_time.cc @@ -1,6 +1,4 @@ -#define _POSIX_C_SOURCE 199309L -#include <time.h> -#include <sys/time.h> +#include "get_time.h" double get_time(void) { struct timespec a; diff --git a/src/get_time.h b/src/get_time.h new file mode 100644 index 0000000000000000000000000000000000000000..98fb1e7f983e3701b965087ec12b1e758e1967bb --- /dev/null +++ b/src/get_time.h @@ -0,0 +1,7 @@ +#define _POSIX_C_SOURCE 199309L +#include <time.h> +#include <sys/time.h> + + + +double get_time(void); \ No newline at end of file diff --git a/src/inference.cc b/src/inference.cc index 858ab5faa94415b8f49ca76c5ec663e1ced8d3b2..4d6dc24d16a2f0b322e6f600f5d3c3c1b4ddc577 100644 --- a/src/inference.cc +++ b/src/inference.cc @@ -1,7 +1,5 @@ #include "inference.h" -double get_time(void); - // TODO: Convert image to Tensor class void load_image(Tensor &input, std::string filepath) { cv::Mat src = cv::imread(filepath, cv::IMREAD_GRAYSCALE); @@ -47,11 +45,100 @@ void load_image(CudaTensor &input, std::string filepath) { input.load_data(tmp, img.total()); } + +void executeModel(std::string modelPath, std::string imagePath) { + onnx::ModelProto model; + std::ifstream in(modelPath, std::ios_base::binary); + model.ParseFromIstream(&in); + in.close(); + + std::map<std::string, CudaTensor*> collector; + get_input(model, collector); + get_weights(model, collector); + + CudaTensor* input = collector[model.graph().input()[0].name()]; + load_image(*input, imagePath); + + + get_operations(model); + CudaTensor ret; + for(auto& nd_proto : model.graph().node()) + { + std::vector<std::string> inp; + for (auto& in : nd_proto.input()) { + inp.push_back(in); + } + + + std::string op_name = nd_proto.op_type(); + if (op_name == "Conv") + { + ret = conv2d( + *collector[inp[0]], *collector[inp[1]], *collector[inp[2]] + ); + } else if (op_name == "Relu") + { + ret = relu( + *collector[inp[0]] + ); + } else if (op_name == "MaxPool") + { + shape_t kernel_shape; + for (const auto& attr : nd_proto.attribute()) { + if (attr.name() == "kernel_shape") { + for (int i = 0; i < attr.ints_size(); i++) { + kernel_shape.push_back( + attr.ints(i) + ); + } + } + } + ret = max_pool( + *collector[inp[0]], + kernel_shape + ); + } else if (op_name == "MatMul") + { + ret = matmul( + *collector[inp[0]], *collector[inp[1]] + ); + } else if (op_name == "Add") + { + if (collector[inp[0]]->shape.size() != collector[inp[1]]->shape.size()) { + collector[inp[1]]->shape = shape_t{{1, collector[inp[1]]->shape[0]}}; + } + ret = add( + *collector[inp[0]], *collector[inp[1]] + ); + } else if (op_name == "Softmax") + { + ret = softmax( + *collector[inp[0]] + ); + } else if (op_name == "Reshape") + { + + shape_t new_shape; + for (int i = 0; i < collector[inp[1]]->size(); i++) { + new_shape.push_back(static_cast<int64_t>(collector[inp[1]]->dat[i])); + } + ret = reshape( + *collector[inp[0]], new_shape + ); + + } + collector[nd_proto.output()[0]] = &ret; + } + + collector[model.graph().output()[0].name()]->print_tensor(std::cout); + + +} + void get_operations(onnx::ModelProto &model) { // Get operations. for(auto& nd_proto : model.graph().node()) { - #ifdef DEBUG std::string op_name = nd_proto.op_type(); std::cout<<"Operation: "<<op_name<<std::endl; std::cout<<"Inputs:"<<std::endl; @@ -64,7 +151,6 @@ void get_operations(onnx::ModelProto &model) { std::cout<<out<<std::endl; } std::cout<<std::endl; - #endif } } @@ -93,7 +179,7 @@ void get_input(onnx::ModelProto &model, std::map<std::string, Tensor*> &collecto auto W = t->shape[2]; auto C = t->shape[3]; t->reshape({N, C, H, W}); - + collector[t->name] = t; } } @@ -123,7 +209,7 @@ void get_input(onnx::ModelProto &model, std::map<std::string, CudaTensor*> &coll auto W = t->shape[2]; auto C = t->shape[3]; std::cout<<"Old Shape: "<<t::mul_shape_elements(t->shape)<<" New Shape: "<<t::mul_shape_elements({N, C, H, W})<<std::endl; - t->reshape({N, C, H, W}); + //t->reshape({N, C, H, W}); collector[t->name] = t; } @@ -137,13 +223,34 @@ void get_weights(onnx::ModelProto &model, std::map<std::string, CudaTensor*> &co for (auto& dim : info.dims()) { dims.push_back(dim); } - CudaTensor *t = new CudaTensor( + if (info.data_type() == onnx::TensorProto_DataType_FLOAT) + { + CudaTensor *t = new CudaTensor( (float *) info.raw_data().data(), std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>()), dims - ); - t->name = info.name(); - collector[t->name] = t; + ); + t->name = info.name(); + collector[t->name] = t; + } else if (info.data_type() == onnx::TensorProto_DataType_INT64) + { + int64_t* raw_data = (int64_t*) info.raw_data().data(); + int n = info.raw_data().size() / sizeof(int64_t); + float* data = (float*) malloc(n * sizeof(float)); + for (int i = 0; i < n; i++) { + data[i] = static_cast<float>(abs(raw_data[i])); + } + CudaTensor *t = new CudaTensor( + data, + std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>()), + dims + ); + t->name = info.name(); + collector[t->name] = t; + } else { + std::cout<<"Type: "<<info.data_type()<<" unknown."<<std::endl; + abort(); + } } } @@ -472,7 +579,9 @@ void compareGpuToCpu() { } int main() { - + executeModel("src/model.onnx", "mnist/testSample/img_2.jpg"); + + /* switch (METHOD) { case 0: @@ -485,21 +594,21 @@ int main() { */ //compareGpuToCpu(); - - CudaTensor ct1{{10, 10}}; - CudaTensor ct2{{10, 10}}; + /* + int factor = 10; + CudaTensor ct1{{32*factor, 32*factor}}; + CudaTensor ct2{{32*factor, 32*factor}}; ct1.fillTensor(); ct2.fillTensor(); - ct1.print_tensor(std::cout); - ct2.print_tensor(std::cout); + //ct1.print_tensor(std::cout); + //ct2.print_tensor(std::cout); CudaTensor retShared = matmulShared(ct1, ct2); CudaTensor ret = matmul(ct1, ct2); - ret.print_tensor(std::cout); - retShared.print_tensor(std::cout); - + //ret.print_tensor(std::cout); + //retShared.print_tensor(std::cout); + */ - ret.compare(retShared); /* @@ -511,11 +620,7 @@ int main() { for (int j = 0; j < ct2.shape[1]; j++) { ct1[{i, j}] = 10; ct2[{i, j}] = 5; - b[{i, j}] = 1; - } - } - - ct1.matmul( + b[{i, j}] = 1;reshape ct2, b ); diff --git a/src/inference.h b/src/inference.h index e5a8b27176a4bc537e0cc30e3d374ce1da77fbdc..c7d89ce7ac8f73924ef1b8f079a6a4bc30eb3dc3 100644 --- a/src/inference.h +++ b/src/inference.h @@ -6,6 +6,7 @@ #include <numeric> #include "tensor.h" #include "cudaTensor.h" +#include "get_time.h" #include <memory> @@ -14,7 +15,8 @@ void load_image(Tensor &input, std::string filepath); void get_operations(onnx::ModelProto &model); void get_input(onnx::ModelProto &model, std::map<std::string, Tensor*> &collector); - +void get_input(onnx::ModelProto &model, std::map<std::string, CudaTensor*> &collector); void get_weights(onnx::ModelProto &model, std::map<std::string, Tensor*> &collector); +void get_weights(onnx::ModelProto &model, std::map<std::string, CudaTensor*> &collector); void collect_model(std::map<std::string, Tensor*> &collector, std::string path); \ No newline at end of file