Skip to content
Snippets Groups Projects
Select Git revision
  • gpu_fft
  • main default protected
  • im2col
  • int8_quant
  • fix_warnings
  • graph
  • Pruning
  • generic
  • gpu
  • cuda
10 results

cudaMain.cc

Blame
  • user avatar
    Jakob Spahn authored
    f0902911
    History
    cudaMain.cc 699 B
    // clang++ --std=c++20 lantern/tensor/*.cc lantern/tensor/accel/cuda/*.cc -I . -o cuda
    #include "include/lantern.h"
    #include "lantern/tensor/accel/cuda/CUDABackend.h"
    
    #include <iostream>
    
    int main() {
        lt::manage::setDefaultGate<lt::CUDATensor>();
    
        auto x(lt::Tensor::randn<float>({1, 1, 28, 28})),
                y(lt::Tensor::randn<float>({1, 1, 5, 5})),
                b(lt::Tensor::randn<float>({1}));
    
    
        lt::CUDABackend::getInstance().conv_fft = true;
        auto res = lt::conv2d(x, y, b);
        std::cout << "res: " << res << std::endl;
        lt::CUDABackend::getInstance().conv_fft = false;
        auto res2 = lt::conv2d(x, y, b);
        std::cout << "res2: " << res2 << std::endl;
    
    
        return 0;
    }