Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add precipitation histogram to prognostic run report #1271

Merged
merged 20 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion external/report/report/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from .create_report import create_html, insert_report_figure, Metrics, Metadata, Link
from .create_report import (
create_html,
insert_report_figure,
Metrics,
Metadata,
Link,
OrderedList,
)

__version__ = "0.1.0"
11 changes: 10 additions & 1 deletion external/report/report/create_report.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import os
from typing import Mapping, Sequence, Union
from typing import Any, Mapping, Sequence, Union

from jinja2 import Template
from pytz import timezone
Expand Down Expand Up @@ -90,6 +90,15 @@ def __repr__(self) -> str:
return f'<a href="{self.url}">{self.tag}</a>'


class OrderedList:
def __init__(self, *items: Any):
self.items = items

def __repr__(self) -> str:
items_li = [f"<li>{item}</li>" for item in self.items]
return "<ol>\n" + "\n".join(items_li) + "\n</ol>"


def resolve_plot(obj):
if isinstance(obj, str):
return ImagePlot(obj)
Expand Down
7 changes: 6 additions & 1 deletion external/report/tests/test_report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from report import __version__, create_html
from report import __version__, create_html, OrderedList
from report.create_report import _save_figure, insert_report_figure


Expand Down Expand Up @@ -46,3 +46,8 @@ def test__save_figure(tmpdir):
with open(os.path.join(output_dir, filepath_relative_to_report), "r") as f:
saved_data = f.read()
assert saved_data.replace("\n", "") == fig.content


def test_OrderedList_repr():
result = str(OrderedList("item1", "item2"))
assert result == "<ol>\n<li>item1</li>\n<li>item2</li>\n</ol>"
13 changes: 13 additions & 0 deletions external/vcm/tests/test_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from vcm.calc.calc import local_time, apparent_source
from vcm.cubedsphere.constants import COORD_Z_CENTER, COORD_Z_OUTER
from vcm.calc.histogram import histogram


@pytest.mark.parametrize("toa_pressure", [0, 5])
Expand Down Expand Up @@ -119,3 +120,15 @@ def test_apparent_source():
s_dim="forecast_time",
)
assert Q1_forecast3 == pytest.approx((2.0 / (15 * 60)) - (4.0 / 60))


def test_histogram():
data = xr.DataArray(
np.reshape(np.arange(0, 40, 2), (5, 4)), dims=["x", "y"], name="temperature"
)
coords = {"temperature_bins": [0, 30]}
expected_count = xr.DataArray([15, 5], coords=coords, dims="temperature_bins")
expected_width = xr.DataArray([30, 10], coords=coords, dims="temperature_bins")
count, width = histogram(data, bins=[0, 30, 40])
xr.testing.assert_equal(count, expected_count)
xr.testing.assert_equal(width, expected_width)
1 change: 1 addition & 0 deletions external/vcm/vcm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
column_integrated_heating_from_isobaric_transition,
column_integrated_heating_from_isochoric_transition,
)
from .calc.histogram import histogram

