Skip to content

Commit

Permalink
Merge pull request #63 from argonne-lcf/hzheng-data-fix
Browse files Browse the repository at this point in the history
[merge]: into `microsoft-main` $\leftarrow$ from `hzheng-data-fix`
  • Loading branch information
saforem2 authored Nov 5, 2024
2 parents fd94b37 + 160d6a6 commit 40db8c2
Show file tree
Hide file tree
Showing 13 changed files with 1,292 additions and 689 deletions.
579 changes: 375 additions & 204 deletions ALCF/helpers.sh

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ALCF/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ six
numpy<2
schedulefree
packaging>=20.0
wandb
40 changes: 39 additions & 1 deletion ALCF/test_blendable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python
import time
import json
start_time = time.time()
from mpi4py import MPI
import os
Expand Down Expand Up @@ -37,7 +38,7 @@ def print_rank_0(msg):

os.makedirs(args.trace_dir, exist_ok=True)


corpus_all = []
data_file_list = args.data_file_list
print_rank_0(f"Reading data from {args.data_file_list}")
files = []
Expand All @@ -51,6 +52,9 @@ def print_rank_0(msg):
files.append(float(w))
files.append(fname)
files.append(c)
if c not in corpus_all:
corpus_all.append(c)

splits_string="100,0,0"

weights = np.array(weights)
Expand Down Expand Up @@ -82,6 +86,40 @@ def print_rank_0(msg):
print_rank_0(f"Total number of samples: {len(train_ds)}")
print_rank_0(f"Weights set: {weights[:min(8, num_datasets)]}")


def get_sample_info(blendable_dataset, idx):
# corpus dataset
cd = blendable_dataset.dataset_index[idx]
# index within the corpus dataset
cds = blendable_dataset.dataset_sample_index[idx]
# dataset index within each corpus
fcd = blendable_dataset.datasets[cd].dataset_index[cds]
# sample index within the dataset
fcds = blendable_dataset.datasets[cd].dataset_sample_index[cds]
# corresponding data file
prefix = blendable_dataset.datasets[cd].dataset_builders[fcd].prefix
corpus = blendable_dataset.datasets[cd].dataset_builders[fcd].corpus
#v = blendable_dataset[idx]['text']
#norm = np.linalg.norm(v)
return prefix, corpus, fcds

num_batches = args.train_iters
print(f"global_batch_size: {args.global_batch_size}")
print(f"number of batches: {num_batches}")

fout = open("samples_list.jsonl", "w")
if comm.rank == 0:
for i in range(num_batches):
ns_corpus = {}
for c in corpus_all:
ns_corpus[c] = 0
for j in range(args.global_batch_size):
prefix, corpus, idx = get_sample_info(train_ds, i*args.global_batch_size+j)
ns_corpus[corpus] +=1
fout.write(f"\u007b 'batch': {i}, 'sample': {j}, 'corpus': '{corpus}', 'prefix': '{prefix}', 'dataset_sample_index': {idx} \u007d\n")
fout.write(f"\u007b 'batch': {i}, 'histogram': {ns_corpus} \u007d \n")
comm.Barrier()
exit()
start_build_dataloader = time.time()
print_rank_0(f"Starting to build the data loader")
rank_in_parallel_group = mpu.get_sequence_parallel_rank()
Expand Down
4 changes: 4 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,10 @@ def _add_data_args(parser):
group.add_argument('--data-file-list', type=str, default=None,
help='The file with the list of dataset and weights')

group.add_argument('--shuffle-sample-in-corpus', action='store_true', help="Whether to shuffle the samples within in the dataset files")

group.add_argument('--blend-sample-in-corpus', action='store_true', help="Whether to blend different files in the same corpus")

group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
Expand Down
20 changes: 11 additions & 9 deletions megatron/core/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from megatron.core import ModelParallelConfig
from deepspeed.accelerator import get_accelerator

from megatron.utils import Profile
dlp = Profile("PIPELINE")
# Types
Shape = Union[List[int], torch.Size]

Expand Down Expand Up @@ -329,6 +330,7 @@ def _ring_exchange_wrapper(**kwargs):
return tensor_recv_prev, tensor_recv_next, reqs


