Skip to content

Commit

Permalink
random utils fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Feb 2, 2024
1 parent c794339 commit 247dcb0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
16 changes: 14 additions & 2 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:
def _fold_in(seed: SeedType, data: int) -> SeedType:
rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data))
rng_2 = np.random.RandomState(seed=(_signed_to_unsigned(data) & 0xffffffff))
new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
return new_seed_1 + new_seed_2


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand All @@ -60,6 +60,11 @@ def _check_jax_install() -> None:
'--framework=pytorch to use the Numpy version instead.')


def _randint(seed: SeedType) -> int:
rng = np.random.RandomState(_signed_to_unsigned(seed))
return rng.randint(MAX_INT32)


def fold_in(seed: SeedType, data: int) -> SeedType:
if FLAGS.framework == 'jax':
_check_jax_install()
Expand All @@ -79,3 +84,10 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
_check_jax_install()
return jax_rng.PRNGKey(seed)
return _PRNGKey(seed)


def randint(seed:SeedType) -> int:
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.randint(seed, )
return _randint(seed)
2 changes: 1 addition & 1 deletion scoring/run_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def main(_):
# For each runnable workload check if there are any containers running and if not launch next container command
for workload in workloads:
run_key = prng.fold_in(rng_subkey, hash(workload))
run_seed = run_key[0] # arbitrary
run_seed = prng.randint(run_key) # arbitrary
base_workload_name = get_base_workload_name(workload)
wait_until_container_not_running()
os.system(
Expand Down

0 comments on commit 247dcb0

Please sign in to comment.