diff --git a/include/lambda_wrapper.h b/include/lambda_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..9c0eda0ec39252552cfe4c2e93986d2c90d41413 --- /dev/null +++ b/include/lambda_wrapper.h @@ -0,0 +1,38 @@ +#ifndef __LAMBDA_WRAPPER_H__ +#define __LAMBDA_WRAPPER_H__ + +#include <cassert> +#include <vector> + +#include <cmath> +#include <cstdint> +#include <iostream> +#include "function.h" +#include <functional> + +namespace numerics { + + template <typename argtype, typename valtype> + class LambdaWrapper : public Function<argtype, valtype> { + public: + LambdaWrapper(std::function<valtype(argtype)> func, size_t indim, size_t outdim) : input_dim(indim), output_dim(outdim), f(func) {} + + virtual valtype operator()(argtype arg) const override { + return f(arg); + }; + + virtual size_t input_dimension() const override { + return input_dim; + } + + virtual size_t output_dimension() const override { + return output_dim; + } + protected: + size_t input_dim, output_dim; + std::function < valtype(argtype)> f; + }; + +}; // namespace numerics + +#endif