diff --git a/include/differential.h b/include/differential.h index ca2e952e5abd6e1cffaf26063679f52358a4829e..d1cf91fce671ddc32e1deb49fc0f6bf32588f492 100644 --- a/include/differential.h +++ b/include/differential.h @@ -7,14 +7,15 @@ namespace numerics { template <typename argtype, typename valtype> -class ScalarDifferential : public Function<argtype, valtype> { +class ScalarDifferential + : public Function<Coordinate<argtype>, Coordinate<valtype>> { public: ScalarDifferential(Function<Vector<argtype>, valtype>& _func, Differentiator<argtype, valtype>& _diff, argtype _h = 1e-8) : func(_func), diff(_diff), h(_h) {} - Vector<valtype> operator()(numerics::Coordinate<argtype> x) const override { + Vector<valtype> operator()(Coordinate<argtype> x) const override { assert(x.dimension() == input_dimension()); return diff(func, x, Vector<valtype>(x.dimension(), h)); } @@ -30,7 +31,7 @@ class ScalarDifferential : public Function<argtype, valtype> { }; template <typename argtype, typename valtype> -Function<Vector<argtype>, Vector<valtype>> +ScalarDifferential<argtype, valtype> nabla(Function<Vector<argtype>, valtype>& func, Differentiator<argtype, valtype>& diff) { return ScalarDifferential<argtype, valtype>(func, diff);