From d50239f1c5470fd2b2495affc01773e063942e37 Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Mon, 14 Oct 2024 14:20:56 -0500 Subject: [PATCH 1/2] added support for blending samples across different files in the same corpus --- megatron/arguments.py | 4 +++- megatron/data/gpt_dataset.py | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 9ab3e40953..9b0e6ccb1a 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1291,7 +1291,9 @@ def _add_data_args(parser): group.add_argument('--data-file-list', type=str, default=None, help='The file with the list of dataset and weights') - group.add_argument('--shuffle-sample', action='store_true', help="Whether to shuffle the samples within in the dataset files") + group.add_argument('--shuffle-sample-in-corpus', action='store_true', help="Whether to shuffle the samples within in the dataset files") + + group.add_argument('--blend-sample-in-corpus', action='store_true', help="Whether to blend different files in the same corpus") group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index c412d02b31..38df556267 100755 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -131,19 +131,35 @@ def Build(self): self.build = True return self.dataset - class BuildConcatDataset(torch.utils.data.Dataset): + class BuildCorpusDataset(torch.utils.data.Dataset): @dlp.log - def __init__(self, dataset_builders, shuffle=False): + def __init__(self, dataset_builders): self.dataset_builders = dataset_builders self.num_datasets = len(dataset_builders) self.num_samples = np.sum([d.num_samples for d in dataset_builders]) self.indices = np.zeros((self.num_samples, 2), dtype=np.uint64) - self.desc = "ConcatDataset:" + self.desc = "CorpusDataset:" # m = 0 num_samples_list = np.array([d.num_samples for d in dataset_builders]) self.num_samples = np.sum(num_samples_list) + args = get_args() - def _build_indices(): + @dlp.log + def _build_indices_blended(): + start_time = time.time() + dataset_index = np.zeros(self.num_samples, dtype=np.int64) + dataset_sample_index = np.zeros(self.num_samples, dtype=np.int64) + weights = num_samples_list / self.num_samples + helpers.build_blending_indices( + dataset_index, dataset_sample_index, + weights, self.num_datasets, self.num_samples, + torch.distributed.get_rank() == 0) + log.debug('> elapsed time for building blendable dataset indices for corpus {self.dataset_builders[0].corpus}: ' + '{:.2f} (sec)'.format(time.time() - start_time)) + return dataset_index, dataset_sample_index + + + def _build_indices_concat(): start_time = time.time() dataset_index = np.zeros(self.num_samples, dtype=np.int64) dataset_sample_index = np.zeros(self.num_samples, dtype=np.int64) @@ -159,11 +175,15 @@ def _build_indices(): "{:.2f} (sec)".format(time.time() - start_time) ) return dataset_index, dataset_sample_index - - self.dataset_index, self.dataset_sample_index = _build_indices() + + if args.blend_sample_in_corpus: + self.dataset_index, self.dataset_sample_index = _build_indices_blended() + else: + self.dataset_index, self.dataset_sample_index = _build_indices_concat() + np_rng = np.random.RandomState(seed=dataset_builders[0].seed) self.shuffle_index = np.arange(self.num_samples) - if shuffle: + if args.shuffle_sample_in_corpus: np_rng.shuffle(self.shuffle_index) for i in range(self.num_datasets): self.desc += dataset_builders[i].prefix + "," @@ -243,7 +263,7 @@ def build_corpus_datasets(dataset_type="train"): log.debug(" > number of samples for each corpus ") corpus_weights_achieved = {} for c in corpus_list: - datasets.append(BuildConcatDataset(corpus_builders[c], args.shuffle_sample)) + datasets.append(BuildCorpusDataset(corpus_builders[c])) total += datasets[-1].num_samples corpus_weights_achieved[c] = ( float(datasets[-1].num_samples) / train_num_samples From 8a8472c7bd83a3c9bc13f27b54c099ab1cde98b6 Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Mon, 21 Oct 2024 19:35:57 +0000 Subject: [PATCH 2/2] fixed print out bug --- megatron/data/gpt_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 38df556267..9ff2703277 100755 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -154,8 +154,8 @@ def _build_indices_blended(): dataset_index, dataset_sample_index, weights, self.num_datasets, self.num_samples, torch.distributed.get_rank() == 0) - log.debug('> elapsed time for building blendable dataset indices for corpus {self.dataset_builders[0].corpus}: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + log.debug(f"> elapsed time for building blendable dataset indices for corpus {self.dataset_builders[0].corpus}: " + "{:.2f} (sec)".format(time.time() - start_time)) return dataset_index, dataset_sample_index