@dlp.log
def recv_forward(tensor_shape: Shape,
config: ModelParallelConfig) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
Expand All @@ -353,7 +355,7 @@ def recv_forward(tensor_shape: Shape,
config.timers('forward-recv').stop()
return input_tensor


@dlp.log
def recv_backward(tensor_shape: Shape,
config: ModelParallelConfig) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
Expand All @@ -376,7 +378,7 @@ def recv_backward(tensor_shape: Shape,
config.timers('backward-recv').stop()
return output_tensor_grad


@dlp.log
def send_forward(output_tensor: torch.Tensor,
config: ModelParallelConfig) -> None:
"""Send tensor to next rank in pipeline (forward send).
Expand All @@ -397,7 +399,7 @@ def send_forward(output_tensor: torch.Tensor,
if config.timers is not None:
config.timers('forward-send').stop()


@dlp.log
def send_backward(input_tensor_grad: torch.Tensor,
config: ModelParallelConfig) -> None:
"""Send tensor to previous rank in pipeline (backward send).
Expand All @@ -417,7 +419,7 @@ def send_backward(input_tensor_grad: torch.Tensor,
if config.timers is not None:
config.timers('backward-send').stop()


@dlp.log
def send_forward_recv_backward(output_tensor: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig) -> torch.Tensor:
Expand All @@ -441,7 +443,7 @@ def send_forward_recv_backward(output_tensor: torch.Tensor,
config.timers('forward-send-backward-recv').stop()
return output_tensor_grad


@dlp.log
def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig) -> torch.Tensor:
Expand All @@ -465,7 +467,7 @@ def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
config.timers('backward-send-forward-recv').stop()
return input_tensor


@dlp.log
def send_forward_recv_forward(output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
Expand All @@ -491,7 +493,7 @@ def send_forward_recv_forward(output_tensor: torch.Tensor,
return input_tensor, wait_handles
return input_tensor


@dlp.log
def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
Expand All @@ -517,7 +519,7 @@ def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
return output_tensor_grad, wait_handles
return output_tensor_grad


@dlp.log
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
Expand Down
38 changes: 23 additions & 15 deletions megatron/data/blendable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
import os
import time

import logging
import numpy as np
import torch

from deepspeed.accelerator import get_accelerator
from megatron import print_rank_0
# from megatron import print_rank_0
from megatron.core import mpu
from megatron.utils import Profile, PerfTrace
from mpi4py import MPI

from megatron.utils import get_logger

log = get_logger(__name__, rank_zero_only=True)

dlp = Profile("DATASET")
class BlendableDataset(torch.utils.data.Dataset):
@dlp.log
Expand All @@ -35,16 +41,18 @@ def __init__(self, datasets, weights, size, *,
# Build indicies.
@dlp.log
def _build_indices():
start_time = time.time()
start_time = time.perf_counter()
dataset_index = np.zeros(self.size, dtype=np.int64)
dataset_sample_index = np.zeros(self.size, dtype=np.int64)

from megatron.data import helpers
helpers.build_blending_indices(dataset_index, dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
print_rank_0('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))
log.info(
"> elapsed time for building blendable dataset indices: "
f"{time.perf_counter() - start_time:.2f} (sec)"
)
return dataset_index, dataset_sample_index

desc = "Blendable dataset\n\n"
Expand All @@ -68,15 +76,15 @@ def _build_indices():
' dataset, building indices on rank 0 ...', flush=True)
dataset_index, dataset_sample_index = _build_indices()
try:
print_rank_0(" > saving index map files")
start_time = time.time()
log.debug(" > saving index map files")
start_time = time.perf_counter()
os.makedirs(os.path.dirname(index_path), exist_ok=True)
with open(desc_path, 'wt') as fd:
fd.write(desc)
np.save(index_path, dataset_index, allow_pickle=True)
np.save(sample_index_path, dataset_sample_index,
allow_pickle=True)
print_rank_0(f" > finished saving index map files in {time.time() - start_time} seconds")
log.info(f" > finished saving index map files in {time.perf_counter() - start_time} seconds")
except OSError:
print(f'There was an error trying to create the data cache directory ({data_cache_path})')
print('or a file in it. This is set with the --data-cache-path argument. Please')
Expand All @@ -93,21 +101,21 @@ def _build_indices():
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()) //
torch.distributed.get_world_size(group=mpu.get_sequence_parallel_group())):
print_rank_0("Data index creation unsuccessful, exiting.")
log.info("Data index creation unsuccessful, exiting.")
exit()
'''
torch.distributed.barrier(group=mpu.get_data_parallel_group())
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
torch.distributed.barrier(group=mpu.get_data_parallel_group())
start_time = time.time()
print_rank_0(f'> loading blendable dataset index: {index_path}')

start_time = time.perf_counter()
log.info(f'> loading blendable dataset index: {index_path}')
self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r')
assert self.dataset_index.size == self.size
print_rank_0(f'> loading blendable dataset sample index: {sample_index_path}')
log.info(f'> loading blendable dataset sample index: {sample_index_path}')
self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r')
assert self.dataset_sample_index.size == self.size
print_rank_0(f'> finished loading in {time.time() - start_time} seconds')
log.info(f'> finished loading in {time.perf_counter() - start_time} seconds')
else:
self.dataset_index, self.dataset_sample_index = _build_indices()

Expand All @@ -119,7 +127,7 @@ def _build_indices():
raise RuntimeError('BlendedDataset size is improperly bounded')
except IndexError:
pass
print_rank_0('> size of blendable dataset: '
log.info('> size of blendable dataset: '
'{} samples'.format(self.size))


Expand All @@ -133,4 +141,4 @@ def __getitem__(self, idx):
return {
"dataset_idx" : dataset_idx,
**self.datasets[dataset_idx][sample_idx],
}
}
Loading

0 comments on commit 40db8c2

Please sign in to comment.