diff --git a/t5x/metrics.py b/t5x/metrics.py index 429adafc6..f2ecd1e3f 100644 --- a/t5x/metrics.py +++ b/t5x/metrics.py @@ -76,8 +76,8 @@ def from_model_output(cls, values: Scalar, **_) -> clu_metrics.Metric: def merge(self, other: "Sum") -> "Sum": return type(self)(total=self.total + other.total) - def compute(self) -> Scalar: - return self.total + def compute(self) -> jnp.ndarray: + return jnp.array(self.total) @flax.struct.dataclass @@ -91,15 +91,15 @@ class Step(clu_metrics.Metric): """ steps: Optional[int] = 1 - def replace_steps(self, steps) -> "Step": + def replace_steps(self, steps: int) -> "Step": return self.replace(steps=steps) - def compute(self) -> Scalar: + def compute(self) -> jnp.ndarray: if self.steps is None: raise ValueError( "`steps` must be set by calling `replace_steps` before computing metric." ) - return self.steps + return jnp.array(self.steps) @flax.struct.dataclass @@ -134,7 +134,7 @@ def merge(self, other: "AveragePerStep") -> "AveragePerStep": return type(self)( total=self.total + other.total, steps=self.steps + other.steps) - def compute(self) -> Scalar: + def compute(self) -> jnp.ndarray: steps = super().compute() if self.total is None: raise ValueError("`AveragePerStep` `total` cannot be None.") @@ -157,12 +157,12 @@ class Time(clu_metrics.Metric): def merge(self, other: "Time") -> "Time": return self - def compute(self) -> Scalar: + def compute(self) -> jnp.ndarray: if self.duration is None: raise ValueError( "`Time` `duration` must be set by calling `replace_duration` before computing." ) - return self.duration + return jnp.array(self.duration) def replace_duration(self, duration: Scalar) -> "Time": """Replaces duration with the given value. @@ -210,7 +210,7 @@ def merge(self, other: "TimeRate") -> "TimeRate": assert self.duration is None and other.duration is None, assert_msg return type(self)(numerator=self.numerator + other.numerator) - def compute(self) -> Scalar: + def compute(self) -> jnp.ndarray: duration = super().compute() return self.numerator / duration @@ -240,7 +240,7 @@ def merge(self, other: "StepsPerTime") -> "StepsPerTime": assert type(self) is type(other) return type(self)(steps=self.steps + other.steps) - def compute(self) -> Scalar: + def compute(self) -> jnp.ndarray: steps = Step.compute(self) duration = Time.compute(self) return steps / duration