Skip to content

Commit

Permalink
fixed bugs and added commandline option
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghh04 committed Oct 14, 2024
1 parent 573b668 commit 3dcb297
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,8 @@ def _add_data_args(parser):
group.add_argument('--data-file-list', type=str, default=None,
help='The file with the list of dataset and weights')

group.add_argument('--shuffle-sample', action='stored_true', help="Whether to shuffle the samples within in the dataset files")

group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
Expand Down
10 changes: 6 additions & 4 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def Build(self):

class BuildConcatDataset(torch.utils.data.Dataset):
@dlp.log
def __init__(self, dataset_builders):
def __init__(self, dataset_builders, shuffle=False):
self.dataset_builders = dataset_builders
self.num_datasets = len(dataset_builders)
self.num_samples = np.sum([d.num_samples for d in dataset_builders])
Expand All @@ -117,7 +117,9 @@ def _build_indices():

self.dataset_index, self.dataset_sample_index = _build_indices()
np_rng = np.random.RandomState(seed=dataset_builders[0].seed)
self.shuffle_index=np_rng.shuffle(range(self.num_samples))
self.shuffle_index = np.arange(self.num_samples)
if shuffle:
np_rng.shuffle(self.shuffle_index)
for i in range(self.num_datasets):
self.desc += dataset_builders[i].prefix + ","

Expand Down Expand Up @@ -146,7 +148,7 @@ def __getitem__(self, idx):
valid_datasets = []
test_datasets = []
# Build individual datasets.

args = get_args()
@dlp.log
def build_corpus_datasets(dataset_type='train'):
start_time = time.time()
Expand All @@ -172,7 +174,7 @@ def build_corpus_datasets(dataset_type='train'):
print_rank_0(" > number of samples for each corpus ")
corpus_weights_achieved={}
for c in corpus_list:
datasets.append(BuildConcatDataset(corpus_builders[c]))
datasets.append(BuildConcatDataset(corpus_builders[c], args.shuffle_sample))
total += datasets[-1].num_samples
corpus_weights_achieved[c] = float(datasets[-1].num_samples)/train_num_samples
print_rank_0(f" {c}: {datasets[-1].num_samples} w={corpus_weights_achieved[c]} (expected: {corpus_weights[c]})")
Expand Down

0 comments on commit 3dcb297

Please sign in to comment.