Skip to content

Commit

Permalink
Merge pull request #525 from mlcommons/juhan/db_cri
Browse files Browse the repository at this point in the history
Criteo Jax loss float64 cast
  • Loading branch information
priyakasimbeg authored Oct 16, 2023
2 parents 2f5774b + a1844c7 commit 45a7730
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
Expand Down Expand Up @@ -147,7 +148,8 @@ def _eval_batch(self,
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
# shape (local_device_count,) will all be different values.
return self._eval_batch_pmapped(params, batch).sum()
return np.array(
self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64)


class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
Expand Down

0 comments on commit 45a7730

Please sign in to comment.