Skip to content

Commit

Permalink
Merge pull request #68 from argonne-lcf/feature/blending_corpus
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 authored Oct 21, 2024
2 parents af4cba1 + 8a8472c commit dfd0643
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
4 changes: 3 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,9 @@ def _add_data_args(parser):
group.add_argument('--data-file-list', type=str, default=None,
help='The file with the list of dataset and weights')

group.add_argument('--shuffle-sample', action='store_true', help="Whether to shuffle the samples within in the dataset files")
group.add_argument('--shuffle-sample-in-corpus', action='store_true', help="Whether to shuffle the samples within in the dataset files")

group.add_argument('--blend-sample-in-corpus', action='store_true', help="Whether to blend different files in the same corpus")

group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
Expand Down
36 changes: 28 additions & 8 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,35 @@ def Build(self):
self.build = True
return self.dataset

class BuildConcatDataset(torch.utils.data.Dataset):
class BuildCorpusDataset(torch.utils.data.Dataset):
@dlp.log
def __init__(self, dataset_builders, shuffle=False):
def __init__(self, dataset_builders):
self.dataset_builders = dataset_builders
self.num_datasets = len(dataset_builders)
self.num_samples = np.sum([d.num_samples for d in dataset_builders])
self.indices = np.zeros((self.num_samples, 2), dtype=np.uint64)
self.desc = "ConcatDataset:"
self.desc = "CorpusDataset:"
# m = 0
num_samples_list = np.array([d.num_samples for d in dataset_builders])
self.num_samples = np.sum(num_samples_list)
args = get_args()

def _build_indices():
@dlp.log
def _build_indices_blended():
start_time = time.time()
dataset_index = np.zeros(self.num_samples, dtype=np.int64)
dataset_sample_index = np.zeros(self.num_samples, dtype=np.int64)
weights = num_samples_list / self.num_samples
helpers.build_blending_indices(
dataset_index, dataset_sample_index,
weights, self.num_datasets, self.num_samples,
torch.distributed.get_rank() == 0)
log.debug(f"> elapsed time for building blendable dataset indices for corpus {self.dataset_builders[0].corpus}: "
"{:.2f} (sec)".format(time.time() - start_time))
return dataset_index, dataset_sample_index


def _build_indices_concat():
start_time = time.time()
dataset_index = np.zeros(self.num_samples, dtype=np.int64)
dataset_sample_index = np.zeros(self.num_samples, dtype=np.int64)
Expand All @@ -159,11 +175,15 @@ def _build_indices():
"{:.2f} (sec)".format(time.time() - start_time)
)
return dataset_index, dataset_sample_index

self.dataset_index, self.dataset_sample_index = _build_indices()

if args.blend_sample_in_corpus:
self.dataset_index, self.dataset_sample_index = _build_indices_blended()
else:
self.dataset_index, self.dataset_sample_index = _build_indices_concat()

np_rng = np.random.RandomState(seed=dataset_builders[0].seed)
self.shuffle_index = np.arange(self.num_samples)
if shuffle:
if args.shuffle_sample_in_corpus:
np_rng.shuffle(self.shuffle_index)
for i in range(self.num_datasets):
self.desc += dataset_builders[i].prefix + ","
Expand Down Expand Up @@ -247,7 +267,7 @@ def build_corpus_datasets(dataset_type="train"):
log.debug(" > number of samples for each corpus ")
corpus_weights_achieved = {}
for c in corpus_list:
datasets.append(BuildConcatDataset(corpus_builders[c], args.shuffle_sample))
datasets.append(BuildCorpusDataset(corpus_builders[c]))
total += datasets[-1].num_samples
corpus_weights_achieved[c] = (
float(datasets[-1].num_samples) / train_num_samples
Expand Down

0 comments on commit dfd0643

Please sign in to comment.