Skip to content

Commit

Permalink
Merge branch 'mahf708/eamxx/horiz-avg-diag' (PR #6788)
Browse files Browse the repository at this point in the history
Adds an online diagnostic field for area-weighted horizontal average, following #6776.
  • Loading branch information
bartgol authored Dec 12, 2024
2 parents 2a7c3fa + 82dc94d commit 8818556
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 5 deletions.
1 change: 1 addition & 0 deletions components/eamxx/src/diagnostics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(DIAGNOSTIC_SRCS
field_at_height.cpp
field_at_level.cpp
field_at_pressure_level.cpp
horiz_avg.cpp
longwave_cloud_forcing.cpp
number_path.cpp
potential_temperature.cpp
Expand Down
4 changes: 3 additions & 1 deletion components/eamxx/src/diagnostics/field_at_height.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ initialize_impl (const RunType /*run_type*/)
EKAT_REQUIRE_MSG (layout.rank()>=2 && layout.rank()<=3,
"Error! Field rank not supported by FieldAtHeight.\n"
" - field name: " + fid.name() + "\n"
" - field layout: " + layout.to_string() + "\n");
" - field layout: " + layout.to_string() + "\n"
"NOTE: if you requested something like 'field_horiz_avg_at_Y',\n"
" you can avoid this error by requesting 'fieldX_at_Y_horiz_avg' instead.\n");
const auto tag = layout.tags().back();
EKAT_REQUIRE_MSG (tag==LEV || tag==ILEV,
"Error! FieldAtHeight diagnostic expects a layout ending with 'LEV'/'ILEV' tag.\n"
Expand Down
6 changes: 4 additions & 2 deletions components/eamxx/src/diagnostics/field_at_level.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ initialize_impl (const RunType /*run_type*/)
using namespace ShortFieldTagsNames;
const auto& fid = f.get_header().get_identifier();
const auto& layout = fid.get_layout();
EKAT_REQUIRE_MSG (layout.rank()>1 && layout.rank()<=6,
EKAT_REQUIRE_MSG (layout.rank()>=2 && layout.rank()<=6,
"Error! Field rank not supported by FieldAtLevel.\n"
" - field name: " + fid.name() + "\n"
" - field layout: " + layout.to_string() + "\n");
" - field layout: " + layout.to_string() + "\n"
"NOTE: if you requested something like 'field_horiz_avg_at_Y',\n"
" you can avoid this error by requesting 'fieldX_at_Y_horiz_avg' instead.\n");
const auto tag = layout.tags().back();
EKAT_REQUIRE_MSG (tag==LEV || tag==ILEV,
"Error! FieldAtLevel diagnostic expects a layout ending with 'LEV'/'ILEV' tag.\n"
Expand Down
4 changes: 3 additions & 1 deletion components/eamxx/src/diagnostics/field_at_pressure_level.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ initialize_impl (const RunType /*run_type*/)
EKAT_REQUIRE_MSG (layout.rank()>=2 && layout.rank()<=3,
"Error! Field rank not supported by FieldAtPressureLevel.\n"
" - field name: " + fid.name() + "\n"
" - field layout: " + layout.to_string() + "\n");
" - field layout: " + layout.to_string() + "\n"
"NOTE: if you requested something like 'field_horiz_avg_at_Y',\n"
" you can avoid this error by requesting 'fieldX_at_Y_horiz_avg' instead.\n");
const auto tag = layout.tags().back();
EKAT_REQUIRE_MSG (tag==LEV || tag==ILEV,
"Error! FieldAtPressureLevel diagnostic expects a layout ending with 'LEV'/'ILEV' tag.\n"
Expand Down
65 changes: 65 additions & 0 deletions components/eamxx/src/diagnostics/horiz_avg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "diagnostics/horiz_avg.hpp"

#include "share/field/field_utils.hpp"

namespace scream {

HorizAvgDiag::HorizAvgDiag(const ekat::Comm &comm,
const ekat::ParameterList &params)
: AtmosphereDiagnostic(comm, params) {
const auto &fname = m_params.get<std::string>("field_name");
m_diag_name = fname + "_horiz_avg";
}

void HorizAvgDiag::set_grids(
const std::shared_ptr<const GridsManager> grids_manager) {
const auto &fn = m_params.get<std::string>("field_name");
const auto &gn = m_params.get<std::string>("grid_name");
const auto g = grids_manager->get_grid("Physics");

add_field<Required>(fn, gn);

// first clone the area unscaled, we will scale it later in initialize_impl
m_scaled_area = g->get_geometry_data("area").clone();
}

void HorizAvgDiag::initialize_impl(const RunType /*run_type*/) {
using namespace ShortFieldTagsNames;
const auto &f = get_fields_in().front();
const auto &fid = f.get_header().get_identifier();
const auto &layout = fid.get_layout();

EKAT_REQUIRE_MSG(layout.rank() >= 1 && layout.rank() <= 3,
"Error! Field rank not supported by HorizAvgDiag.\n"
" - field name: " +
fid.name() +
"\n"
" - field layout: " +
layout.to_string() + "\n");
EKAT_REQUIRE_MSG(layout.tags()[0] == COL,
"Error! HorizAvgDiag diagnostic expects a layout starting "
"with the 'COL' tag.\n"
" - field name : " +
fid.name() +
"\n"
" - field layout: " +
layout.to_string() + "\n");

FieldIdentifier d_fid(m_diag_name, layout.clone().strip_dim(COL),
fid.get_units(), fid.get_grid_name());
m_diagnostic_output = Field(d_fid);
m_diagnostic_output.allocate_view();

// scale the area field
auto total_area = field_sum<Real>(m_scaled_area, &m_comm);
m_scaled_area.scale(sp(1.0) / total_area);
}

void HorizAvgDiag::compute_diagnostic_impl() {
const auto &f = get_fields_in().front();
const auto &d = m_diagnostic_output;
// Call the horiz_contraction impl that will take care of everything
horiz_contraction<Real>(d, f, m_scaled_area, &m_comm);
}

} // namespace scream
43 changes: 43 additions & 0 deletions components/eamxx/src/diagnostics/horiz_avg.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef EAMXX_HORIZ_AVERAGE_HPP
#define EAMXX_HORIZ_AVERAGE_HPP

#include "share/atm_process/atmosphere_diagnostic.hpp"

namespace scream {

/*
* This diagnostic will calculate the area-weighted average of a field
* across the COL tag dimension, producing an N-1 dimensional field
* that is area-weighted average of the input field.
*/

class HorizAvgDiag : public AtmosphereDiagnostic {
public:
// Constructors
HorizAvgDiag(const ekat::Comm &comm, const ekat::ParameterList &params);

// The name of the diagnostic
std::string name() const { return m_diag_name; }

// Set the grid
void set_grids(const std::shared_ptr<const GridsManager> grids_manager);

protected:
#ifdef KOKKOS_ENABLE_CUDA
public:
#endif
void compute_diagnostic_impl();

protected:
void initialize_impl(const RunType /*run_type*/);

// Name of each field (because the diagnostic impl is generic)
std::string m_diag_name;

// Need area field, let's store it scaled by its norm
Field m_scaled_area;
};

} // namespace scream

#endif // EAMXX_HORIZ_AVERAGE_HPP
2 changes: 2 additions & 0 deletions components/eamxx/src/diagnostics/register_diagnostics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "diagnostics/number_path.hpp"
#include "diagnostics/aerocom_cld.hpp"
#include "diagnostics/atm_backtend.hpp"
#include "diagnostics/horiz_avg.hpp"

namespace scream {

Expand Down Expand Up @@ -51,6 +52,7 @@ inline void register_diagnostics () {
diag_factory.register_product("NumberPath",&create_atmosphere_diagnostic<NumberPathDiagnostic>);
diag_factory.register_product("AeroComCld",&create_atmosphere_diagnostic<AeroComCld>);
diag_factory.register_product("AtmBackTendDiag",&create_atmosphere_diagnostic<AtmBackTendDiag>);
diag_factory.register_product("HorizAvgDiag",&create_atmosphere_diagnostic<HorizAvgDiag>);
}

} // namespace scream
Expand Down
3 changes: 3 additions & 0 deletions components/eamxx/src/diagnostics/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ CreateDiagTest(aerocom_cld "aerocom_cld_test.cpp")

# Test atm_tend
CreateDiagTest(atm_backtend "atm_backtend_test.cpp")

# Test horizontal averaging
CreateDiagTest(horiz_avg "horiz_avg_test.cpp")
170 changes: 170 additions & 0 deletions components/eamxx/src/diagnostics/tests/horiz_avg_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include "catch2/catch.hpp"
#include "diagnostics/register_diagnostics.hpp"
#include "share/field/field_utils.hpp"
#include "share/grid/mesh_free_grids_manager.hpp"
#include "share/util/scream_setup_random_test.hpp"
#include "share/util/scream_universal_constants.hpp"

namespace scream {

std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols,
const int nlevs) {
const int num_global_cols = ncols * comm.size();

using vos_t = std::vector<std::string>;
ekat::ParameterList gm_params;
gm_params.set("grids_names", vos_t{"Point Grid"});
auto &pl = gm_params.sublist("Point Grid");
pl.set<std::string>("type", "point_grid");
pl.set("aliases", vos_t{"Physics"});
pl.set<int>("number_of_global_columns", num_global_cols);
pl.set<int>("number_of_vertical_levels", nlevs);

auto gm = create_mesh_free_grids_manager(comm, gm_params);
gm->build_grids();

return gm;
}

TEST_CASE("horiz_avg") {
using namespace ShortFieldTagsNames;
using namespace ekat::units;
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>;
using TeamMember = typename TeamPolicy::member_type;
using KT = ekat::KokkosTypes<DefaultDevice>;
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;

// A numerical tolerance
auto tol = std::numeric_limits<Real>::epsilon() * 100;

// A world comm
ekat::Comm comm(MPI_COMM_WORLD);

// A time stamp
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0});

// Create a grids manager - single column for these tests
constexpr int nlevs = 3;
constexpr int dim3 = 4;
const int ngcols = 6 * comm.size();

auto gm = create_gm(comm, ngcols, nlevs);
auto grid = gm->get_grid("Physics");

// Input (randomized) qc
FieldLayout scalar1d_layout{{COL}, {ngcols}};
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}};

FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid->name());
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid->name());
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid->name());

Field qc1(qc1_fid);
Field qc2(qc2_fid);
Field qc3(qc3_fid);

qc1.allocate_view();
qc2.allocate_view();
qc3.allocate_view();

// Construct random number generator stuff
using RPDF = std::uniform_real_distribution<Real>;
RPDF pdf(sp(0.0), sp(200.0));
auto engine = scream::setup_random_test();

// Construct the Diagnostics
std::map<std::string, std::shared_ptr<AtmosphereDiagnostic>> diags;
auto &diag_factory = AtmosphereDiagnosticFactory::instance();
register_diagnostics();

ekat::ParameterList params;
REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm,
params)); // No 'field_name' parameter

// Set time for qc and randomize its values
qc1.get_header().get_tracking().update_time_stamp(t0);
qc2.get_header().get_tracking().update_time_stamp(t0);
qc3.get_header().get_tracking().update_time_stamp(t0);
randomize(qc1, engine, pdf);
randomize(qc2, engine, pdf);
randomize(qc3, engine, pdf);

// Create and set up the diagnostic
params.set("grid_name", grid->name());
params.set<std::string>("field_name", "qc");
auto diag1 = diag_factory.create("HorizAvgDiag", comm, params);
auto diag2 = diag_factory.create("HorizAvgDiag", comm, params);
auto diag3 = diag_factory.create("HorizAvgDiag", comm, params);
diag1->set_grids(gm);
diag2->set_grids(gm);
diag3->set_grids(gm);

