diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index ba8db9ced..a76a70289 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -6,6 +6,7 @@ from flax import jax_utils import jax import jax.numpy as jnp +import numpy as np from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -147,7 +148,8 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return self._eval_batch_pmapped(params, batch).sum() + return np.array( + self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):