Skip to content

Commit

Permalink
migrate lattice tests to Catch2v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Sajid Ali committed Apr 29, 2024
1 parent 25e38f2 commit 338b561
Show file tree
Hide file tree
Showing 6 changed files with 457 additions and 479 deletions.
2 changes: 1 addition & 1 deletion src/synergia/lattice/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ if(BUILD_PYTHON_BINDINGS)

endif()

#add_subdirectory(tests)
add_subdirectory(tests)
21 changes: 9 additions & 12 deletions src/synergia/lattice/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
if(BUILD_PYTHON_BINDINGS)

add_py_test(test_lattice.py)
add_py_test(test_dynamic_lattice.py)
add_py_test(test_parse_matrix.py)

endif()

add_executable(test_lattice test_lattice.cc)
target_link_libraries(test_lattice synergia_lattice synergia_test_main)
add_executable(test_lattice test_lattice.cc ${test_main})
target_link_libraries(test_lattice PRIVATE synergia_lattice ${testing_libs})
add_mpi_test(test_lattice 1)

add_executable(test_mx_expr test_mx_expr.cc)
target_link_libraries(test_mx_expr synergia_lattice synergia_test_main
${kokkos_libs})
add_executable(test_mx_expr test_mx_expr.cc ${test_main})
target_link_libraries(test_mx_expr PRIVATE synergia_lattice ${testing_libs})
add_mpi_test(test_mx_expr 1)

add_executable(test_madx_parser test_madx_parser.cc)
target_link_libraries(test_madx_parser synergia_lattice synergia_test_main
${kokkos_libs})
add_executable(test_madx_parser test_madx_parser.cc ${test_main})
target_link_libraries(test_madx_parser PRIVATE synergia_lattice ${testing_libs})
add_mpi_test(test_madx_parser 1)

copy_file(foo.dbx test_madx_parser)

add_executable(test_dynamic_lattice test_dynamic_lattice.cc)
target_link_libraries(test_dynamic_lattice synergia_lattice synergia_test_main)
add_executable(test_dynamic_lattice test_dynamic_lattice.cc ${test_main})
target_link_libraries(test_dynamic_lattice PRIVATE synergia_lattice
${testing_libs})
add_mpi_test(test_dynamic_lattice 1)
41 changes: 23 additions & 18 deletions src/synergia/lattice/tests/test_dynamic_lattice.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "synergia/utils/catch.hpp"

#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>

#include "synergia/lattice/madx_reader.h"
#include "synergia/lattice/mx_parse.h"

#include "synergia/utils/cereal_files.h"


TEST_CASE("print")
{
#if 0
Expand Down Expand Up @@ -40,7 +40,7 @@ TEST_CASE("dynamic lattice")
d, at=0.6, from=c;
endsequence;
)";

MadX_reader reader;
reader.parse(str);

Expand All @@ -50,31 +50,36 @@ TEST_CASE("dynamic lattice")
auto& elms = lattice.get_elements();

// original a->k1
CHECK(lattice.get_elements().front().get_double_attribute("k1")
== Approx(2.0).margin(1e-12));
REQUIRE_THAT(lattice.get_elements().front().get_double_attribute("k1"),
Catch::Matchers::WithinAbs(2.0, 1e-12));

// set x
lattice.get_lattice_tree().set_variable("x", 3.0);

// a->k1 after setting a new x
CHECK(lattice.get_elements().front().get_double_attribute("k1")
== Approx(4.0).margin(1e-12));
REQUIRE_THAT(lattice.get_elements().front().get_double_attribute("k1"),
Catch::Matchers::WithinAbs(4.0, 1e-12));

// find element d
auto it = elms.end();
--it; --it;
--it;
--it;

// original values
CHECK(it->get_double_attribute("l") == Approx(0.8).margin(1e-12));
CHECK(it->get_double_attribute("k1") == Approx(1.0).margin(1e-12));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbsMatcher(0.8, 1e-12));
REQUIRE_THAT(it->get_double_attribute("k1"),
Catch::Matchers::WithinAbsMatcher(1.0, 1e-12));

// set o->l to 0.3
lattice.get_lattice_tree().set_element_attribute("o", "l", 0.3);
lattice.get_lattice_tree().print();

// updated values
CHECK(it->get_double_attribute("l") == Approx(1.2).margin(1e-12));
CHECK(it->get_double_attribute("k1") == Approx(1.5).margin(1e-12));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbsMatcher(1.2, 1e-12));
REQUIRE_THAT(it->get_double_attribute("k1"),
Catch::Matchers::WithinAbsMatcher(1.5, 1e-12));

// element c
--it;
Expand All @@ -86,8 +91,8 @@ TEST_CASE("dynamic lattice")
REQUIRE_NOTHROW(it->set_double_attribute("k1", "o->l*3"));

// check value 0.3*3 = 0.9
CHECK(it->get_double_attribute("k1") == Approx(0.9).margin(1e-12));

REQUIRE_THAT(it->get_double_attribute("k1"),
Catch::Matchers::WithinAbsMatcher(0.9, 1e-12));
}

