diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index db97735cb..b4a26b5b0 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -36,14 +36,14 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: int) -> SeedType: rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed)) new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data)) + rng_2 = np.random.RandomState(seed=(_signed_to_unsigned(data) & 0xffffffff)) new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32) return new_seed_1 + new_seed_2 def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name @@ -60,6 +60,11 @@ def _check_jax_install() -> None: '--framework=pytorch to use the Numpy version instead.') +def _randint(seed: SeedType) -> int: + rng = np.random.RandomState(_signed_to_unsigned(seed)) + return rng.randint(MAX_INT32) + + def fold_in(seed: SeedType, data: int) -> SeedType: if FLAGS.framework == 'jax': _check_jax_install() @@ -79,3 +84,10 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name _check_jax_install() return jax_rng.PRNGKey(seed) return _PRNGKey(seed) + + +def randint(seed:SeedType) -> int: + if FLAGS.framework == 'jax': + _check_jax_install() + return jax_rng.randint(seed, ) + return _randint(seed) \ No newline at end of file diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 077ce8d4f..a464da341 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -148,7 +148,7 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: run_key = prng.fold_in(rng_subkey, hash(workload)) - run_seed = run_key[0] # arbitrary + run_seed = prng.randint(run_key) # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system(