From c79433927a4f4d6024aba43cf00f78f2101a1ef5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 2 Feb 2024 01:46:51 +0000 Subject: [PATCH] fix fold_in in pytorch --- algorithmic_efficiency/random_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 68e9a9cfe..db97735cb 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -33,10 +33,12 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()]) -def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: - rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - return [new_seed, data] +def _fold_in(seed: SeedType, data: int) -> SeedType: + rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed)) + new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data)) + new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + return new_seed_1 + new_seed_2 def _split(seed: SeedType, num: int = 2) -> SeedType: @@ -58,7 +60,7 @@ def _check_jax_install() -> None: '--framework=pytorch to use the Numpy version instead.') -def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: +def fold_in(seed: SeedType, data: int) -> SeedType: if FLAGS.framework == 'jax': _check_jax_install() return jax_rng.fold_in(seed, data)