TEST_CASE("serialization")
Expand All @@ -113,7 +118,7 @@ TEST_CASE("serialization")
d, at=0.6, from=c;
endsequence;
)";

MadX_reader reader;
reader.parse(str);

Expand All @@ -131,8 +136,8 @@ TEST_CASE("serialization")
json_load(lattice, "dyn_lattice.json");

lattice.get_lattice_tree().set_variable("x", 5.0);
CHECK(lattice.get_elements().front().get_double_attribute("k1")
== Approx(6.0).margin(1e-12));
REQUIRE_THAT(lattice.get_elements().front().get_double_attribute("k1"),
Catch::Matchers::WithinAbs(6.0, 1e-12));

std::cout << lattice.get_lattice_tree().mx.to_madx() << "\n";
std::cout << lattice.as_string() << "\n";
Expand Down
114 changes: 60 additions & 54 deletions src/synergia/lattice/tests/test_lattice.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "synergia/utils/catch.hpp"
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>

#include "catch2/matchers/catch_matchers.hpp"
#include "synergia/lattice/lattice.h"
#include "synergia/lattice/madx_reader.h"

Expand Down Expand Up @@ -61,26 +63,26 @@ TEST_CASE("append_fodo")
CHECK(it != lattice.get_elements().end());
CHECK(it->get_name() == "f");
CHECK(it->get_type_name() == "quadrupole");
CHECK(it->get_double_attribute("l") ==
Approx(quad_length).margin(tolerance));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbs(quad_length, tolerance));
++it;

CHECK(it->get_name() == "o");
CHECK(it->get_type_name() == "drift");
CHECK(it->get_double_attribute("l") ==
Approx(drift_length).margin(tolerance));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbs(drift_length, tolerance));
++it;

CHECK(it->get_name() == "d");
CHECK(it->get_type_name() == "quadrupole");
CHECK(it->get_double_attribute("l") ==
Approx(quad_length).margin(tolerance));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbs(quad_length, tolerance));
++it;

CHECK(it->get_name() == "o");
CHECK(it->get_type_name() == "drift");
CHECK(it->get_double_attribute("l") ==
Approx(drift_length).margin(tolerance));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbs(drift_length, tolerance));
}

#if 0
Expand Down Expand Up @@ -117,8 +119,9 @@ TEST_CASE("get_length")
lattice.append(d);
lattice.append(o);

CHECK(lattice.get_length() ==
Approx(2 * quad_length + 2 * drift_length).margin(tolerance));
REQUIRE_THAT(lattice.get_length(),
Catch::Matchers::WithinAbs(2 * quad_length + 2 * drift_length,
tolerance));
}

TEST_CASE("get_total_angle1")
Expand All @@ -136,7 +139,8 @@ TEST_CASE("get_total_angle1")
lattice.append(d);
lattice.append(o);

CHECK(lattice.get_total_angle() == Approx(0.0).margin(tolerance));
REQUIRE_THAT(lattice.get_total_angle(),
Catch::Matchers::WithinAbs(0.0, tolerance));
}

TEST_CASE("get_total_angle2")
Expand Down Expand Up @@ -167,7 +171,8 @@ TEST_CASE("get_total_angle2")
lattice.append(o);
}

CHECK(lattice.get_total_angle() == Approx(2 * pi).margin(tolerance));
REQUIRE_THAT(lattice.get_total_angle(),
Catch::Matchers::WithinAbs(2 * pi, tolerance));
}

TEST_CASE("lattice_from_lattice_element")
Expand Down Expand Up @@ -200,10 +205,10 @@ TEST_CASE("copy_lattice")
CHECK(lattice.get_elements().begin()->get_name() ==
copied_lattice.get_elements().begin()->get_name());

CHECK(lattice.get_elements().begin()->get_length() ==
Approx(foo_length).margin(tolerance));
CHECK(copied_lattice.get_elements().begin()->get_length() ==
Approx(foo_length).margin(tolerance));
REQUIRE_THAT(lattice.get_elements().begin()->get_length(),
Catch::Matchers::WithinAbs(foo_length, tolerance));
REQUIRE_THAT(copied_lattice.get_elements().begin()->get_length(),
Catch::Matchers::WithinAbs(foo_length, tolerance));

for (auto it = copied_lattice.get_elements().begin();
it != copied_lattice.get_elements().end();
Expand All @@ -216,18 +221,18 @@ TEST_CASE("copy_lattice")
copied_lattice.get_elements().begin()->set_double_attribute("l",
new_length);

CHECK(lattice.get_elements().begin()->get_length() ==
Approx(foo_length).margin(tolerance));
CHECK(copied_lattice.get_elements().begin()->get_length() ==
Approx(new_length).margin(tolerance));
REQUIRE_THAT(lattice.get_elements().begin()->get_length(),
Catch::Matchers::WithinAbs(foo_length, tolerance));
REQUIRE_THAT(copied_lattice.get_elements().begin()->get_length(),
Catch::Matchers::WithinAbs(new_length, tolerance));

const double new_energy = 2 * total_energy;
Reference_particle new_reference_particle(charge, mass, new_energy);
copied_lattice.set_reference_particle(new_reference_particle);
CHECK(lattice.get_reference_particle().get_total_energy() ==
Approx(total_energy).margin(tolerance));
CHECK(copied_lattice.get_reference_particle().get_total_energy() ==
Approx(new_energy).margin(tolerance));
REQUIRE_THAT(lattice.get_reference_particle().get_total_energy(),
Catch::Matchers::WithinAbs(total_energy, tolerance));
REQUIRE_THAT(copied_lattice.get_reference_particle().get_total_energy(),
Catch::Matchers::WithinAbs(new_energy, tolerance));
}

