Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617660990
  • Loading branch information
ilopezgp authored and The swirl_dynamics Authors committed Mar 20, 2024
1 parent c09b9bc commit cc0c4ac
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 1 deletion.
135 changes: 134 additions & 1 deletion swirl_dynamics/templates/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
import functools
import time
from typing import Any, Protocol
from typing import Any, Protocol, Self

from absl import logging
from clu import metric_writers
Expand All @@ -40,6 +40,27 @@
PredType = Any


# Forked from clu.metrics
def _broadcast_masks(values: jax.Array, mask: jax.Array | None):
"""Checks and broadcasts mask for aggregating values."""
if values.ndim == 0:
values = values[None]
if mask is None:
mask = jnp.ones_like(values)
# Leading dimensions of mask and values must match.
if mask.shape[0] != values.shape[0]:
raise ValueError(
"Argument `mask` must have the same leading dimension as `values`. "
f"Received mask of dimension {mask.shape} "
f"and values of dimension {values.shape}."
)
# Broadcast mask to the same number of dimensions as values.
if mask.ndim < values.ndim:
mask = jnp.expand_dims(mask, axis=tuple(np.arange(mask.ndim, values.ndim)))
mask = mask.astype(bool)
return values, mask


class Benchmark(Protocol):
"""The abstract benchmark task interface.
Expand Down Expand Up @@ -129,6 +150,118 @@ def compute(self) -> jax.Array:
return _TensorAverage


def TensorRatio( # pylint: disable=invalid-name
axis: int | tuple[int, ...] | None = None
):
"""Computes the ratio between two aggregated metrics.
The ratio is performed after full aggregation of numerator and denominator.
For numerator and denominator, only entries on the selected axes are
aggregated, while the remaining axes stay untouched (which means their
dimensions will be part of the final result obtained after the aggregating
`.compute()` call).
Args:
axis: The axis or axes along which numerator and denominator are aggregated.
If `None`, average is taken across all axes.
Returns:
The metric class.
"""

@flax.struct.dataclass
class _TensorRatio(clu_metrics.Metric):
"""Ratio of metrics class."""

numerator: jax.Array
denominator: jax.Array

@classmethod
def from_model_output(
cls,
*,
numerator: jax.Array,
denominator: jax.Array,
mask: jax.Array | None = None,
**_,
) -> Self:
"""Construct a metric instance given model output values."""
numerator, mask = _broadcast_masks(numerator, mask)
return cls(
numerator=jnp.where(mask, numerator, jnp.zeros_like(numerator)).sum(
axis=axis
),
denominator=jnp.where(
mask, denominator, jnp.zeros_like(denominator)
).sum(axis=axis),
)

def merge(self, other):
"""Merges with another metric instance of the same type."""
return type(self)(
numerator=self.numerator + other.numerator,
denominator=self.denominator + other.denominator,
)

def compute(self) -> jax.Array:
ratio = self.numerator / self.denominator
return ratio

@classmethod
def empty(cls) -> Self:
return cls(
numerator=jnp.array(0, jnp.float32),
denominator=jnp.array(0, jnp.float32),
)

@classmethod
def from_outputs(cls, numerator_name: str, denominator_name: str):
"""Calls `cls.from_model_output` with model output names.
Synopsis:
@flax.struct.dataclass
class Metrics(Collection):
loss: TensorRatio(axis=None).from_outputs("foo", "bar")
Args:
numerator_name: Name of the model output that should be passed as the
`numerator` keyword argument to `cls.from_model_output()`.
denominator_name: Name of the model output that should be passed as the
`denominator` keyword argument to `cls.from_model_output()`.
Returns:
A `Metric` derived from `cls` that calls `.from_model_output()` with the
the specified numerator and denominator arguments.
"""

@flax.struct.dataclass
class FromOutputs(cls):
"""Wrapper Metric class that collects numerator and denominator."""

@classmethod
def from_model_output(cls, **model_output):
numerator = jnp.array(model_output[numerator_name])
denominator = jnp.array(model_output[denominator_name])
mask = model_output.get("mask")
if mask is not None and (numerator.shape or [0])[0] != mask.shape[0]:
logging.warning(
"Ignoring mask for model output '%s' because of shape mismatch:"
" output.shape=%s vs. mask.shape=%s",
numerator_name,
numerator.shape,
mask.shape,
)
mask = None
return super().from_model_output(
numerator=numerator, denominator=denominator, mask=mask
)

return FromOutputs

return _TensorRatio


class CollectingResults:
"""Object that holds collected evaluation results.
Expand Down
45 changes: 45 additions & 0 deletions swirl_dynamics/templates/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,51 @@ def test_rms(self):
np.testing.assert_allclose(metric_compute, expected, atol=1e-5)


@parameterized.parameters(
((1, 2),),
(None,),
(3,),
)
class TensorRatioTest(parameterized.TestCase):

def test_from_model_output(self, agg_axis):
test_shape = (1, 2, 3, 4, 5)
metric_cls = evaluate.TensorRatio(agg_axis).from_outputs("abc", "efg")
rng = np.random.default_rng(123)
test_values = rng.random(test_shape)
expected = 2 * np.ones_like(np.mean(test_values, axis=agg_axis))
metric = metric_cls.from_model_output(abc=2 * test_values, efg=test_values)
metric_compute = metric.compute()
np.testing.assert_allclose(metric_compute, expected, atol=1e-4)

def test_merge(self, agg_axis):
test_shape = (1, 2, 3, 4, 5)
metric_cls = evaluate.TensorRatio(agg_axis).from_outputs("abc", "efg")
metric = metric_cls.empty()
rng = np.random.default_rng(321)
test_values1 = rng.random(test_shape)
test_values2 = rng.random(test_shape)
metric = metric.merge(
metric_cls.from_model_output(abc=test_values1, efg=test_values2)
)
test_values3 = rng.random(test_shape)
test_values4 = rng.random(test_shape)
metric = metric.merge(
metric_cls.from_model_output(abc=test_values3, efg=test_values4)
)
metric_compute = metric.compute()
# The expected result is the ratio of the sums of numerators and
# denominators, which does not commute with the sum of ratios.
expected_numerator = np.add(
np.sum(test_values1, axis=agg_axis), np.sum(test_values3, axis=agg_axis)
)
expected_denominator = np.add(
np.sum(test_values2, axis=agg_axis), np.sum(test_values4, axis=agg_axis)
)
expected = np.divide(expected_numerator, expected_denominator)
np.testing.assert_allclose(metric_compute, expected, atol=1e-5)


class CollectingResultsTest(parameterized.TestCase):

def test_collect_batches_and_compute(self):
Expand Down

0 comments on commit cc0c4ac

Please sign in to comment.