From 759f11585287ece327b3d98b7af6ca940c78e24e Mon Sep 17 00:00:00 2001 From: T5X Team Date: Wed, 16 Aug 2023 17:19:52 -0700 Subject: [PATCH] Change metric compute() return type to jnp.ndarray to match upcoming pytype changes in CLU. PiperOrigin-RevId: 557649671 --- t5x/metrics.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/t5x/metrics.py b/t5x/metrics.py index f2ecd1e3f..429adafc6 100644 --- a/t5x/metrics.py +++ b/t5x/metrics.py @@ -76,8 +76,8 @@ def from_model_output(cls, values: Scalar, **_) -> clu_metrics.Metric: def merge(self, other: "Sum") -> "Sum": return type(self)(total=self.total + other.total) - def compute(self) -> jnp.ndarray: - return jnp.array(self.total) + def compute(self) -> Scalar: + return self.total @flax.struct.dataclass @@ -91,15 +91,15 @@ class Step(clu_metrics.Metric): """ steps: Optional[int] = 1 - def replace_steps(self, steps: int) -> "Step": + def replace_steps(self, steps) -> "Step": return self.replace(steps=steps) - def compute(self) -> jnp.ndarray: + def compute(self) -> Scalar: if self.steps is None: raise ValueError( "`steps` must be set by calling `replace_steps` before computing metric." ) - return jnp.array(self.steps) + return self.steps @flax.struct.dataclass @@ -134,7 +134,7 @@ def merge(self, other: "AveragePerStep") -> "AveragePerStep": return type(self)( total=self.total + other.total, steps=self.steps + other.steps) - def compute(self) -> jnp.ndarray: + def compute(self) -> Scalar: steps = super().compute() if self.total is None: raise ValueError("`AveragePerStep` `total` cannot be None.") @@ -157,12 +157,12 @@ class Time(clu_metrics.Metric): def merge(self, other: "Time") -> "Time": return self - def compute(self) -> jnp.ndarray: + def compute(self) -> Scalar: if self.duration is None: raise ValueError( "`Time` `duration` must be set by calling `replace_duration` before computing." ) - return jnp.array(self.duration) + return self.duration def replace_duration(self, duration: Scalar) -> "Time": """Replaces duration with the given value. @@ -210,7 +210,7 @@ def merge(self, other: "TimeRate") -> "TimeRate": assert self.duration is None and other.duration is None, assert_msg return type(self)(numerator=self.numerator + other.numerator) - def compute(self) -> jnp.ndarray: + def compute(self) -> Scalar: duration = super().compute() return self.numerator / duration @@ -240,7 +240,7 @@ def merge(self, other: "StepsPerTime") -> "StepsPerTime": assert type(self) is type(other) return type(self)(steps=self.steps + other.steps) - def compute(self) -> jnp.ndarray: + def compute(self) -> Scalar: steps = Step.compute(self) duration = Time.compute(self) return steps / duration