Skip to content

Commit

Permalink
rng fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Feb 13, 2024
1 parent 1fc4ce2 commit 6fd50e7
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_criteo1tb_dataset(split: str,
is_training = split == 'train'
shuffle = is_training or split == 'eval_train'
ds = tf.data.Dataset.list_files(
file_paths, shuffle=shuffle, seed=shuffle_rng[0])
file_paths, shuffle=shuffle, seed=shuffle_rng)

if shuffle:
ds = ds.shuffle(buffer_size=1024)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _build_input_queue(
ds = LibriSpeechDataset(split=ds_split, data_dir=data_dir)
if split == 'eval_train':
indices = list(range(len(ds)))
random.Random(data_rng[0]).shuffle(indices)
random.Random(data_rng).shuffle(indices)
ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples])

sampler = None
Expand Down
7 changes: 4 additions & 3 deletions algorithmic_efficiency/workloads/ogbg/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import tensorflow_datasets as tfds
import torch

from algorithmic_efficiency import random_utils

AVG_NODES_PER_GRAPH = 26
AVG_EDGES_PER_GRAPH = 56

Expand All @@ -24,9 +26,8 @@
def _load_dataset(split, should_shuffle, data_rng, data_dir):
"""Loads a dataset split from TFDS."""
if should_shuffle:
file_data_rng, dataset_data_rng = jax.random.split(data_rng)
file_data_rng = file_data_rng[0]
dataset_data_rng = dataset_data_rng[0]
file_data_rng = random_utils.bits(data_rng)
dataset_data_rng = random_utils.bits(file_data_rng)
else:
file_data_rng = None
dataset_data_rng = None
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/wmt/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def filter_fn(x):
dataset = dataset.filter(length_filter(max_length))

if shuffle:
dataset = dataset.shuffle(shuffle_buffer_size, seed=data_rng[0])
dataset = dataset.shuffle(shuffle_buffer_size, seed=data_rng)

if train:
dataset = dataset.repeat()
Expand Down

0 comments on commit 6fd50e7

Please sign in to comment.