From 0798d572089043deb9840bb463df39ecd8d72e87 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Wed, 16 Aug 2023 12:04:49 -0700 Subject: [PATCH] Change metric compute() return type to jnp.ndarray to match upcoming pytype changes in CLU. PiperOrigin-RevId: 557561097 --- t5x/metrics.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/t5x/metrics.py b/t5x/metrics.py index 429adafc6..15ea6c7e2 100644 --- a/t5x/metrics.py +++ b/t5x/metrics.py @@ -25,10 +25,9 @@ import flax # Only used for flax.struct.dataclass. import jax import jax.numpy as jnp -import numpy as np MetricsMap = MutableMapping[str, clu_metrics.Metric] -Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray] +Scalar = jnp.ndarray def _check_param(value, *, ndim=None, dtype=jnp.float32): @@ -91,7 +90,7 @@ class Step(clu_metrics.Metric): """ steps: Optional[int] = 1 - def replace_steps(self, steps) -> "Step": + def replace_steps(self, steps: Union[Scalar, int]) -> "Step": return self.replace(steps=steps) def compute(self) -> Scalar: @@ -99,7 +98,7 @@ def compute(self) -> Scalar: raise ValueError( "`steps` must be set by calling `replace_steps` before computing metric." ) - return self.steps + return jnp.array(self.steps) @flax.struct.dataclass @@ -176,7 +175,7 @@ def replace_duration(self, duration: Scalar) -> "Time": Returns: A new Time object. """ - return self.replace(duration=duration) + return self.replace(duration=jnp.array(duration)) @flax.struct.dataclass