Skip to content

Commit

Permalink
Change metric compute() return type to jnp.ndarray to match upcoming …
Browse files Browse the repository at this point in the history
…pytype changes in CLU.

PiperOrigin-RevId: 557561097
  • Loading branch information
adarob authored and t5-copybara committed Aug 16, 2023
1 parent dd1b2a8 commit 0798d57
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions t5x/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
import flax # Only used for flax.struct.dataclass.
import jax
import jax.numpy as jnp
import numpy as np

MetricsMap = MutableMapping[str, clu_metrics.Metric]
Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray]
Scalar = jnp.ndarray


def _check_param(value, *, ndim=None, dtype=jnp.float32):
Expand Down Expand Up @@ -91,15 +90,15 @@ class Step(clu_metrics.Metric):
"""
steps: Optional[int] = 1

def replace_steps(self, steps) -> "Step":
def replace_steps(self, steps: Union[Scalar, int]) -> "Step":
return self.replace(steps=steps)

def compute(self) -> Scalar:
if self.steps is None:
raise ValueError(
"`steps` must be set by calling `replace_steps` before computing metric."
)
return self.steps
return jnp.array(self.steps)


@flax.struct.dataclass
Expand Down Expand Up @@ -176,7 +175,7 @@ def replace_duration(self, duration: Scalar) -> "Time":
Returns:
A new Time object.
"""
return self.replace(duration=duration)
return self.replace(duration=jnp.array(duration))


@flax.struct.dataclass
Expand Down

0 comments on commit 0798d57

Please sign in to comment.