-
Notifications
You must be signed in to change notification settings - Fork 372
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
EAMxx: add horizontal average diagnostic field
- Loading branch information
Showing
7 changed files
with
313 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
#include "diagnostics/horiz_avg.hpp" | ||
|
||
#include "share/field/field_utils.hpp" | ||
|
||
namespace scream { | ||
|
||
HorizAvgDiag::HorizAvgDiag(const ekat::Comm &comm, | ||
const ekat::ParameterList ¶ms) | ||
: AtmosphereDiagnostic(comm, params) { | ||
const auto &fname = m_params.get<std::string>("field_name"); | ||
m_diag_name = fname + "_horiz_avg"; | ||
} | ||
|
||
void HorizAvgDiag::set_grids( | ||
const std::shared_ptr<const GridsManager> grids_manager) { | ||
const auto &fn = m_params.get<std::string>("field_name"); | ||
const auto &gn = m_params.get<std::string>("grid_name"); | ||
const auto g = grids_manager->get_grid("Physics"); | ||
|
||
add_field<Required>(fn, gn); | ||
|
||
// first clone the area unscaled, we will scale it later in initialize_impl | ||
m_scaled_area = g->get_geometry_data("area").clone(); | ||
} | ||
|
||
void HorizAvgDiag::initialize_impl(const RunType /*run_type*/) { | ||
using namespace ShortFieldTagsNames; | ||
const auto &f = get_fields_in().front(); | ||
const auto &fid = f.get_header().get_identifier(); | ||
const auto &layout = fid.get_layout(); | ||
|
||
EKAT_REQUIRE_MSG(layout.rank() >= 1 && layout.rank() <= 3, | ||
"Error! Field rank not supported by HorizAvgDiag.\n" | ||
" - field name: " + | ||
fid.name() + | ||
"\n" | ||
" - field layout: " + | ||
layout.to_string() + "\n"); | ||
EKAT_REQUIRE_MSG(layout.tags()[0] == COL, | ||
"Error! HorizAvgDiag diagnostic expects a layout starting " | ||
"with the 'COL' tag.\n" | ||
" - field name : " + | ||
fid.name() + | ||
"\n" | ||
" - field layout: " + | ||
layout.to_string() + "\n"); | ||
|
||
FieldIdentifier d_fid(m_diag_name, layout.clone().strip_dim(COL), | ||
fid.get_units(), fid.get_grid_name()); | ||
m_diagnostic_output = Field(d_fid); | ||
m_diagnostic_output.allocate_view(); | ||
|
||
// scale the area field | ||
auto total_area = field_sum<Real>(m_scaled_area, &m_comm); | ||
m_scaled_area.scale(sp(1.0) / total_area); | ||
} | ||
|
||
void HorizAvgDiag::compute_diagnostic_impl() { | ||
const auto &f = get_fields_in().front(); | ||
const auto &d = m_diagnostic_output; | ||
// Call the horiz_contraction impl that will take care of everything | ||
horiz_contraction<Real>(d, f, m_scaled_area, &m_comm); | ||
} | ||
|
||
} // namespace scream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#ifndef EAMXX_HORIZ_AVERAGE_HPP | ||
#define EAMXX_HORIZ_AVERAGE_HPP | ||
|
||
#include "share/atm_process/atmosphere_diagnostic.hpp" | ||
|
||
namespace scream { | ||
|
||
/* | ||
* This diagnostic will calculate the area-weighted average of a field | ||
* across the COL tag dimension, producing an N-1 dimensional field | ||
* that is area-weighted average of the input field. | ||
*/ | ||
|
||
class HorizAvgDiag : public AtmosphereDiagnostic { | ||
public: | ||
// Constructors | ||
HorizAvgDiag(const ekat::Comm &comm, const ekat::ParameterList ¶ms); | ||
|
||
// The name of the diagnostic | ||
std::string name() const { return m_diag_name; } | ||
|
||
// Set the grid | ||
void set_grids(const std::shared_ptr<const GridsManager> grids_manager); | ||
|
||
protected: | ||
#ifdef KOKKOS_ENABLE_CUDA | ||
public: | ||
#endif | ||
void compute_diagnostic_impl(); | ||
|
||
protected: | ||
void initialize_impl(const RunType /*run_type*/); | ||
|
||
// Name of each field (because the diagnostic impl is generic) | ||
std::string m_diag_name; | ||
|
||
// Need area field, let's store it scaled by its norm | ||
Field m_scaled_area; | ||
}; | ||
|
||
} // namespace scream | ||
|
||
#endif // EAMXX_HORIZ_AVERAGE_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
193 changes: 193 additions & 0 deletions
193
components/eamxx/src/diagnostics/tests/horiz_avg_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
#include "catch2/catch.hpp" | ||
#include "diagnostics/register_diagnostics.hpp" | ||
#include "share/field/field_utils.hpp" | ||
#include "share/grid/mesh_free_grids_manager.hpp" | ||
#include "share/util/scream_setup_random_test.hpp" | ||
#include "share/util/scream_universal_constants.hpp" | ||
|
||
namespace scream { | ||
|
||
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols, | ||
const int nlevs) { | ||
const int num_global_cols = ncols * comm.size(); | ||
|
||
using vos_t = std::vector<std::string>; | ||
ekat::ParameterList gm_params; | ||
gm_params.set("grids_names", vos_t{"Point Grid"}); | ||
auto &pl = gm_params.sublist("Point Grid"); | ||
pl.set<std::string>("type", "point_grid"); | ||
pl.set("aliases", vos_t{"Physics"}); | ||
pl.set<int>("number_of_global_columns", num_global_cols); | ||
pl.set<int>("number_of_vertical_levels", nlevs); | ||
|
||
auto gm = create_mesh_free_grids_manager(comm, gm_params); | ||
gm->build_grids(); | ||
|
||
return gm; | ||
} | ||
|
||
TEST_CASE("horiz_avg") { | ||
using namespace ShortFieldTagsNames; | ||
using namespace ekat::units; | ||
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>; | ||
using TeamMember = typename TeamPolicy::member_type; | ||
using KT = ekat::KokkosTypes<DefaultDevice>; | ||
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>; | ||
// A world comm | ||
ekat::Comm comm(MPI_COMM_WORLD); | ||
|
||
// A time stamp | ||
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0}); | ||
|
||
// Create a grids manager - single column for these tests | ||
constexpr int nlevs = 3; | ||
constexpr int dim3 = 4; | ||
const int ngcols = 6 * comm.size(); | ||
|
||
auto gm1 = create_gm(comm, ngcols, 1); | ||
auto gm2 = create_gm(comm, ngcols, nlevs); | ||
auto grid1 = gm1->get_grid("Physics"); | ||
auto grid2 = gm2->get_grid("Physics"); | ||
|
||
// Input (randomized) qc | ||
FieldLayout scalar1d_layout{{COL}, {ngcols}}; | ||
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}}; | ||
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}}; | ||
|
||
FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid1->name()); | ||
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid2->name()); | ||
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid2->name()); | ||
|
||
Field qc1(qc1_fid); | ||
Field qc2(qc2_fid); | ||
Field qc3(qc3_fid); | ||
|
||
qc1.allocate_view(); | ||
qc2.allocate_view(); | ||
qc3.allocate_view(); | ||
|
||
// Construct random number generator stuff | ||
using RPDF = std::uniform_real_distribution<Real>; | ||
RPDF pdf(sp(0.0), sp(200.0)); | ||
|
||
auto engine = scream::setup_random_test(); | ||
|
||
// Construct the Diagnostics | ||
std::map<std::string, std::shared_ptr<AtmosphereDiagnostic>> diags; | ||
auto &diag_factory = AtmosphereDiagnosticFactory::instance(); | ||
register_diagnostics(); | ||
|
||
ekat::ParameterList params; | ||
// REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm, | ||
// params)); // No 'field_name' parameter | ||
|
||
// Set time for qc and randomize its values | ||
qc1.get_header().get_tracking().update_time_stamp(t0); | ||
qc2.get_header().get_tracking().update_time_stamp(t0); | ||
qc3.get_header().get_tracking().update_time_stamp(t0); | ||
randomize(qc1, engine, pdf); | ||
randomize(qc2, engine, pdf); | ||
randomize(qc3, engine, pdf); | ||
|
||
// Create and set up the diagnostic | ||
params.set("grid_name", grid1->name()); | ||
params.set<std::string>("field_name", "qc"); | ||
auto diag1 = diag_factory.create("HorizAvgDiag", comm, params); | ||
auto diag2 = diag_factory.create("HorizAvgDiag", comm, params); | ||
auto diag3 = diag_factory.create("HorizAvgDiag", comm, params); | ||
diag1->set_grids(gm1); | ||
diag2->set_grids(gm2); | ||
diag3->set_grids(gm2); | ||
|
||
auto area = grid1->get_geometry_data("area"); | ||
|
||
diag1->set_required_field(qc1); | ||
diag1->initialize(t0, RunType::Initial); | ||
|
||
diag1->compute_diagnostic(); | ||
auto diag1_f = diag1->get_diagnostic(); | ||
|
||
FieldIdentifier diag0_fid("qc_horiz_avg_manual", | ||
scalar1d_layout.clone().strip_dim(COL), kg / kg, | ||
grid1->name()); | ||
Field diag0(diag0_fid); | ||
diag0.allocate_view(); | ||
auto diag0_v = diag0.get_view<Real>(); | ||
|
||
auto qc1_v = qc1.get_view<Real *>(); | ||
auto area_v = area.get_view<const Real *>(); | ||
|
||
// calculate total area | ||
Real atot = field_sum<Real>(area, &comm); | ||
// calculate weighted avg | ||
Real wavg = sp(0.0); | ||
Kokkos::parallel_reduce( | ||
"HorizAvgDiag::compute_diagnostic_impl::weighted_sum", ngcols, | ||
KOKKOS_LAMBDA(const int icol, Real &local_wavg) { | ||
local_wavg += (area_v[icol] / atot) * qc1_v[icol]; | ||
}, | ||
wavg); | ||
Kokkos::deep_copy(diag0_v, wavg); | ||
|
||
diag1_f.sync_to_host(); | ||
auto diag1_v_h = diag1_f.get_view<Real, Host>(); | ||
REQUIRE(diag1_v_h() == wavg); | ||
|
||
// Try known cases | ||
// Set qc1_v to 1.0 to get weighted average of 1.0 | ||
wavg = sp(1.0); | ||
Kokkos::deep_copy(qc1_v, wavg); | ||
diag1->compute_diagnostic(); | ||
auto diag1_v2_host = diag1_f.get_view<Real, Host>(); | ||
REQUIRE(std::abs(diag1_v2_host() - wavg) < sp(1e-6)); | ||
|
||
// other diags | ||
// Set qc2_v to 5.0 to get weighted average of 5.0 | ||
wavg = sp(5.0); | ||
auto qc2_v = qc2.get_view<Real **>(); | ||
Kokkos::deep_copy(qc2_v, wavg); | ||
|
||
diag2->set_required_field(qc2); | ||
diag2->initialize(t0, RunType::Initial); | ||
diag2->compute_diagnostic(); | ||
auto diag2_f = diag2->get_diagnostic(); | ||
|
||
auto diag2_v_host = diag2_f.get_view<Real *, Host>(); | ||
|
||
for(int i = 0; i < nlevs; ++i) { | ||
REQUIRE(std::abs(diag2_v_host(i) - wavg) < sp(1e-6)); | ||
} | ||
|
||
auto qc3_v = qc3.get_view<Real ***>(); | ||
FieldIdentifier diag3_manual_fid("qc_horiz_avg_manual", | ||
scalar3d_layout.clone().strip_dim(COL), | ||
kg / kg, grid2->name()); | ||
Field diag3_manual(diag3_manual_fid); | ||
diag3_manual.allocate_view(); | ||
auto diag3_manual_v = diag3_manual.get_view<Real **>(); | ||
// calculate diag3_manual by hand | ||
auto p = ESU::get_default_team_policy(dim3 * nlevs, ngcols); | ||
Kokkos::parallel_for( | ||
"HorizAvgDiag::compute_diagnostic_impl::manual_diag3", p, | ||
KOKKOS_LAMBDA(const TeamMember &m) { | ||
const int idx = m.league_rank(); | ||
const int j = idx / nlevs; | ||
const int k = idx % nlevs; | ||
Real sum = sp(0.0); | ||
Kokkos::parallel_reduce( | ||
Kokkos::TeamThreadRange(m, ngcols), | ||
[&](const int icol, Real &accum) { | ||
accum += (area_v(icol) / atot) * qc3_v(icol, j, k); | ||
}, | ||
sum); | ||
Kokkos::single(Kokkos::PerTeam(m), | ||
[&]() { diag3_manual_v(j, k) = sum; }); | ||
}); | ||
diag3->set_required_field(qc3); | ||
diag3->initialize(t0, RunType::Initial); | ||
diag3->compute_diagnostic(); | ||
auto diag3_f = diag3->get_diagnostic(); | ||
REQUIRE(views_are_equal(diag3_f, diag3_manual)); | ||
} | ||
|
||
} // namespace scream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters