diff --git a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py b/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py index cb091b3a5..dc6cd7aef 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py +++ b/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py @@ -111,7 +111,7 @@ def get_criteo1tb_dataset(split: str, is_training = split == 'train' shuffle = is_training or split == 'eval_train' ds = tf.data.Dataset.list_files( - file_paths, shuffle=shuffle, seed=shuffle_rng[0]) + file_paths, shuffle=shuffle, seed=shuffle_rng) if shuffle: ds = ds.shuffle(buffer_size=1024) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 9f0a6f841..9bb8232e9 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -152,7 +152,7 @@ def _build_input_queue( ds = LibriSpeechDataset(split=ds_split, data_dir=data_dir) if split == 'eval_train': indices = list(range(len(ds))) - random.Random(data_rng[0]).shuffle(indices) + random.Random(data_rng).shuffle(indices) ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples]) sampler = None diff --git a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py b/algorithmic_efficiency/workloads/ogbg/input_pipeline.py index a301d677a..3ccad6b79 100644 --- a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogbg/input_pipeline.py @@ -10,6 +10,8 @@ import tensorflow_datasets as tfds import torch +from algorithmic_efficiency import random_utils + AVG_NODES_PER_GRAPH = 26 AVG_EDGES_PER_GRAPH = 56 @@ -24,9 +26,8 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir): """Loads a dataset split from TFDS.""" if should_shuffle: - file_data_rng, dataset_data_rng = jax.random.split(data_rng) - file_data_rng = file_data_rng[0] - dataset_data_rng = dataset_data_rng[0] + file_data_rng = random_utils.bits(data_rng) + dataset_data_rng = random_utils.bits(file_data_rng) else: file_data_rng = None dataset_data_rng = None diff --git a/algorithmic_efficiency/workloads/wmt/input_pipeline.py b/algorithmic_efficiency/workloads/wmt/input_pipeline.py index af1c54994..97cd75d7e 100644 --- a/algorithmic_efficiency/workloads/wmt/input_pipeline.py +++ b/algorithmic_efficiency/workloads/wmt/input_pipeline.py @@ -234,7 +234,7 @@ def filter_fn(x): dataset = dataset.filter(length_filter(max_length)) if shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, seed=data_rng[0]) + dataset = dataset.shuffle(shuffle_buffer_size, seed=data_rng) if train: dataset = dataset.repeat()