diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 37fc08e5f..68e9a9cfe 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -1,6 +1,6 @@ """Proxy functions in front of the Jax RNG API or a compatible Numpy RNG API.""" -from typing import Union +from typing import Any, List, Union from absl import flags from absl import logging @@ -21,12 +21,6 @@ MAX_INT32 = 2**31 MIN_INT32 = -MAX_INT32 -# SALT constants -_SALT1 = np.random.RandomState(seed=5).randint( - MIN_INT32, MAX_INT32, dtype=np.int32) -_SALT2 = np.random.RandomState(seed=6).randint( - MIN_INT32, MAX_INT32, dtype=np.int32) - SeedType = Union[int, list, np.ndarray] @@ -39,19 +33,15 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()]) -def _fold_in(seed: SeedType, data: int) -> SeedType: - a = np.random.RandomState(seed=_signed_to_unsigned(seed ^ _SALT1)).randint( - MIN_INT32, MAX_INT32, dtype=np.int32) - b = np.random.RandomState(seed=_signed_to_unsigned(data ^ _SALT2)).randint( - MIN_INT32, MAX_INT32, dtype=np.int32) - c = np.random.RandomState(seed=_signed_to_unsigned(a ^ b)).randint( - MIN_INT32, MAX_INT32, dtype=np.int32) - return c +def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: + rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name @@ -68,11 +58,7 @@ def _check_jax_install() -> None: '--framework=pytorch to use the Numpy version instead.') -def _bits(seed: SeedType) -> int: - return seed - - -def fold_in(seed: SeedType, data: int) -> SeedType: +def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: if FLAGS.framework == 'jax': _check_jax_install() return jax_rng.fold_in(seed, data) @@ -91,10 +77,3 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name _check_jax_install() return jax_rng.PRNGKey(seed) return _PRNGKey(seed) - - -def bits(seed: SeedType) -> int: - if FLAGS.framework == 'jax': - _check_jax_install() - return jax_rng.bits(seed) - return _bits(seed) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py index e2d655e9b..af86c212e 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py @@ -65,7 +65,7 @@ def _build_dataset( } if split == 'eval_train': train_indices = indices_split['train'] - random.Random(data_rng).shuffle(train_indices) + random.Random(data_rng[0]).shuffle(train_indices) indices_split['eval_train'] = train_indices[:self.num_eval_train_examples] if split in indices_split: dataset = torch.utils.data.Subset(dataset, indices_split[split]) @@ -111,7 +111,7 @@ def init_model_fn( self._model.reset_parameters() return self._model, None - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) self._model = resnet18(num_classes=self._num_classes) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index bff5fa837..85bb602d1 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -72,7 +72,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Only dropout is used.""" del aux_dropout_rate - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False if self.use_resnet: diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 0ad1b3eeb..74f6aa13d 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -113,7 +113,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) model = UNet( num_pool_layers=self.num_pool_layers, num_channels=self.num_channels, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index ba2012644..6727054c9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -103,7 +103,7 @@ def _build_dataset( if split == 'eval_train': indices = list(range(self.num_train_examples)) - random.Random(data_rng).shuffle(indices) + random.Random(data_rng[0]).shuffle(indices) dataset = torch.utils.data.Subset(dataset, indices[:self.num_eval_train_examples]) @@ -147,7 +147,7 @@ def init_model_fn( """Dropout is unused.""" del dropout_rate del aux_dropout_rate - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) if self.use_silu and self.use_gelu: raise RuntimeError('Cannot use both GELU and SiLU activations.') diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index aec3f1aaf..e672e8d22 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -30,7 +30,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 9f0a6f841..20f27b150 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -58,7 +58,7 @@ def init_model_fn( Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as input_dropout_rate. """ - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False torch.backends.cuda.enable_flash_sdp(False) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index c968b528d..bcdd78fb5 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -32,7 +32,7 @@ def init_model_fn( Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate as input_dropout_rate. """ - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( DeepspeechConfig( feed_forward_dropout_rate=dropout_rate, diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index a60e6040e..e638df078 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -133,7 +133,7 @@ def init_model_fn( self._model.reset_parameters() return self._model, None - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) self._model = _Model() self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 2b593948c..dcc195170 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -55,7 +55,7 @@ def _build_mnist_dataset( if shuffle: ds = ds.repeat() - ds = ds.shuffle(16 * global_batch_size, seed=prng.bits(data_rng)) + ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0]) ds = ds.batch(global_batch_size, drop_remainder=is_train) if repeat_final_dataset: diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index 84a445c4b..d4817226d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -143,7 +143,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is unused.""" del aux_dropout_rate - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) model = GNN( num_outputs=self._num_outputs, dropout_rate=dropout_rate, diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 9ee959a4f..9f6d817f4 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -171,7 +171,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" - torch.random.manual_seed(rng) + torch.random.manual_seed(rng[0]) if self.activation == 'relu': activation = F.relu diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 83d7a5f65..077ce8d4f 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -148,7 +148,7 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: run_key = prng.fold_in(rng_subkey, hash(workload)) - run_seed = prng.bits(run_key) + run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system( diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index afa752cb5..b67625213 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -113,7 +113,7 @@ def get_workload(workload): else: raise ValueError(f'Workload {workload} is not available.') _ = jax_workload.init_model_fn(jax.random.PRNGKey(0)) - _ = pytorch_workload.init_model_fn(0) + _ = pytorch_workload.init_model_fn([0]) return jax_workload, pytorch_workload diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 639c7372d..7cf8f63c3 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -221,7 +221,7 @@ def get_workload(workload_name): else: raise ValueError(f'Workload {workload_name} is not available.') _ = jax_workload.init_model_fn(jax.random.PRNGKey(0)) - _ = pytorch_workload.init_model_fn(0) + _ = pytorch_workload.init_model_fn([0]) return jax_workload, pytorch_workload