// Clone the area field
auto area = grid->get_geometry_data("area").clone();

// Test the horiz contraction of qc1
// Get the diagnostic field
diag1->set_required_field(qc1);
diag1->initialize(t0, RunType::Initial);
diag1->compute_diagnostic();
auto diag1_f = diag1->get_diagnostic();

// Manual calculation
FieldIdentifier diag0_fid("qc_horiz_avg_manual",
scalar1d_layout.clone().strip_dim(COL), kg / kg,
grid->name());
Field diag0(diag0_fid);
diag0.allocate_view();

// calculate total area
Real atot = field_sum<Real>(area, &comm);
// scale the area field
area.scale(1 / atot);

// calculate weighted avg
horiz_contraction<Real>(diag0, qc1, area, &comm);
// Compare
REQUIRE(views_are_equal(diag1_f, diag0));

// Try other known cases
// Set qc1_v to 1.0 to get weighted average of 1.0
Real wavg = 1;
qc1.deep_copy(wavg);
diag1->compute_diagnostic();
auto diag1_v2_host = diag1_f.get_view<Real, Host>();
REQUIRE_THAT(diag1_v2_host(),
Catch::Matchers::WithinRel(
wavg, tol)); // Catch2's floating point comparison

// other diags
// Set qc2_v to 5.0 to get weighted average of 5.0
wavg = sp(5.0);
qc2.deep_copy(wavg);
diag2->set_required_field(qc2);
diag2->initialize(t0, RunType::Initial);
diag2->compute_diagnostic();
auto diag2_f = diag2->get_diagnostic();

