Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617564740
  • Loading branch information
ilopezgp authored and The swirl_dynamics Authors committed Mar 20, 2024
1 parent 9e8d199 commit c09b9bc
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 6 deletions.
84 changes: 78 additions & 6 deletions swirl_dynamics/data/zarr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from typing import Any

from absl import logging
from etils import epath
import xarray as xr

Expand All @@ -32,12 +33,20 @@ def collected_metrics_to_ds(
) -> xr.Dataset:
"""Packages collected metrics as an xarray.Dataset.
The metrics defined by `data` are packaged into a Dataset as variables
sharing dimensions and coordinates. The function assumes that the collecting
axis (with name `append_dim`) is the first one in the `data`; typically the
batch dimension of the dataset being processed. If passed, the `coords`
provide coordinates for the dimensions of the `data`.
Args:
data: A mapping of metric names to their collected values.
append_dim: The name of the axis dimension of metric collection, enforced to
allow downstream dataset appending.
append_slice: Current index slice in the `append_dim` axis.
coords: xarray coordinates of the label dataset used to compute the metrics.
These coordinates are used to annotate the collected metrics if the
dimension to dimension size mapping is injective.
Returns:
A dataset containing all collected metrics as variables, with coordinate
Expand All @@ -51,15 +60,21 @@ def collected_metrics_to_ds(
dims.append(
list(coords.dims.keys())[list(coords.dims.values()).index(cur_size)]
)

coord_dict = {
elem: coords[elem].data for elem in dims if elem != append_dim
}
coord_dict[append_dim] = coords[append_dim].data[append_slice]
if len(dims) == len(set(dims)):
coord_dict = {
elem: coords[elem].data for elem in dims if elem != append_dim
}
coord_dict[append_dim] = coords[append_dim].data[append_slice]
else:
logging.warning(
'The coordinate order of the data cannot be inferred:'
'from their shape due to same-length dimensions. '
'Reverting to generic dimension labels.'
)

data_vars = {}
for key, value in data.items():
if coords is None:
if coord_dict is None:
dims.extend([f'dim_{i}' for i in range(value.ndim - 1)])
data_vars[key] = (dims, value)

Expand Down Expand Up @@ -89,6 +104,63 @@ def collected_metrics_to_zarr(
write_to_file(ds, out_dir, basename, append_dim)


def aggregated_metrics_to_ds(
data: Mapping[str, Any],
coords: xr.core.coordinates.DatasetCoordinates | None = None,
) -> xr.Dataset:
"""Packages aggregated metrics as an xarray.Dataset.
Args:
data: A mapping of metric names to their aggregated values.
coords: xarray coordinates of the label dataset used to compute the metrics.
These coordinates are used to annotate the aggregated metrics if the
dimension to dimension size mapping is injective.
Returns:
A dataset containing all aggregated metrics as variables, with coordinate
metadata.
"""
coord_dict = None
dim_dict = {}
if coords is not None:
dims = coords.dims
if len(dims.values()) == len(set(dims.values())):
coord_dict = {elem: coords[elem].data for elem in dims.keys()}
dim_dict = {n_dim: dim_name for dim_name, n_dim in dims.items()}
else:
logging.warning(
'The coordinate order of the data cannot be inferred:'
'from their shape due to same-length dimensions. '
'Reverting to generic dimension labels.'
)

data_vars = {}
for key, value in data.items():
if coord_dict is None:
dims = [f'dim_{i}' for i in range(value.ndim)]
else:
dims = [dim_dict[dim_length] for dim_length in value.shape]
data_vars[key] = (dims, value)

return xr.Dataset(
data_vars=data_vars,
coords=coord_dict,
attrs=dict(description='Aggregated metrics.'),
)


def aggregated_metrics_to_zarr(
data: Mapping[str, Any],
*,
out_dir: epath.PathLike,
basename: str,
coords: xr.core.coordinates.DatasetCoordinates | None = None,
) -> None:
"""Writes aggregated metrics to zarr."""
ds = aggregated_metrics_to_ds(data, coords)
write_to_file(ds, out_dir, basename)


def write_to_file(
ds, out_dir: epath.PathLike, basename: str, append_dim: str | None = None
) -> None:
Expand Down
61 changes: 61 additions & 0 deletions swirl_dynamics/data/zarr_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,67 @@ def test_collected_metrics_to_zarr(self):

self.assertTrue(os.path.exists(os.path.join(outdir, "test_metrics.zarr")))

def test_aggregated_metrics_to_ds(self):

shape = (10, 5, 3)
data = {"foo": np.ones(shape), "bar": np.ones(shape)}
coord_dict = {
"time": pd.date_range("2012-01-01", "2012-01-08"),
"lon": range(shape[0]),
"lat": range(shape[1]),
"field": ["var1", "var2", "var3"],
}
coords = xr.Dataset(coords=coord_dict).coords
ds = zarr_utils.aggregated_metrics_to_ds(data, coords)

self.assertIsInstance(ds, xr.Dataset)
self.assertIn("foo", ds)
self.assertIn("bar", ds)
self.assertIn("field", ds.dims)
self.assertEqual(ds.dims["field"], 3)
self.assertEqual(ds.dims["lon"], 10)

def test_aggregated_metrics_to_ds_ambiguous_shape(self):

shape = (5, 5, 3)
data = {"foo": np.ones(shape), "bar": np.ones(shape)}
coord_dict = {
"time": pd.date_range("2012-01-01", "2012-01-08"),
"lon": range(shape[0]),
"lat": range(shape[1]),
"field": ["var1", "var2", "var3"],
}
coords = xr.Dataset(coords=coord_dict).coords
ds = zarr_utils.aggregated_metrics_to_ds(data, coords)
# Returned coordinates are generic when ambiguous
self.assertIsInstance(ds, xr.Dataset)
self.assertIn("foo", ds)
self.assertIn("bar", ds)
self.assertIn("dim_0", ds.dims)
self.assertEqual(ds.dims["dim_2"], 3)
self.assertEqual(ds.dims["dim_1"], 5)

def test_aggregated_metrics_to_zarr(self):

shape = (10, 5, 3)
data = {"foo": np.ones(shape), "bar": np.ones(shape)}
coord_dict = {
"time": pd.date_range("2012-01-01", "2012-01-08"),
"lon": range(shape[0]),
"lat": range(shape[1]),
"field": ["var1", "var2", "var3"],
}
coords = xr.Dataset(coords=coord_dict).coords
outdir = self.create_tempdir()
zarr_utils.aggregated_metrics_to_zarr(
data,
out_dir=outdir,
basename="test_metrics",
coords=coords,
)

self.assertTrue(os.path.exists(os.path.join(outdir, "test_metrics.zarr")))

def test_write_to_file(self):
foo = np.ones((3,))
outdir = self.create_tempdir()
Expand Down

0 comments on commit c09b9bc

Please sign in to comment.