diff --git a/swirl_dynamics/data/zarr_utils.py b/swirl_dynamics/data/zarr_utils.py index e65b606..26e53ff 100644 --- a/swirl_dynamics/data/zarr_utils.py +++ b/swirl_dynamics/data/zarr_utils.py @@ -18,6 +18,7 @@ import os from typing import Any +from absl import logging from etils import epath import xarray as xr @@ -32,12 +33,20 @@ def collected_metrics_to_ds( ) -> xr.Dataset: """Packages collected metrics as an xarray.Dataset. + The metrics defined by `data` are packaged into a Dataset as variables + sharing dimensions and coordinates. The function assumes that the collecting + axis (with name `append_dim`) is the first one in the `data`; typically the + batch dimension of the dataset being processed. If passed, the `coords` + provide coordinates for the dimensions of the `data`. + Args: data: A mapping of metric names to their collected values. append_dim: The name of the axis dimension of metric collection, enforced to allow downstream dataset appending. append_slice: Current index slice in the `append_dim` axis. coords: xarray coordinates of the label dataset used to compute the metrics. + These coordinates are used to annotate the collected metrics if the + dimension to dimension size mapping is injective. Returns: A dataset containing all collected metrics as variables, with coordinate @@ -51,15 +60,21 @@ def collected_metrics_to_ds( dims.append( list(coords.dims.keys())[list(coords.dims.values()).index(cur_size)] ) - - coord_dict = { - elem: coords[elem].data for elem in dims if elem != append_dim - } - coord_dict[append_dim] = coords[append_dim].data[append_slice] + if len(dims) == len(set(dims)): + coord_dict = { + elem: coords[elem].data for elem in dims if elem != append_dim + } + coord_dict[append_dim] = coords[append_dim].data[append_slice] + else: + logging.warning( + 'The coordinate order of the data cannot be inferred:' + 'from their shape due to same-length dimensions. ' + 'Reverting to generic dimension labels.' + ) data_vars = {} for key, value in data.items(): - if coords is None: + if coord_dict is None: dims.extend([f'dim_{i}' for i in range(value.ndim - 1)]) data_vars[key] = (dims, value) @@ -89,6 +104,63 @@ def collected_metrics_to_zarr( write_to_file(ds, out_dir, basename, append_dim) +def aggregated_metrics_to_ds( + data: Mapping[str, Any], + coords: xr.core.coordinates.DatasetCoordinates | None = None, +) -> xr.Dataset: + """Packages aggregated metrics as an xarray.Dataset. + + Args: + data: A mapping of metric names to their aggregated values. + coords: xarray coordinates of the label dataset used to compute the metrics. + These coordinates are used to annotate the aggregated metrics if the + dimension to dimension size mapping is injective. + + Returns: + A dataset containing all aggregated metrics as variables, with coordinate + metadata. + """ + coord_dict = None + dim_dict = {} + if coords is not None: + dims = coords.dims + if len(dims.values()) == len(set(dims.values())): + coord_dict = {elem: coords[elem].data for elem in dims.keys()} + dim_dict = {n_dim: dim_name for dim_name, n_dim in dims.items()} + else: + logging.warning( + 'The coordinate order of the data cannot be inferred:' + 'from their shape due to same-length dimensions. ' + 'Reverting to generic dimension labels.' + ) + + data_vars = {} + for key, value in data.items(): + if coord_dict is None: + dims = [f'dim_{i}' for i in range(value.ndim)] + else: + dims = [dim_dict[dim_length] for dim_length in value.shape] + data_vars[key] = (dims, value) + + return xr.Dataset( + data_vars=data_vars, + coords=coord_dict, + attrs=dict(description='Aggregated metrics.'), + ) + + +def aggregated_metrics_to_zarr( + data: Mapping[str, Any], + *, + out_dir: epath.PathLike, + basename: str, + coords: xr.core.coordinates.DatasetCoordinates | None = None, +) -> None: + """Writes aggregated metrics to zarr.""" + ds = aggregated_metrics_to_ds(data, coords) + write_to_file(ds, out_dir, basename) + + def write_to_file( ds, out_dir: epath.PathLike, basename: str, append_dim: str | None = None ) -> None: diff --git a/swirl_dynamics/data/zarr_utils_test.py b/swirl_dynamics/data/zarr_utils_test.py index 3d72299..d9d8398 100644 --- a/swirl_dynamics/data/zarr_utils_test.py +++ b/swirl_dynamics/data/zarr_utils_test.py @@ -81,6 +81,67 @@ def test_collected_metrics_to_zarr(self): self.assertTrue(os.path.exists(os.path.join(outdir, "test_metrics.zarr"))) + def test_aggregated_metrics_to_ds(self): + + shape = (10, 5, 3) + data = {"foo": np.ones(shape), "bar": np.ones(shape)} + coord_dict = { + "time": pd.date_range("2012-01-01", "2012-01-08"), + "lon": range(shape[0]), + "lat": range(shape[1]), + "field": ["var1", "var2", "var3"], + } + coords = xr.Dataset(coords=coord_dict).coords + ds = zarr_utils.aggregated_metrics_to_ds(data, coords) + + self.assertIsInstance(ds, xr.Dataset) + self.assertIn("foo", ds) + self.assertIn("bar", ds) + self.assertIn("field", ds.dims) + self.assertEqual(ds.dims["field"], 3) + self.assertEqual(ds.dims["lon"], 10) + + def test_aggregated_metrics_to_ds_ambiguous_shape(self): + + shape = (5, 5, 3) + data = {"foo": np.ones(shape), "bar": np.ones(shape)} + coord_dict = { + "time": pd.date_range("2012-01-01", "2012-01-08"), + "lon": range(shape[0]), + "lat": range(shape[1]), + "field": ["var1", "var2", "var3"], + } + coords = xr.Dataset(coords=coord_dict).coords + ds = zarr_utils.aggregated_metrics_to_ds(data, coords) + # Returned coordinates are generic when ambiguous + self.assertIsInstance(ds, xr.Dataset) + self.assertIn("foo", ds) + self.assertIn("bar", ds) + self.assertIn("dim_0", ds.dims) + self.assertEqual(ds.dims["dim_2"], 3) + self.assertEqual(ds.dims["dim_1"], 5) + + def test_aggregated_metrics_to_zarr(self): + + shape = (10, 5, 3) + data = {"foo": np.ones(shape), "bar": np.ones(shape)} + coord_dict = { + "time": pd.date_range("2012-01-01", "2012-01-08"), + "lon": range(shape[0]), + "lat": range(shape[1]), + "field": ["var1", "var2", "var3"], + } + coords = xr.Dataset(coords=coord_dict).coords + outdir = self.create_tempdir() + zarr_utils.aggregated_metrics_to_zarr( + data, + out_dir=outdir, + basename="test_metrics", + coords=coords, + ) + + self.assertTrue(os.path.exists(os.path.join(outdir, "test_metrics.zarr"))) + def test_write_to_file(self): foo = np.ones((3,)) outdir = self.create_tempdir()