Skip to content

Commit

Permalink
Update calls to clu metrics to pass jnp.ndarrays instead of ints.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557868914
  • Loading branch information
adarob authored and t5-copybara committed Aug 17, 2023
1 parent 456e56b commit bb71132
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions t5x/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,8 +1336,8 @@ def compute_base_metrics(
Returns:
Dict of metrics.
"""
num_examples = targets.shape[0]
num_tokens = targets.size
num_examples = jnp.array(targets.shape[0])
num_tokens = jnp.array(targets.size)
num_devices = jax.device_count()
assert num_devices, 'JAX is reporting no devices, but it should.'
# Note: apply mask again even though mask has already been applied to loss.
Expand Down Expand Up @@ -1389,8 +1389,8 @@ def compute_base_metrics(
})

if segment_ids is not None:
total_tokens = 0
total_non_padding_tokens = 0
total_tokens = jnp.array(0)
total_non_padding_tokens = jnp.array(0)
for feature, feature_segment_ids in segment_ids.items():
if feature_segment_ids is None or feature_segment_ids.shape[1] == 0:
continue
Expand All @@ -1401,7 +1401,7 @@ def compute_base_metrics(
)
# 0s is padding
feature_non_padding = jnp.sum(feature_segment_ids != 0)
feature_size = feature_segment_ids.size
feature_size = jnp.array(feature_segment_ids.size)
total_tokens += feature_size
total_non_padding_tokens += feature_non_padding
metrics[f'non_padding_fraction/{feature}'] = clu_metrics.Average(
Expand Down

0 comments on commit bb71132

Please sign in to comment.