Skip to content

Commit

Permalink
Change metric compute() return type to jnp.ndarray to match upcoming …
Browse files Browse the repository at this point in the history
…pytype changes in CLU.

PiperOrigin-RevId: 557796316
  • Loading branch information
T5X Team authored and t5-copybara committed Aug 17, 2023
1 parent 77d2624 commit 94a4f05
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions t5x/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def from_model_output(cls, values: Scalar, **_) -> clu_metrics.Metric:
def merge(self, other: "Sum") -> "Sum":
return type(self)(total=self.total + other.total)

def compute(self) -> jnp.ndarray:
return jnp.array(self.total)
def compute(self) -> Scalar:
return self.total


@flax.struct.dataclass
Expand All @@ -91,15 +91,15 @@ class Step(clu_metrics.Metric):
"""
steps: Optional[int] = 1

def replace_steps(self, steps: int) -> "Step":
def replace_steps(self, steps) -> "Step":
return self.replace(steps=steps)

def compute(self) -> jnp.ndarray:
def compute(self) -> Scalar:
if self.steps is None:
raise ValueError(
"`steps` must be set by calling `replace_steps` before computing metric."
)
return jnp.array(self.steps)
return self.steps


@flax.struct.dataclass
Expand Down Expand Up @@ -134,7 +134,7 @@ def merge(self, other: "AveragePerStep") -> "AveragePerStep":
return type(self)(
total=self.total + other.total, steps=self.steps + other.steps)

def compute(self) -> jnp.ndarray:
def compute(self) -> Scalar:
steps = super().compute()
if self.total is None:
raise ValueError("`AveragePerStep` `total` cannot be None.")
Expand All @@ -157,12 +157,12 @@ class Time(clu_metrics.Metric):
def merge(self, other: "Time") -> "Time":
return self

def compute(self) -> jnp.ndarray:
def compute(self) -> Scalar:
if self.duration is None:
raise ValueError(
"`Time` `duration` must be set by calling `replace_duration` before computing."
)
return jnp.array(self.duration)
return self.duration

def replace_duration(self, duration: Scalar) -> "Time":
"""Replaces duration with the given value.
Expand Down Expand Up @@ -210,7 +210,7 @@ def merge(self, other: "TimeRate") -> "TimeRate":
assert self.duration is None and other.duration is None, assert_msg
return type(self)(numerator=self.numerator + other.numerator)

def compute(self) -> jnp.ndarray:
def compute(self) -> Scalar:
duration = super().compute()
return self.numerator / duration

Expand Down Expand Up @@ -240,7 +240,7 @@ def merge(self, other: "StepsPerTime") -> "StepsPerTime":
assert type(self) is type(other)
return type(self)(steps=self.steps + other.steps)

def compute(self) -> jnp.ndarray:
def compute(self) -> Scalar:
steps = Step.compute(self)
duration = Time.compute(self)
return steps / duration
Expand Down

0 comments on commit 94a4f05

Please sign in to comment.