TEST_CASE("copy_lattice2")
Expand Down Expand Up @@ -258,10 +263,10 @@ TEST_CASE("copy_lattice_from_lattice_sptr")
CHECK(lattice.get_elements().begin()->get_name() ==
copied_lattice.get_elements().begin()->get_name());

CHECK(lattice.get_elements().begin()->get_length() ==
Approx(foo_length).margin(tolerance));
CHECK(copied_lattice.get_elements().begin()->get_length() ==
Approx(foo_length).margin(tolerance));
REQUIRE_THAT(lattice.get_elements().begin()->get_length(),
Catch::Matchers::WithinAbs(foo_length, tolerance));
REQUIRE_THAT(copied_lattice.get_elements().begin()->get_length(),
Catch::Matchers::WithinAbs(foo_length, tolerance));

for (auto it = copied_lattice.get_elements().begin();
it != copied_lattice.get_elements().end();
Expand All @@ -273,29 +278,30 @@ TEST_CASE("copy_lattice_from_lattice_sptr")

TEST_CASE("lattice_reference_particle")
{
Lattice lattice(name);
Reference_particle reference_particle(1, 1.0, 1.25);
lattice.set_reference_particle(reference_particle);
CHECK(lattice.get_reference_particle().get_mass() == 1.0);
CHECK(lattice.get_reference_particle().get_gamma() == 1.25);
CHECK(lattice.get_reference_particle().get_total_energy() == 1.25);
// set new energy, beta=13/84, gamma=85/84
lattice.get_reference_particle().set_total_energy(85.0/84.0);
CHECK(lattice.get_reference_particle().get_gamma() == Approx(85.0/84.0).margin(tolerance));
CHECK(lattice.get_reference_particle().get_beta() == Approx(13.0/85.0).margin(tolerance));
Lattice lattice(name);
Reference_particle reference_particle(1, 1.0, 1.25);
lattice.set_reference_particle(reference_particle);
CHECK(lattice.get_reference_particle().get_mass() == 1.0);
CHECK(lattice.get_reference_particle().get_gamma() == 1.25);
CHECK(lattice.get_reference_particle().get_total_energy() == 1.25);
// set new energy, beta=13/84, gamma=85/84
lattice.get_reference_particle().set_total_energy(85.0 / 84.0);
REQUIRE_THAT(lattice.get_reference_particle().get_gamma(),
Catch::Matchers::WithinAbs(85.0 / 84.0, tolerance));
REQUIRE_THAT(lattice.get_reference_particle().get_beta(),
Catch::Matchers::WithinAbs(13.0 / 85.0, tolerance));
}

TEST_CASE("lattice_energy")
{
Lattice lattice(name);
Reference_particle reference_particle(1, 1.0, 1.25);
lattice.set_reference_particle(reference_particle);
CHECK(lattice.get_lattice_energy() == 1.25);
lattice.set_lattice_energy(7.0);
CHECK(lattice.get_lattice_energy() == 7.0);
Lattice lattice(name);
Reference_particle reference_particle(1, 1.0, 1.25);
lattice.set_reference_particle(reference_particle);
CHECK(lattice.get_lattice_energy() == 1.25);
lattice.set_lattice_energy(7.0);
CHECK(lattice.get_lattice_energy() == 7.0);
}


TEST_CASE("test_lsexpr")
{
#if 0
Expand Down Expand Up @@ -363,26 +369,26 @@ TEST_CASE("test_serialize1")
CHECK(it != loaded.get_elements().end());
CHECK(it->get_name() == "f");
CHECK(it->get_type_name() == "quadrupole");
CHECK(it->get_double_attribute("l") ==
Approx(quad_length).margin(tolerance));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbs(quad_length, tolerance));
++it;

CHECK(it->get_name() == "o");
CHECK(it->get_type_name() == "drift");
CHECK(it->get_double_attribute("l") ==
Approx(drift_length).margin(tolerance));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbs(drift_length, tolerance));
++it;

CHECK(it->get_name() == "d");
CHECK(it->get_type_name() == "quadrupole");
CHECK(it->get_double_attribute("l") ==
Approx(quad_length).margin(tolerance));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbs(quad_length, tolerance));
++it;

CHECK(it->get_name() == "o");
CHECK(it->get_type_name() == "drift");
CHECK(it->get_double_attribute("l") ==
Approx(drift_length).margin(tolerance));
REQUIRE_THAT(it->get_double_attribute("l"),
Catch::Matchers::WithinAbs(drift_length, tolerance));
}

TEST_CASE("test_serialize2")
Expand Down
Loading

0 comments on commit 338b561

Please sign in to comment.