Skip to content

Commit

Permalink
EAMxx: improve horiz_avg testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mahf708 committed Dec 3, 2024
1 parent c243f18 commit c6fa3aa
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions components/eamxx/src/diagnostics/tests/horiz_avg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ TEST_CASE("horiz_avg") {
using TeamMember = typename TeamPolicy::member_type;
using KT = ekat::KokkosTypes<DefaultDevice>;
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;

// A numerical tolerance
auto tol = std::numeric_limits<Real>::epsilon() * 100;

// A world comm
ekat::Comm comm(MPI_COMM_WORLD);

Expand All @@ -44,19 +48,17 @@ TEST_CASE("horiz_avg") {
constexpr int dim3 = 4;
const int ngcols = 6 * comm.size();

auto gm1 = create_gm(comm, ngcols, 1);
auto gm2 = create_gm(comm, ngcols, nlevs);
auto grid1 = gm1->get_grid("Physics");
auto grid2 = gm2->get_grid("Physics");
auto gm = create_gm(comm, ngcols, nlevs);
auto grid = gm->get_grid("Physics");

// Input (randomized) qc
FieldLayout scalar1d_layout{{COL}, {ngcols}};
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}};

FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid1->name());
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid2->name());
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid2->name());
FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid->name());
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid->name());
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid->name());

Field qc1(qc1_fid);
Field qc2(qc2_fid);
Expand All @@ -78,8 +80,8 @@ TEST_CASE("horiz_avg") {
register_diagnostics();

ekat::ParameterList params;
// REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm,
// params)); // No 'field_name' parameter
REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm,
params)); // No 'field_name' parameter

// Set time for qc and randomize its values
qc1.get_header().get_tracking().update_time_stamp(t0);
Expand All @@ -90,16 +92,16 @@ TEST_CASE("horiz_avg") {
randomize(qc3, engine, pdf);

// Create and set up the diagnostic
params.set("grid_name", grid1->name());
params.set("grid_name", grid->name());
params.set<std::string>("field_name", "qc");
auto diag1 = diag_factory.create("HorizAvgDiag", comm, params);
auto diag2 = diag_factory.create("HorizAvgDiag", comm, params);
auto diag3 = diag_factory.create("HorizAvgDiag", comm, params);
diag1->set_grids(gm1);
diag2->set_grids(gm2);
diag3->set_grids(gm2);
diag1->set_grids(gm);
diag2->set_grids(gm);
diag3->set_grids(gm);

auto area = grid1->get_geometry_data("area");
auto area = grid->get_geometry_data("area");

diag1->set_required_field(qc1);
diag1->initialize(t0, RunType::Initial);
Expand All @@ -109,7 +111,7 @@ TEST_CASE("horiz_avg") {

FieldIdentifier diag0_fid("qc_horiz_avg_manual",
scalar1d_layout.clone().strip_dim(COL), kg / kg,
grid1->name());
grid->name());
Field diag0(diag0_fid);
diag0.allocate_view();
auto diag0_v = diag0.get_view<Real>();
Expand Down Expand Up @@ -139,7 +141,9 @@ TEST_CASE("horiz_avg") {
Kokkos::deep_copy(qc1_v, wavg);
diag1->compute_diagnostic();
auto diag1_v2_host = diag1_f.get_view<Real, Host>();
REQUIRE(std::abs(diag1_v2_host() - wavg) < sp(1e-6));
REQUIRE_THAT(diag1_v2_host(),
Catch::Matchers::WithinRel(
wavg, tol)); // Catch2's floating point comparison

// other diags
// Set qc2_v to 5.0 to get weighted average of 5.0
Expand All @@ -155,13 +159,13 @@ TEST_CASE("horiz_avg") {
auto diag2_v_host = diag2_f.get_view<Real *, Host>();

for(int i = 0; i < nlevs; ++i) {
REQUIRE(std::abs(diag2_v_host(i) - wavg) < sp(1e-6));
REQUIRE_THAT(diag2_v_host(i), Catch::Matchers::WithinRel(wavg, tol));
}

auto qc3_v = qc3.get_view<Real ***>();
FieldIdentifier diag3_manual_fid("qc_horiz_avg_manual",
scalar3d_layout.clone().strip_dim(COL),
kg / kg, grid2->name());
kg / kg, grid->name());
Field diag3_manual(diag3_manual_fid);
diag3_manual.allocate_view();
auto diag3_manual_v = diag3_manual.get_view<Real **>();
Expand Down

0 comments on commit c6fa3aa

Please sign in to comment.