Skip to content

Commit

Permalink
Merge pull request #67 from argonne-lcf/feature/profile
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 authored Oct 14, 2024
2 parents 5e9eed0 + 94d5337 commit bb55e97
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,8 @@ def _add_data_args(parser):
group.add_argument('--data-file-list', type=str, default=None,
help='The file with the list of dataset and weights')

group.add_argument('--shuffle-sample', action='store_true', help="Whether to shuffle the samples within in the dataset files")

group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
Expand Down
9 changes: 5 additions & 4 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def Build(self):

class BuildConcatDataset(torch.utils.data.Dataset):
@dlp.log
def __init__(self, dataset_builders):
def __init__(self, dataset_builders, shuffle=False):
self.dataset_builders = dataset_builders
self.num_datasets = len(dataset_builders)
self.num_samples = np.sum([d.num_samples for d in dataset_builders])
Expand Down Expand Up @@ -163,7 +163,8 @@ def _build_indices():
self.dataset_index, self.dataset_sample_index = _build_indices()
np_rng = np.random.RandomState(seed=dataset_builders[0].seed)
self.shuffle_index = np.arange(self.num_samples)
np_rng.shuffle(self.shuffle_index)
if shuffle:
np_rng.shuffle(self.shuffle_index)
for i in range(self.num_datasets):
self.desc += dataset_builders[i].prefix + ","

Expand Down Expand Up @@ -196,7 +197,7 @@ def __getitem__(self, idx):
valid_datasets = []
test_datasets = []
# Build individual datasets.

args = get_args()
@dlp.log
def build_corpus_datasets(dataset_type="train"):
start_time = time.time()
Expand Down Expand Up @@ -242,7 +243,7 @@ def build_corpus_datasets(dataset_type="train"):
log.debug(" > number of samples for each corpus ")
corpus_weights_achieved = {}
for c in corpus_list:
datasets.append(BuildConcatDataset(corpus_builders[c]))
datasets.append(BuildConcatDataset(corpus_builders[c], args.shuffle_sample))
total += datasets[-1].num_samples
corpus_weights_achieved[c] = (
float(datasets[-1].num_samples) / train_num_samples
Expand Down

0 comments on commit bb55e97

Please sign in to comment.