Skip to content

Commit

Permalink
fix fold_in in pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Feb 2, 2024
1 parent ff3c9b0 commit c794339
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:
return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
return [new_seed, data]
def _fold_in(seed: SeedType, data: int) -> SeedType:
rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data))
new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
return new_seed_1 + new_seed_2


def _split(seed: SeedType, num: int = 2) -> SeedType:
Expand All @@ -58,7 +60,7 @@ def _check_jax_install() -> None:
'--framework=pytorch to use the Numpy version instead.')


def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
def fold_in(seed: SeedType, data: int) -> SeedType:
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.fold_in(seed, data)
Expand Down

0 comments on commit c794339

Please sign in to comment.