From cc0c4ac1d49208a37ef7a48858cb66d332f71cba Mon Sep 17 00:00:00 2001 From: Ignacio Lopez-Gomez Date: Wed, 20 Mar 2024 16:46:37 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 617660990 --- swirl_dynamics/templates/evaluate.py | 135 +++++++++++++++++++++- swirl_dynamics/templates/evaluate_test.py | 45 ++++++++ 2 files changed, 179 insertions(+), 1 deletion(-) diff --git a/swirl_dynamics/templates/evaluate.py b/swirl_dynamics/templates/evaluate.py index ad56afe..8a7bc16 100644 --- a/swirl_dynamics/templates/evaluate.py +++ b/swirl_dynamics/templates/evaluate.py @@ -17,7 +17,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence import functools import time -from typing import Any, Protocol +from typing import Any, Protocol, Self from absl import logging from clu import metric_writers @@ -40,6 +40,27 @@ PredType = Any +# Forked from clu.metrics +def _broadcast_masks(values: jax.Array, mask: jax.Array | None): + """Checks and broadcasts mask for aggregating values.""" + if values.ndim == 0: + values = values[None] + if mask is None: + mask = jnp.ones_like(values) + # Leading dimensions of mask and values must match. + if mask.shape[0] != values.shape[0]: + raise ValueError( + "Argument `mask` must have the same leading dimension as `values`. " + f"Received mask of dimension {mask.shape} " + f"and values of dimension {values.shape}." + ) + # Broadcast mask to the same number of dimensions as values. + if mask.ndim < values.ndim: + mask = jnp.expand_dims(mask, axis=tuple(np.arange(mask.ndim, values.ndim))) + mask = mask.astype(bool) + return values, mask + + class Benchmark(Protocol): """The abstract benchmark task interface. @@ -129,6 +150,118 @@ def compute(self) -> jax.Array: return _TensorAverage +def TensorRatio( # pylint: disable=invalid-name + axis: int | tuple[int, ...] | None = None +): + """Computes the ratio between two aggregated metrics. + + The ratio is performed after full aggregation of numerator and denominator. + For numerator and denominator, only entries on the selected axes are + aggregated, while the remaining axes stay untouched (which means their + dimensions will be part of the final result obtained after the aggregating + `.compute()` call). + + Args: + axis: The axis or axes along which numerator and denominator are aggregated. + If `None`, average is taken across all axes. + + Returns: + The metric class. + """ + + @flax.struct.dataclass + class _TensorRatio(clu_metrics.Metric): + """Ratio of metrics class.""" + + numerator: jax.Array + denominator: jax.Array + + @classmethod + def from_model_output( + cls, + *, + numerator: jax.Array, + denominator: jax.Array, + mask: jax.Array | None = None, + **_, + ) -> Self: + """Construct a metric instance given model output values.""" + numerator, mask = _broadcast_masks(numerator, mask) + return cls( + numerator=jnp.where(mask, numerator, jnp.zeros_like(numerator)).sum( + axis=axis + ), + denominator=jnp.where( + mask, denominator, jnp.zeros_like(denominator) + ).sum(axis=axis), + ) + + def merge(self, other): + """Merges with another metric instance of the same type.""" + return type(self)( + numerator=self.numerator + other.numerator, + denominator=self.denominator + other.denominator, + ) + + def compute(self) -> jax.Array: + ratio = self.numerator / self.denominator + return ratio + + @classmethod + def empty(cls) -> Self: + return cls( + numerator=jnp.array(0, jnp.float32), + denominator=jnp.array(0, jnp.float32), + ) + + @classmethod + def from_outputs(cls, numerator_name: str, denominator_name: str): + """Calls `cls.from_model_output` with model output names. + + Synopsis: + + @flax.struct.dataclass + class Metrics(Collection): + loss: TensorRatio(axis=None).from_outputs("foo", "bar") + + Args: + numerator_name: Name of the model output that should be passed as the + `numerator` keyword argument to `cls.from_model_output()`. + denominator_name: Name of the model output that should be passed as the + `denominator` keyword argument to `cls.from_model_output()`. + + Returns: + A `Metric` derived from `cls` that calls `.from_model_output()` with the + the specified numerator and denominator arguments. + """ + + @flax.struct.dataclass + class FromOutputs(cls): + """Wrapper Metric class that collects numerator and denominator.""" + + @classmethod + def from_model_output(cls, **model_output): + numerator = jnp.array(model_output[numerator_name]) + denominator = jnp.array(model_output[denominator_name]) + mask = model_output.get("mask") + if mask is not None and (numerator.shape or [0])[0] != mask.shape[0]: + logging.warning( + "Ignoring mask for model output '%s' because of shape mismatch:" + " output.shape=%s vs. mask.shape=%s", + numerator_name, + numerator.shape, + mask.shape, + ) + mask = None + return super().from_model_output( + numerator=numerator, denominator=denominator, mask=mask + ) + + return FromOutputs + + return _TensorRatio + + class CollectingResults: """Object that holds collected evaluation results. diff --git a/swirl_dynamics/templates/evaluate_test.py b/swirl_dynamics/templates/evaluate_test.py index d300147..f8840a7 100644 --- a/swirl_dynamics/templates/evaluate_test.py +++ b/swirl_dynamics/templates/evaluate_test.py @@ -92,6 +92,51 @@ def test_rms(self): np.testing.assert_allclose(metric_compute, expected, atol=1e-5) +@parameterized.parameters( + ((1, 2),), + (None,), + (3,), +) +class TensorRatioTest(parameterized.TestCase): + + def test_from_model_output(self, agg_axis): + test_shape = (1, 2, 3, 4, 5) + metric_cls = evaluate.TensorRatio(agg_axis).from_outputs("abc", "efg") + rng = np.random.default_rng(123) + test_values = rng.random(test_shape) + expected = 2 * np.ones_like(np.mean(test_values, axis=agg_axis)) + metric = metric_cls.from_model_output(abc=2 * test_values, efg=test_values) + metric_compute = metric.compute() + np.testing.assert_allclose(metric_compute, expected, atol=1e-4) + + def test_merge(self, agg_axis): + test_shape = (1, 2, 3, 4, 5) + metric_cls = evaluate.TensorRatio(agg_axis).from_outputs("abc", "efg") + metric = metric_cls.empty() + rng = np.random.default_rng(321) + test_values1 = rng.random(test_shape) + test_values2 = rng.random(test_shape) + metric = metric.merge( + metric_cls.from_model_output(abc=test_values1, efg=test_values2) + ) + test_values3 = rng.random(test_shape) + test_values4 = rng.random(test_shape) + metric = metric.merge( + metric_cls.from_model_output(abc=test_values3, efg=test_values4) + ) + metric_compute = metric.compute() + # The expected result is the ratio of the sums of numerators and + # denominators, which does not commute with the sum of ratios. + expected_numerator = np.add( + np.sum(test_values1, axis=agg_axis), np.sum(test_values3, axis=agg_axis) + ) + expected_denominator = np.add( + np.sum(test_values2, axis=agg_axis), np.sum(test_values4, axis=agg_axis) + ) + expected = np.divide(expected_numerator, expected_denominator) + np.testing.assert_allclose(metric_compute, expected, atol=1e-5) + + class CollectingResultsTest(parameterized.TestCase): def test_collect_batches_and_compute(self):