from .interpolate import (
interpolate_to_pressure_levels,
Expand Down
28 changes: 28 additions & 0 deletions external/vcm/vcm/calc/histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Hashable, Mapping, Tuple

import numpy as np
import xarray as xr


def histogram(da: xr.DataArray, **kwargs) -> Tuple[xr.DataArray, xr.DataArray]:
"""Compute histogram and return tuple of counts and bin widths.

Args:
da: input data
kwargs: optional parameters to pass on to np.histogram

Return:
counts, bin_widths tuple of xr.DataArrays. The coordinate of both arrays is
equal to the left side of the histogram bins.
"""
coord_name = f"{da.name}_bins" if da.name is not None else "bins"
count, bins = np.histogram(da, **kwargs)
coords: Mapping[Hashable, Any] = {coord_name: bins[:-1]}
width = bins[1:] - bins[:-1]
width_da = xr.DataArray(width, coords=coords, dims=[coord_name])
count_da = xr.DataArray(count, coords=coords, dims=[coord_name])
if "units" in da.attrs:
count_da[coord_name].attrs["units"] = da.units
width_da[coord_name].attrs["units"] = da.units
width_da.attrs["units"] = da.units
return count_da, width_da
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
grouping contains outputs from the physics routines (`sfc_dt_atmos.tile*.nc` and
`diags.zarr`).
"""
import os
import sys

import datetime
import tempfile
import intake
import numpy as np
import xarray as xr
import shutil
from dask.diagnostics import ProgressBar
import fsspec

Expand All @@ -37,6 +34,7 @@
from fv3net.diagnostics.prognostic_run import diurnal_cycle
from fv3net.diagnostics.prognostic_run import transform
from fv3net.diagnostics.prognostic_run.constants import (
HISTOGRAM_BINS,
HORIZONTAL_DIMS,
DiagArg,
GLOBAL_AVERAGE_DYCORE_VARS,
Expand Down Expand Up @@ -229,16 +227,6 @@ def _assign_diagnostic_time_attrs(
return diagnostics_ds


def dump_nc(ds: xr.Dataset, f):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We switched to using the vcm.dump_nc version of this a while ago, so this func is unused.

# to_netcdf closes file, which will delete the buffer
# need to use a buffer since seek doesn't work with GCSFS file objects
with tempfile.TemporaryDirectory() as dirname:
url = os.path.join(dirname, "tmp.nc")
ds.to_netcdf(url, engine="h5netcdf")
with open(url, "rb") as tmp1:
shutil.copyfileobj(tmp1, f)


@add_to_diags("dycore")
@diag_finalizer("rms_global")
@transform.apply("resample_time", "3H", inner_join=True)
Expand Down Expand Up @@ -503,6 +491,22 @@ def _diurnal_func(
return _assign_diagnostic_time_attrs(diag, prognostic)


@add_to_diags("physics")
@diag_finalizer("histogram")
@transform.apply("resample_time", "3H", inner_join=True, method="mean")
@transform.apply("subset_variables", list(HISTOGRAM_BINS.keys()))
def compute_histogram(prognostic, verification, grid):
oliverwm1 marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Computing histograms for physics diagnostics")
counts = xr.Dataset()
for varname in prognostic.data_vars:
count, width = vcm.histogram(
prognostic[varname], bins=HISTOGRAM_BINS[varname], density=True
)
counts[varname] = count
counts[f"{varname}_bin_width"] = width
return _assign_diagnostic_time_attrs(counts, prognostic)


def register_parser(subparsers):
parser = subparsers.add_parser("save", help="Compute the prognostic run diags.")
parser.add_argument("url", help="Prognostic run output location.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,6 @@ def detect_folders(
bucket: str, fs: fsspec.AbstractFileSystem,
) -> Mapping[str, DiagnosticFolder]:
diag_ncs = fs.glob(os.path.join(bucket, "*", "diags.nc"))
if len(diag_ncs) < 2:
raise ValueError(
"Plots require more than 1 diagnostic directory in"
f" {bucket} for holoviews plots to display correctly."
)
return {
Path(url).parent.name: DiagnosticFolder(fs, Path(url).parent.as_posix())
for url in diag_ncs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import xarray as xr
from typing import Tuple

Expand Down Expand Up @@ -135,3 +136,7 @@
"dQu",
"dQv",
]

PRECIP_RATE = "total_precip_to_surface"
HISTOGRAM_BINS = {PRECIP_RATE: np.logspace(-1, np.log10(500), 101)}
PERCENTILES = [25, 50, 75, 90, 99, 99.9]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import xarray as xr
from toolz import curry
from .constants import HORIZONTAL_DIMS
from .constants import HORIZONTAL_DIMS, PERCENTILES
import json

_METRICS = []
Expand Down Expand Up @@ -138,6 +138,49 @@ def rmse_time_mean(diags):
return rms_of_time_mean_bias


for percentile in PERCENTILES:

@add_to_metrics(f"percentile_{percentile}")
def percentile_metric(diags, percentile=percentile):
histogram = grab_diag(diags, "histogram")
percentiles = xr.Dataset()
data_vars = [v for v in histogram.data_vars if not v.endswith("bin_width")]
for varname in data_vars:
percentiles[varname] = compute_percentile(
percentile,
histogram[varname].values,
histogram[f"{varname}_bins"].values,
histogram[f"{varname}_bin_width"].values,
)
restore_units(histogram, percentiles)
oliverwm1 marked this conversation as resolved.
Show resolved Hide resolved
return percentiles


def compute_percentile(
percentile: float, freq: np.ndarray, bins: np.ndarray, bin_widths: np.ndarray
) -> float:
"""Compute percentile given normalized histogram.

Args:
percentile: value between 0 and 100
freq: array of frequencies normalized by bin widths
bins: values of left sides of bins
bin_widths: values of bin widths

Returns:
value of distribution at percentile
"""
cumulative_distribution = np.cumsum(freq * bin_widths)
if np.abs(cumulative_distribution[-1] - 1) > 1e-6:
raise ValueError(
"The provided frequencies do not integrate to one. "
"Ensure that histogram is computed with density=True."
)
bin_midpoints = bins + 0.5 * bin_widths
closest_index = np.argmin(np.abs(cumulative_distribution - percentile / 100))
return bin_midpoints[closest_index]


def restore_units(source, target):
for variable in target:
target[variable].attrs["units"] = source[variable].attrs["units"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def transform(*diag_args):

@add_to_input_transform_fns
def resample_time(
freq_label: str, arg: DiagArg, time_slice=slice(None, -1), inner_join: bool = False
freq_label: str,
arg: DiagArg,
time_slice=slice(None, -1),
inner_join: bool = False,
method: str = "nearest",
) -> DiagArg:
"""
Subset times in prognostic and verification data
Expand All @@ -102,23 +106,30 @@ def resample_time(
time by default to work with crashed simulations.
inner_join: Subset times to the intersection of prognostic and verification
data. Defaults to False.
method: how to do resampling. Can be "nearest" or "mean".
"""
prognostic, verification, grid = arg
prognostic = _downsample_only(prognostic, freq_label)
verification = _downsample_only(verification, freq_label)
prognostic = _downsample_only(prognostic, freq_label, method)
verification = _downsample_only(verification, freq_label, method)

prognostic = prognostic.isel(time=time_slice)
if inner_join:
prognostic, verification = _inner_join_time(prognostic, verification)
return prognostic, verification, grid


def _downsample_only(ds: xr.Dataset, freq_label: str) -> xr.Dataset:
def _downsample_only(ds: xr.Dataset, freq_label: str, method: str) -> xr.Dataset:
"""Resample in time, only if given freq_label is lower frequency than time
sampling of given dataset ds"""
ds_freq = ds.time.values[1] - ds.time.values[0]
if ds_freq < pd.to_timedelta(freq_label):
return ds.resample(time=freq_label, label="right").nearest()
resampled = ds.resample(time=freq_label, label="right")
if method == "nearest":
return resampled.nearest()
elif method == "mean":
return resampled.mean()
else:
raise ValueError(f"Don't know how to resample with method={method}.")
else:
return ds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
}


def fig_to_b64(fig, format="png"):
def fig_to_b64(fig, format="png", dpi=None):
pic_IObytes = io.BytesIO()
fig.savefig(pic_IObytes, format=format, bbox_inches="tight")
fig.savefig(pic_IObytes, format=format, bbox_inches="tight", dpi=dpi)
pic_IObytes.seek(0)
pic_hash = base64.b64encode(pic_IObytes.read())
return f"data:image/png;base64, " + pic_hash.decode()
Expand Down Expand Up @@ -179,6 +179,27 @@ def plot_cubed_sphere_map(
)


def plot_histogram(run_diags: RunDiagnostics, varname: str) -> raw_html:
"""Plot 1D histogram of varname overlaid across runs."""

logging.info(f"plotting {varname}")
fig, ax = plt.subplots()
bin_name = varname.replace("histogram", "bins")
for run in run_diags.runs:
v = run_diags.get_variable(run, varname)
ax.step(v[bin_name], v, label=run, where="post", linewidth=1)
ax.set_xlabel(f"{v.long_name} [{v.units}]")
ax.set_ylabel(f"Frequency [({v.units})^-1]")
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlim([v[bin_name].values[0], v[bin_name].values[-1]])
ax.legend()
fig.tight_layout()
data = fig_to_b64(fig, dpi=150)
plt.close(fig)
return raw_html(f'<img src="{data}" width="800px" />')


def _render_map_title(
metrics: RunMetrics, variable: str, run: str, metrics_for_title: Mapping[str, str],
) -> str:
Expand Down
Loading