Skip to content

Commit

Permalink
Merge pull request #625 from mlcommons/random_utils_fixes
Browse files Browse the repository at this point in the history
Changes to random utils
  • Loading branch information
priyakasimbeg authored Feb 7, 2024
2 parents ff3c9b0 + 1fc4ce2 commit 3421276
Show file tree
Hide file tree
Showing 15 changed files with 44 additions and 23 deletions.
35 changes: 28 additions & 7 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Proxy functions in front of the Jax RNG API or a compatible Numpy RNG API."""

from typing import Any, List, Union
from typing import Union

from absl import flags
from absl import logging
Expand All @@ -21,6 +21,12 @@
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32

# SALT constants
_SALT1 = np.random.RandomState(seed=5).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)
_SALT2 = np.random.RandomState(seed=6).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)

SeedType = Union[int, list, np.ndarray]


Expand All @@ -33,15 +39,19 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:
return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
return [new_seed, data]
def _fold_in(seed: SeedType, data: int) -> SeedType:
a = np.random.RandomState(seed=_signed_to_unsigned(seed ^ _SALT1)).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)
b = np.random.RandomState(seed=_signed_to_unsigned(data ^ _SALT2)).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)
c = np.random.RandomState(seed=_signed_to_unsigned(a ^ b)).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)
return c


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand All @@ -58,7 +68,11 @@ def _check_jax_install() -> None:
'--framework=pytorch to use the Numpy version instead.')


def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
def _bits(seed: SeedType) -> int:
return seed


def fold_in(seed: SeedType, data: int) -> SeedType:
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.fold_in(seed, data)
Expand All @@ -77,3 +91,10 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
_check_jax_install()
return jax_rng.PRNGKey(seed)
return _PRNGKey(seed)


def bits(seed: SeedType) -> int:
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.bits(seed)
return _bits(seed)
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _build_dataset(
}
if split == 'eval_train':
train_indices = indices_split['train']
random.Random(data_rng[0]).shuffle(train_indices)
random.Random(data_rng).shuffle(train_indices)
indices_split['eval_train'] = train_indices[:self.num_eval_train_examples]
if split in indices_split:
dataset = torch.utils.data.Subset(dataset, indices_split[split])
Expand Down Expand Up @@ -111,7 +111,7 @@ def init_model_fn(
self._model.reset_parameters()
return self._model, None

torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
self._model = resnet18(num_classes=self._num_classes)
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def init_model_fn(
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""Only dropout is used."""
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
if self.use_resnet:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _build_dataset(

if split == 'eval_train':
indices = list(range(self.num_train_examples))
random.Random(data_rng[0]).shuffle(indices)
random.Random(data_rng).shuffle(indices)
dataset = torch.utils.data.Subset(dataset,
indices[:self.num_eval_train_examples])

Expand Down Expand Up @@ -147,7 +147,7 @@ def init_model_fn(
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)

if self.use_silu and self.use_gelu:
raise RuntimeError('Cannot use both GELU and SiLU activations.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = models.ViT(
dropout_rate=dropout_rate,
num_classes=self._num_classes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def init_model_fn(
Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as
input_dropout_rate.
"""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
# Configure torch backends to avoid OOM errors.
torch.backends.cudnn.benchmark = False
torch.backends.cuda.enable_flash_sdp(False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def init_model_fn(
Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate
as input_dropout_rate.
"""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = DeepspeechEncoderDecoder(
DeepspeechConfig(
feed_forward_dropout_rate=dropout_rate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def init_model_fn(
self._model.reset_parameters()
return self._model, None

torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
self._model = _Model()
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/mnist/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _build_mnist_dataset(

if shuffle:
ds = ds.repeat()
ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
ds = ds.shuffle(16 * global_batch_size, seed=prng.bits(data_rng))
ds = ds.batch(global_batch_size, drop_remainder=is_train)

if repeat_final_dataset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def init_model_fn(
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = GNN(
num_outputs=self._num_outputs,
dropout_rate=dropout_rate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""aux_dropout_rate is used as attention_dropout_rate."""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)

if self.activation == 'relu':
activation = F.relu
Expand Down
2 changes: 1 addition & 1 deletion scoring/run_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def main(_):
# For each runnable workload check if there are any containers running and if not launch next container command
for workload in workloads:
run_key = prng.fold_in(rng_subkey, hash(workload))
run_seed = run_key[0] # arbitrary
run_seed = prng.bits(run_key)
base_workload_name = get_base_workload_name(workload)
wait_until_container_not_running()
os.system(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_param_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_workload(workload):
else:
raise ValueError(f'Workload {workload} is not available.')
_ = jax_workload.init_model_fn(jax.random.PRNGKey(0))
_ = pytorch_workload.init_model_fn([0])
_ = pytorch_workload.init_model_fn(0)
return jax_workload, pytorch_workload


Expand Down
2 changes: 1 addition & 1 deletion tests/test_param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def get_workload(workload_name):
else:
raise ValueError(f'Workload {workload_name} is not available.')
_ = jax_workload.init_model_fn(jax.random.PRNGKey(0))
_ = pytorch_workload.init_model_fn([0])
_ = pytorch_workload.init_model_fn(0)
return jax_workload, pytorch_workload


Expand Down

0 comments on commit 3421276

Please sign in to comment.