diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 68e9a9cfe..db97735cb 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -33,10 +33,12 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()]) -def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: - rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - return [new_seed, data] +def _fold_in(seed: SeedType, data: int) -> SeedType: + rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed)) + new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data)) + new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + return new_seed_1 + new_seed_2 def _split(seed: SeedType, num: int = 2) -> SeedType: @@ -58,7 +60,7 @@ def _check_jax_install() -> None: '--framework=pytorch to use the Numpy version instead.') -def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: +def fold_in(seed: SeedType, data: int) -> SeedType: if FLAGS.framework == 'jax': _check_jax_install() return jax_rng.fold_in(seed, data)