diff --git a/src/test_gradient_descent.cpp b/src/test_gradient_descent.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76564b2868a4f3a436a5d2eabb628579fbf73952 --- /dev/null +++ b/src/test_gradient_descent.cpp @@ -0,0 +1,120 @@ +#define CATCH_CONFIG_MAIN // This tells Catch to provide a main() - only do this + // in one cpp file +#include "catch.hpp" +#include "differentiator.h" +#include "optimizer.h" +#include "sample_functions.h" +#define _USE_MATH_DEFINES +#include <cmath> +#include <cstdint> +#include <fstream> +#include <iomanip> +#include <string> +#include <vector> + +TEST_CASE("Gradient Descent on functions with one dimensional argument", + "[gradient_descent]") { + using namespace numerics; + using std::cerr; + using std::cout; + using std::endl; + FivePointDifferentiator fp; + double step = 1e-4; + GradientDescent optimizer(step); + + Coordinate<double> position1(1), position2(2), ref2(2), res1(1), res2(2); + + double precision = 1e-9; + double error = 1e-5; + + using funct = Function<Coordinate<double>, double>; + + SECTION("Testing 1D parabola") { + funct& parabola = functions::parabola_1d; + cerr << "Tested function: " << functions::parabola_1d << endl; + position1[0] = 10; + res1 = optimizer.optimize(parabola, position1, precision); + REQUIRE(res1.l2Norm() < error); + REQUIRE(std::abs(parabola(res1)) < error); + } + SECTION("Testing 1D higher-order parabola") { + double ref_pos = sqrt(4.6 / 4); + + funct& higher_parabola = functions::higher_parabola_1d; + cerr << "Tested function: " << functions::higher_parabola_1d << endl; + position1[0] = 10; + res1 = optimizer.optimize(higher_parabola, position1, precision); + REQUIRE(std::abs(res1[0] - ref_pos) < error); + + position1[0] = -ref_pos + 0.5; + res1 = optimizer.optimize(higher_parabola, position1, precision); + REQUIRE(std::abs(res1[0] - ref_pos) < error); + } + SECTION("Testing 1D reference Polynomial") { + funct& reference_polynomial = functions::ref_polynomial_1d; + cerr << "Tested function: " << functions::ref_polynomial_1d << endl; + position1[0] = 20; + res1 = optimizer.optimize(reference_polynomial, position1, precision); + REQUIRE(std::abs(reference_polynomial(res1)) < error); + } + + SECTION("Testing narrow 1d Lennard Jones potential") { + double rm = functions::narrow_LJ.get_rm(); + + funct& narrow_LJ = functions::narrow_LJ; + position1[0] = rm * 3; + res1 = optimizer.optimize(narrow_LJ, position1, precision); + REQUIRE(std::abs(res1[0] - rm) < error); + + position1[0] = rm / 2.0; + res1 = optimizer.optimize(narrow_LJ, position1, precision); + REQUIRE(std::abs(res1[0] - rm) < error); + } + + SECTION("Testing wide 1d Lennard Jones potential") { + double rm = functions::wide_LJ.get_rm(); + funct& wide_LJ = functions::wide_LJ; + position1[0] = rm * 3; + res1 = optimizer.optimize(wide_LJ, position1, precision); + REQUIRE(std::abs(res1[0] - rm) < error); + + position1[0] = rm / 2.0; + res1 = optimizer.optimize(wide_LJ, position1, precision); + REQUIRE(std::abs(res1[0] - rm) < error); + } + + SECTION("Testing 2D parabola") { + funct& parabola2d = functions::parabola_2d; + position2[0] = 10; + position2[1] = -123.0; + res2 = optimizer.optimize(parabola2d, position2, precision); + REQUIRE(res2.l2Norm() < error); + REQUIRE(std::abs(parabola2d(res2)) < error); + position2[0] = -133; + position2[1] = -0.1; + res2 = optimizer.optimize(parabola2d, position2, precision); + REQUIRE(res2.l2Norm() < error); + REQUIRE(std::abs(parabola2d(res2)) < error); + } + SECTION("Testing 2D parabola") { + funct& periodic_wave = functions::periodic_wave_2d; + uint32_t k = 13; + uint32_t l = 31; + + ref2[0] = (2 * k - 0.5) * M_PI; + ref2[1] = (2 * l + 1.0) * M_PI; + position2[0] = ref2[0] + 0.5; + position2[1] = ref2[1] - 0.2; + res2 = optimizer.optimize(periodic_wave, position2, precision); + REQUIRE((res2 - ref2).l2Norm() < error); + k = -4; + l = 5; + + ref2[0] = (2 * k - 0.5) * M_PI; + ref2[1] = (2 * l + 1.0) * M_PI; + position2[0] = ref2[0] - 0.3; + position2[1] = ref2[1] + 0.4; + res2 = optimizer.optimize(periodic_wave, position2, precision); + REQUIRE((res2 - ref2).l2Norm() < error); + } +}