auto diag2_v_host = diag2_f.get_view<Real *, Host>();

for(int i = 0; i < nlevs; ++i) {
REQUIRE_THAT(diag2_v_host(i), Catch::Matchers::WithinRel(wavg, tol));
}

// Try a random case with qc3
auto qc3_v = qc3.get_view<Real ***>();
FieldIdentifier diag3_manual_fid("qc_horiz_avg_manual",
scalar3d_layout.clone().strip_dim(COL),
kg / kg, grid->name());
Field diag3_manual(diag3_manual_fid);
diag3_manual.allocate_view();
horiz_contraction<Real>(diag3_manual, qc3, area, &comm);
diag3->set_required_field(qc3);
diag3->initialize(t0, RunType::Initial);
diag3->compute_diagnostic();
auto diag3_f = diag3->get_diagnostic();
REQUIRE(views_are_equal(diag3_f, diag3_manual));
}

} // namespace scream
11 changes: 10 additions & 1 deletion components/eamxx/src/share/io/scream_io_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ create_diagnostic (const std::string& diag_field_name,
std::regex backtend ("([A-Za-z0-9_]+)_atm_backtend$");
std::regex pot_temp ("(Liq)?PotentialTemperature$");
std::regex vert_layer ("(z|geopotential|height)_(mid|int)$");
std::regex horiz_avg ("([A-Za-z0-9_]+)_horiz_avg$");

std::string diag_name;
std::smatch matches;
Expand Down Expand Up @@ -191,7 +192,15 @@ create_diagnostic (const std::string& diag_field_name,
diag_name = "VerticalLayer";
params.set<std::string>("diag_name","dz");
params.set<std::string>("vert_location","mid");
} else {
}
else if (std::regex_search(diag_field_name,matches,horiz_avg)) {
diag_name = "HorizAvgDiag";
// Set the grid_name
params.set("grid_name",grid->name());
params.set<std::string>("field_name",matches[1].str());
}
else
{
// No existing special regex matches, so we assume that the diag field name IS the diag name.
diag_name = diag_field_name;
}
Expand Down

0 comments on commit 8818556

Please sign in to comment.