diff --git a/external/report/report/__init__.py b/external/report/report/__init__.py
index 0d51642ac7..2a0760086e 100644
--- a/external/report/report/__init__.py
+++ b/external/report/report/__init__.py
@@ -1,3 +1,10 @@
-from .create_report import create_html, insert_report_figure, Metrics, Metadata, Link
+from .create_report import (
+ create_html,
+ insert_report_figure,
+ Metrics,
+ Metadata,
+ Link,
+ OrderedList,
+)
__version__ = "0.1.0"
diff --git a/external/report/report/create_report.py b/external/report/report/create_report.py
index 28f3ca378d..a7b332934a 100644
--- a/external/report/report/create_report.py
+++ b/external/report/report/create_report.py
@@ -1,6 +1,6 @@
import datetime
import os
-from typing import Mapping, Sequence, Union
+from typing import Any, Mapping, Sequence, Union
from jinja2 import Template
from pytz import timezone
@@ -90,6 +90,15 @@ def __repr__(self) -> str:
return f'{self.tag}'
+class OrderedList:
+ def __init__(self, *items: Any):
+ self.items = items
+
+ def __repr__(self) -> str:
+ items_li = [f"
{item}" for item in self.items]
+ return "\n" + "\n".join(items_li) + "\n
"
+
+
def resolve_plot(obj):
if isinstance(obj, str):
return ImagePlot(obj)
diff --git a/external/report/tests/test_report.py b/external/report/tests/test_report.py
index 22bd22cb6b..a5131478b8 100644
--- a/external/report/tests/test_report.py
+++ b/external/report/tests/test_report.py
@@ -1,5 +1,5 @@
import os
-from report import __version__, create_html
+from report import __version__, create_html, OrderedList
from report.create_report import _save_figure, insert_report_figure
@@ -46,3 +46,8 @@ def test__save_figure(tmpdir):
with open(os.path.join(output_dir, filepath_relative_to_report), "r") as f:
saved_data = f.read()
assert saved_data.replace("\n", "") == fig.content
+
+
+def test_OrderedList_repr():
+ result = str(OrderedList("item1", "item2"))
+ assert result == "\n- item1
\n- item2
\n
"
diff --git a/external/vcm/tests/test_calc.py b/external/vcm/tests/test_calc.py
index 49396e0dab..2222ccf183 100644
--- a/external/vcm/tests/test_calc.py
+++ b/external/vcm/tests/test_calc.py
@@ -12,6 +12,7 @@
)
from vcm.calc.calc import local_time, apparent_source
from vcm.cubedsphere.constants import COORD_Z_CENTER, COORD_Z_OUTER
+from vcm.calc.histogram import histogram
@pytest.mark.parametrize("toa_pressure", [0, 5])
@@ -119,3 +120,15 @@ def test_apparent_source():
s_dim="forecast_time",
)
assert Q1_forecast3 == pytest.approx((2.0 / (15 * 60)) - (4.0 / 60))
+
+
+def test_histogram():
+ data = xr.DataArray(
+ np.reshape(np.arange(0, 40, 2), (5, 4)), dims=["x", "y"], name="temperature"
+ )
+ coords = {"temperature_bins": [0, 30]}
+ expected_count = xr.DataArray([15, 5], coords=coords, dims="temperature_bins")
+ expected_width = xr.DataArray([30, 10], coords=coords, dims="temperature_bins")
+ count, width = histogram(data, bins=[0, 30, 40])
+ xr.testing.assert_equal(count, expected_count)
+ xr.testing.assert_equal(width, expected_width)
diff --git a/external/vcm/vcm/__init__.py b/external/vcm/vcm/__init__.py
index f8ba2af1f2..dba363489d 100644
--- a/external/vcm/vcm/__init__.py
+++ b/external/vcm/vcm/__init__.py
@@ -29,6 +29,7 @@
column_integrated_heating_from_isobaric_transition,
column_integrated_heating_from_isochoric_transition,
)
+from .calc.histogram import histogram
from .interpolate import (
interpolate_to_pressure_levels,
diff --git a/external/vcm/vcm/calc/histogram.py b/external/vcm/vcm/calc/histogram.py
new file mode 100644
index 0000000000..0cc34ad05a
--- /dev/null
+++ b/external/vcm/vcm/calc/histogram.py
@@ -0,0 +1,28 @@
+from typing import Any, Hashable, Mapping, Tuple
+
+import numpy as np
+import xarray as xr
+
+
+def histogram(da: xr.DataArray, **kwargs) -> Tuple[xr.DataArray, xr.DataArray]:
+ """Compute histogram and return tuple of counts and bin widths.
+
+ Args:
+ da: input data
+ kwargs: optional parameters to pass on to np.histogram
+
+ Return:
+ counts, bin_widths tuple of xr.DataArrays. The coordinate of both arrays is
+ equal to the left side of the histogram bins.
+ """
+ coord_name = f"{da.name}_bins" if da.name is not None else "bins"
+ count, bins = np.histogram(da, **kwargs)
+ coords: Mapping[Hashable, Any] = {coord_name: bins[:-1]}
+ width = bins[1:] - bins[:-1]
+ width_da = xr.DataArray(width, coords=coords, dims=[coord_name])
+ count_da = xr.DataArray(count, coords=coords, dims=[coord_name])
+ if "units" in da.attrs:
+ count_da[coord_name].attrs["units"] = da.units
+ width_da[coord_name].attrs["units"] = da.units
+ width_da.attrs["units"] = da.units
+ return count_da, width_da
diff --git a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/compute.py b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/compute.py
index 1ed61a4338..7ac85edb88 100644
--- a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/compute.py
+++ b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/compute.py
@@ -12,15 +12,12 @@
grouping contains outputs from the physics routines (`sfc_dt_atmos.tile*.nc` and
`diags.zarr`).
"""
-import os
import sys
import datetime
-import tempfile
import intake
import numpy as np
import xarray as xr
-import shutil
from dask.diagnostics import ProgressBar
import fsspec
@@ -37,6 +34,7 @@
from fv3net.diagnostics.prognostic_run import diurnal_cycle
from fv3net.diagnostics.prognostic_run import transform
from fv3net.diagnostics.prognostic_run.constants import (
+ HISTOGRAM_BINS,
HORIZONTAL_DIMS,
DiagArg,
GLOBAL_AVERAGE_DYCORE_VARS,
@@ -229,16 +227,6 @@ def _assign_diagnostic_time_attrs(
return diagnostics_ds
-def dump_nc(ds: xr.Dataset, f):
- # to_netcdf closes file, which will delete the buffer
- # need to use a buffer since seek doesn't work with GCSFS file objects
- with tempfile.TemporaryDirectory() as dirname:
- url = os.path.join(dirname, "tmp.nc")
- ds.to_netcdf(url, engine="h5netcdf")
- with open(url, "rb") as tmp1:
- shutil.copyfileobj(tmp1, f)
-
-
@add_to_diags("dycore")
@diag_finalizer("rms_global")
@transform.apply("resample_time", "3H", inner_join=True)
@@ -503,6 +491,22 @@ def _diurnal_func(
return _assign_diagnostic_time_attrs(diag, prognostic)
+@add_to_diags("physics")
+@diag_finalizer("histogram")
+@transform.apply("resample_time", "3H", inner_join=True, method="mean")
+@transform.apply("subset_variables", list(HISTOGRAM_BINS.keys()))
+def compute_histogram(prognostic, verification, grid):
+ logger.info("Computing histograms for physics diagnostics")
+ counts = xr.Dataset()
+ for varname in prognostic.data_vars:
+ count, width = vcm.histogram(
+ prognostic[varname], bins=HISTOGRAM_BINS[varname], density=True
+ )
+ counts[varname] = count
+ counts[f"{varname}_bin_width"] = width
+ return _assign_diagnostic_time_attrs(counts, prognostic)
+
+
def register_parser(subparsers):
parser = subparsers.add_parser("save", help="Compute the prognostic run diags.")
parser.add_argument("url", help="Prognostic run output location.")
diff --git a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/computed_diagnostics.py b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/computed_diagnostics.py
index a3e0c3cefd..cae780bc7b 100644
--- a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/computed_diagnostics.py
+++ b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/computed_diagnostics.py
@@ -274,11 +274,6 @@ def detect_folders(
bucket: str, fs: fsspec.AbstractFileSystem,
) -> Mapping[str, DiagnosticFolder]:
diag_ncs = fs.glob(os.path.join(bucket, "*", "diags.nc"))
- if len(diag_ncs) < 2:
- raise ValueError(
- "Plots require more than 1 diagnostic directory in"
- f" {bucket} for holoviews plots to display correctly."
- )
return {
Path(url).parent.name: DiagnosticFolder(fs, Path(url).parent.as_posix())
for url in diag_ncs
diff --git a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/constants.py b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/constants.py
index 45a6471432..8f4e207c33 100644
--- a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/constants.py
+++ b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/constants.py
@@ -1,3 +1,4 @@
+import numpy as np
import xarray as xr
from typing import Tuple
@@ -135,3 +136,7 @@
"dQu",
"dQv",
]
+
+PRECIP_RATE = "total_precip_to_surface"
+HISTOGRAM_BINS = {PRECIP_RATE: np.logspace(-1, np.log10(500), 101)}
+PERCENTILES = [25, 50, 75, 90, 99, 99.9]
diff --git a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/metrics.py b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/metrics.py
index a5b50202d6..dc7e8d638a 100644
--- a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/metrics.py
+++ b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/metrics.py
@@ -13,7 +13,7 @@
import numpy as np
import xarray as xr
from toolz import curry
-from .constants import HORIZONTAL_DIMS
+from .constants import HORIZONTAL_DIMS, PERCENTILES
import json
_METRICS = []
@@ -138,6 +138,49 @@ def rmse_time_mean(diags):
return rms_of_time_mean_bias
+for percentile in PERCENTILES:
+
+ @add_to_metrics(f"percentile_{percentile}")
+ def percentile_metric(diags, percentile=percentile):
+ histogram = grab_diag(diags, "histogram")
+ percentiles = xr.Dataset()
+ data_vars = [v for v in histogram.data_vars if not v.endswith("bin_width")]
+ for varname in data_vars:
+ percentiles[varname] = compute_percentile(
+ percentile,
+ histogram[varname].values,
+ histogram[f"{varname}_bins"].values,
+ histogram[f"{varname}_bin_width"].values,
+ )
+ restore_units(histogram, percentiles)
+ return percentiles
+
+
+def compute_percentile(
+ percentile: float, freq: np.ndarray, bins: np.ndarray, bin_widths: np.ndarray
+) -> float:
+ """Compute percentile given normalized histogram.
+
+ Args:
+ percentile: value between 0 and 100
+ freq: array of frequencies normalized by bin widths
+ bins: values of left sides of bins
+ bin_widths: values of bin widths
+
+ Returns:
+ value of distribution at percentile
+ """
+ cumulative_distribution = np.cumsum(freq * bin_widths)
+ if np.abs(cumulative_distribution[-1] - 1) > 1e-6:
+ raise ValueError(
+ "The provided frequencies do not integrate to one. "
+ "Ensure that histogram is computed with density=True."
+ )
+ bin_midpoints = bins + 0.5 * bin_widths
+ closest_index = np.argmin(np.abs(cumulative_distribution - percentile / 100))
+ return bin_midpoints[closest_index]
+
+
def restore_units(source, target):
for variable in target:
target[variable].attrs["units"] = source[variable].attrs["units"]
diff --git a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/transform.py b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/transform.py
index 1b1ea4c343..62ddc09f78 100644
--- a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/transform.py
+++ b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/transform.py
@@ -89,7 +89,11 @@ def transform(*diag_args):
@add_to_input_transform_fns
def resample_time(
- freq_label: str, arg: DiagArg, time_slice=slice(None, -1), inner_join: bool = False
+ freq_label: str,
+ arg: DiagArg,
+ time_slice=slice(None, -1),
+ inner_join: bool = False,
+ method: str = "nearest",
) -> DiagArg:
"""
Subset times in prognostic and verification data
@@ -102,10 +106,11 @@ def resample_time(
time by default to work with crashed simulations.
inner_join: Subset times to the intersection of prognostic and verification
data. Defaults to False.
+ method: how to do resampling. Can be "nearest" or "mean".
"""
prognostic, verification, grid = arg
- prognostic = _downsample_only(prognostic, freq_label)
- verification = _downsample_only(verification, freq_label)
+ prognostic = _downsample_only(prognostic, freq_label, method)
+ verification = _downsample_only(verification, freq_label, method)
prognostic = prognostic.isel(time=time_slice)
if inner_join:
@@ -113,12 +118,18 @@ def resample_time(
return prognostic, verification, grid
-def _downsample_only(ds: xr.Dataset, freq_label: str) -> xr.Dataset:
+def _downsample_only(ds: xr.Dataset, freq_label: str, method: str) -> xr.Dataset:
"""Resample in time, only if given freq_label is lower frequency than time
sampling of given dataset ds"""
ds_freq = ds.time.values[1] - ds.time.values[0]
if ds_freq < pd.to_timedelta(freq_label):
- return ds.resample(time=freq_label, label="right").nearest()
+ resampled = ds.resample(time=freq_label, label="right")
+ if method == "nearest":
+ return resampled.nearest()
+ elif method == "mean":
+ return resampled.mean()
+ else:
+ raise ValueError(f"Don't know how to resample with method={method}.")
else:
return ds
diff --git a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/views/matplotlib.py b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/views/matplotlib.py
index 1ac3cc3fe4..fd57255ff1 100644
--- a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/views/matplotlib.py
+++ b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/views/matplotlib.py
@@ -30,9 +30,9 @@
}
-def fig_to_b64(fig, format="png"):
+def fig_to_b64(fig, format="png", dpi=None):
pic_IObytes = io.BytesIO()
- fig.savefig(pic_IObytes, format=format, bbox_inches="tight")
+ fig.savefig(pic_IObytes, format=format, bbox_inches="tight", dpi=dpi)
pic_IObytes.seek(0)
pic_hash = base64.b64encode(pic_IObytes.read())
return f"data:image/png;base64, " + pic_hash.decode()
@@ -179,6 +179,27 @@ def plot_cubed_sphere_map(
)
+def plot_histogram(run_diags: RunDiagnostics, varname: str) -> raw_html:
+ """Plot 1D histogram of varname overlaid across runs."""
+
+ logging.info(f"plotting {varname}")
+ fig, ax = plt.subplots()
+ bin_name = varname.replace("histogram", "bins")
+ for run in run_diags.runs:
+ v = run_diags.get_variable(run, varname)
+ ax.step(v[bin_name], v, label=run, where="post", linewidth=1)
+ ax.set_xlabel(f"{v.long_name} [{v.units}]")
+ ax.set_ylabel(f"Frequency [({v.units})^-1]")
+ ax.set_xscale("log")
+ ax.set_yscale("log")
+ ax.set_xlim([v[bin_name].values[0], v[bin_name].values[-1]])
+ ax.legend()
+ fig.tight_layout()
+ data = fig_to_b64(fig, dpi=150)
+ plt.close(fig)
+ return raw_html(f'')
+
+
def _render_map_title(
metrics: RunMetrics, variable: str, run: str, metrics_for_title: Mapping[str, str],
) -> str:
diff --git a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/views/static_report.py b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/views/static_report.py
index 9696502557..c5046519eb 100644
--- a/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/views/static_report.py
+++ b/workflows/prognostic_run_diags/fv3net/diagnostics/prognostic_run/views/static_report.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
-from typing import Iterable
+from typing import Iterable, Mapping, Sequence
import os
import xarray as xr
import fsspec
@@ -13,9 +13,15 @@
RunMetrics,
)
-from report import create_html, Link
+from report import create_html, Link, OrderedList
from report.holoviews import HVPlot, get_html_header
-from .matplotlib import plot_2d_matplotlib, plot_cubed_sphere_map, raw_html
+from .matplotlib import (
+ plot_2d_matplotlib,
+ plot_cubed_sphere_map,
+ raw_html,
+ plot_histogram,
+)
+from ..constants import PERCENTILES, PRECIP_RATE
import logging
@@ -75,9 +81,7 @@ def make_plots(self, data) -> Iterable:
yield func(data)
-def plot_1d(
- run_diags: RunDiagnostics, varfilter: str, run_attr_name: str = "run",
-) -> HVPlot:
+def plot_1d(run_diags: RunDiagnostics, varfilter: str) -> HVPlot:
"""Plot all diagnostics whose name includes varfilter. Plot is overlaid across runs.
All matching diagnostics must be 1D."""
p = hv.Cycle("Colorblind")
@@ -95,10 +99,7 @@ def plot_1d(
def plot_1d_min_max_with_region_bar(
- run_diags: RunDiagnostics,
- varfilter_min: str,
- varfilter_max: str,
- run_attr_name: str = "run",
+ run_diags: RunDiagnostics, varfilter_min: str, varfilter_max: str,
) -> HVPlot:
"""Plot all diagnostics whose name includes varfilter. Plot is overlaid across runs.
All matching diagnostics must be 1D."""
@@ -123,9 +124,7 @@ def plot_1d_min_max_with_region_bar(
return HVPlot(_set_opts_and_overlay(hmap))
-def plot_1d_with_region_bar(
- run_diags: RunDiagnostics, varfilter: str, run_attr_name: str = "run"
-) -> HVPlot:
+def plot_1d_with_region_bar(run_diags: RunDiagnostics, varfilter: str) -> HVPlot:
"""Plot all diagnostics whose name includes varfilter. Plot is overlaid across runs.
Region will be selectable through a drop-down bar. Region is assumed to be part of
variable name after last underscore. All matching diagnostics must be 1D."""
@@ -189,6 +188,7 @@ def diurnal_component_plot(
hovmoller_plot_manager = PlotManager()
zonal_pressure_plot_manager = PlotManager()
diurnal_plot_manager = PlotManager()
+histogram_plot_manager = PlotManager()
# this will be passed the data from the metrics.json files
metrics_plot_manager = PlotManager()
@@ -291,6 +291,11 @@ def diurnal_cycle_component_plots(diagnostics: Iterable[xr.Dataset]) -> HVPlot:
return diurnal_component_plot(diagnostics)
+@histogram_plot_manager.register
+def histogram_plots(diagnostics: Iterable[xr.Dataset]) -> HVPlot:
+ return plot_histogram(diagnostics, f"{PRECIP_RATE}_histogram")
+
+
# Routines for plotting the "metrics"
# New plotting routines can be registered here.
@metrics_plot_manager.register
@@ -325,12 +330,30 @@ def generic_metric_plot(metrics: RunMetrics, metric_type: str) -> hv.HoloMap:
return HVPlot(hmap.opts(**bar_opts))
-navigation = [
+def get_metrics_table(
+ metrics: RunMetrics, metric_types: Sequence[str], variable_names: Sequence[str]
+) -> Mapping[str, Mapping[str, float]]:
+ """Structure a set of metrics in format suitable for reports.create_html"""
+ table = {}
+ for metric_type in metric_types:
+ for name in variable_names:
+ units = metrics.get_metric_units(metric_type, name, metrics.runs[0])
+ type_label = f"{name} {metric_type} [{units}]"
+ table[type_label] = {
+ run: f"{metrics.get_metric_value(metric_type, name, run):.2f}"
+ for run in metrics.runs
+ }
+ return table
+
+
+navigation = OrderedList(
Link("Home", "index.html"),
+ Link("Process diagnostics", "process_diagnostics.html"),
Link("Latitude versus time hovmoller", "hovmoller.html"),
Link("Time-mean maps", "maps.html"),
Link("Time-mean zonal-pressure profiles", "zonal_pressure.html"),
-]
+)
+navigation = [navigation] # must be iterable for create_html template
def render_index(metadata, diagnostics, metrics, movie_links):
@@ -338,7 +361,6 @@ def render_index(metadata, diagnostics, metrics, movie_links):
"Links": navigation,
"Timeseries": list(timeseries_plot_manager.make_plots(diagnostics)),
"Zonal mean": list(zonal_mean_plot_manager.make_plots(diagnostics)),
- "Diurnal cycle": list(diurnal_plot_manager.make_plots(diagnostics)),
}
if not metrics.empty:
@@ -398,6 +420,23 @@ def render_zonal_pressures(metadata, diagnostics):
)
+def render_process_diagnostics(metadata, diagnostics, metrics):
+ sections = {
+ "Links": navigation,
+ "Diurnal cycle": list(diurnal_plot_manager.make_plots(diagnostics)),
+ "Precipitation histogram": list(histogram_plot_manager.make_plots(diagnostics)),
+ }
+ metric_types = [f"percentile_{p}" for p in PERCENTILES]
+ metrics_table = get_metrics_table(metrics, metric_types, [PRECIP_RATE])
+ return create_html(
+ title="Process diagnostics",
+ metadata=metadata,
+ metrics=metrics_table,
+ sections=sections,
+ html_header=get_html_header(),
+ )
+
+
def _html_link(url, tag):
return f"{tag}"
@@ -427,6 +466,9 @@ def make_report(computed_diagnostics: ComputedDiagnosticsList, output):
"hovmoller.html": render_hovmollers(metadata, diagnostics),
"maps.html": render_maps(metadata, diagnostics, metrics),
"zonal_pressure.html": render_zonal_pressures(metadata, diagnostics),
+ "process_diagnostics.html": render_process_diagnostics(
+ metadata, diagnostics, metrics
+ ),
}
for filename, html in pages.items():
diff --git a/workflows/prognostic_run_diags/tests/test_computed_diagnostics.py b/workflows/prognostic_run_diags/tests/test_computed_diagnostics.py
index 87f452f1eb..6db0c37e81 100644
--- a/workflows/prognostic_run_diags/tests/test_computed_diagnostics.py
+++ b/workflows/prognostic_run_diags/tests/test_computed_diagnostics.py
@@ -49,16 +49,6 @@ def test_ComputedDiagnosticsList_from_urls():
assert isinstance(result.folders["1"], DiagnosticFolder)
-def test_detect_folders_fail_less_than_2(tmpdir):
-
- fs = fsspec.filesystem("file")
-
- tmpdir.mkdir("rundir1").join("diags.nc").write("foobar")
-
- with pytest.raises(ValueError):
- detect_folders(tmpdir, fs)
-
-
def test_get_movie_links(tmpdir):
domain = "http://www.domain.com"
rdirs = ["rundir1", "rundir2"]
diff --git a/workflows/prognostic_run_diags/tests/test_integration.sh b/workflows/prognostic_run_diags/tests/test_integration.sh
index 84edfc4035..41d137de5e 100644
--- a/workflows/prognostic_run_diags/tests/test_integration.sh
+++ b/workflows/prognostic_run_diags/tests/test_integration.sh
@@ -21,9 +21,6 @@ gsutil cp /tmp/$random/metrics.json $OUTPUT/run1/metrics.json
# generate movies for short sample prognostic run
prognostic_run_diags movie --n_jobs 1 --n_timesteps 2 $RUN $OUTPUT/run1
-# make a second copy of diags/metrics since generate_report.py needs at least two runs
-gsutil -m cp -r $OUTPUT/run1 $OUTPUT/run2
-
# generate report based on diagnostics computed above
prognostic_run_diags report $OUTPUT $OUTPUT
diff --git a/workflows/prognostic_run_diags/tests/test_metrics.py b/workflows/prognostic_run_diags/tests/test_metrics.py
new file mode 100644
index 0000000000..dd5a55148e
--- /dev/null
+++ b/workflows/prognostic_run_diags/tests/test_metrics.py
@@ -0,0 +1,23 @@
+import numpy as np
+import pytest
+from fv3net.diagnostics.prognostic_run.metrics import compute_percentile
+
+
+@pytest.mark.parametrize(
+ ["percentile", "expected_value"],
+ [(0, 0.5), (5, 0.5), (10, 0.5), (20, 1.5), (45, 2.5), (99, 3.5)],
+)
+def test_compute_percentile(percentile, expected_value):
+ bins = np.array([0, 1, 2, 3])
+ bin_widths = np.array([1, 1, 1, 1])
+ frequency = np.array([0.1, 0.1, 0.4, 0.4])
+ value = compute_percentile(percentile, frequency, bins, bin_widths)
+ assert value == expected_value
+
+
+def test_compute_percentile_raise_value_error():
+ bins = np.array([0, 1, 2, 3])
+ bin_widths = np.array([1, 1, 1, 0.5])
+ frequency = np.array([0.1, 0.1, 0.4, 0.4])
+ with pytest.raises(ValueError):
+ compute_percentile(0, frequency, bins, bin_widths)
diff --git a/workflows/prognostic_run_diags/tests/test_savediags.py b/workflows/prognostic_run_diags/tests/test_savediags.py
index f6fa505dd1..5c70b990f1 100644
--- a/workflows/prognostic_run_diags/tests/test_savediags.py
+++ b/workflows/prognostic_run_diags/tests/test_savediags.py
@@ -2,8 +2,6 @@
import cftime
import numpy as np
import xarray as xr
-import fsspec
-from unittest.mock import Mock
import pytest
@@ -29,33 +27,6 @@ def grid():
return xr.open_dataset("grid.nc").load()
-def test_dump_nc(tmpdir):
- ds = xr.Dataset({"a": (["x"], [1.0])})
-
- path = str(tmpdir.join("data.nc"))
- with fsspec.open(path, "wb") as f:
- savediags.dump_nc(ds, f)
-
- ds_compare = xr.open_dataset(path)
- xr.testing.assert_equal(ds, ds_compare)
-
-
-def test_dump_nc_no_seek():
- """
- GCSFS file objects raise an error when seek is called in write mode::
-
- if not self.mode == "rb":
- raise ValueError("Seek only available in read mode")
- ValueError: Seek only available in read mode
-
- """
- ds = xr.Dataset({"a": (["x"], [1.0])})
- m = Mock()
-
- savediags.dump_nc(ds, m)
- m.seek.assert_not_called()
-
-
@pytest.mark.parametrize("func", savediags._DIAG_FNS)
def test_compute_diags_succeeds(func, resampled, verification, grid):
func(resampled, verification, grid)