Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull in from upstream #10

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ea0c3c7
fixed dftracer compatibility
zhenghh04 Aug 30, 2024
a0ac750
added requirements.txt
zhenghh04 Sep 3, 2024
de7f22f
Update utils.py
zhenghh04 Sep 4, 2024
12f6f8e
fix check
zhenghh04 Sep 12, 2024
5394156
shuffle concate dataset index
zhenghh04 Oct 12, 2024
573b668
fixed bugs
zhenghh04 Oct 12, 2024
89db92a
merge: `feature/profile` with data fix into `microsoft-main`
saforem2 Oct 12, 2024
9de83a9
Fix `shuffle_idx` in `megatron/data/gpt_dataset.py`
saforem2 Oct 12, 2024
d7a2594
Fix `shuffle_idx` in `megatron/data/gpt_dataset.py`
saforem2 Oct 12, 2024
3e33a6a
Update `ALCF/helpers.sh`, `train_aGPT_7B.sh`
saforem2 Oct 13, 2024
43cde2b
Update `pretrain_gpt_alcf.py`
saforem2 Oct 13, 2024
9f09733
Update `megatron/data/{blendable,gpt,indexed}_dataset.py`
saforem2 Oct 13, 2024
2b31b44
Update `ALCF/requirements/requirements.txt`
saforem2 Oct 13, 2024
5e9eed0
Update `megatron/utils.py`
saforem2 Oct 13, 2024
3dcb297
fixed bugs and added commandline option
zhenghh04 Oct 14, 2024
bec9b7a
Merge branch 'debug-logging' into feature/profile
saforem2 Oct 14, 2024
43fc2fe
fixed typo
zhenghh04 Oct 14, 2024
94d5337
Merge branch 'feature/profile' of github.com:argonne-lcf/Megatron-Dee…
zhenghh04 Oct 14, 2024
bb55e97
Merge pull request #67 from argonne-lcf/feature/profile
saforem2 Oct 14, 2024
d50239f
added support for blending samples across different files in the same…
zhenghh04 Oct 14, 2024
9b4f510
Merge pull request #64 from argonne-lcf/debug-logging
saforem2 Oct 14, 2024
324ef11
Merge branch 'alcf-hzheng-data-fix' into hzheng-data-fix
saforem2 Oct 15, 2024
45ff652
Discard changes to megatron/data/gpt_dataset.py
saforem2 Oct 15, 2024
52a406c
Consistent logging in `megatron/data/*.py`
saforem2 Oct 15, 2024
63b1901
Update `megatron/data/gpt_dataset.py`
saforem2 Oct 16, 2024
7ef26bf
Use `time.perf_counter` in `megatron/data/blendable_dataset.py`
saforem2 Oct 16, 2024
deb95cd
fix init issue for silently ignoring the deepspeed config (#452)
xylian86 Oct 17, 2024
68da2db
Update `ALCF/helpers.sh`
saforem2 Oct 17, 2024
ab3a8ec
Merge branch 'main' of https://github.com/microsoft/Megatron-DeepSpee…
saforem2 Oct 18, 2024
ed21bd9
Merge branch 'hzheng-data-fix' of https://github.com/argonne-lcf/Mega…
saforem2 Oct 18, 2024
6acc370
fix moe tflops (#445)
ranzhejiang Oct 18, 2024
467279b
Merge 'upstream/main' into `hzeng-data-fix`
saforem2 Oct 18, 2024
9e015cc
Remove duplicate `gradient_accumulation_steps` in DS config
saforem2 Oct 18, 2024
58dc2d7
Update default EVAL args
saforem2 Oct 21, 2024
277d308
Catch eval metrics in `megatron/training.py`
saforem2 Oct 21, 2024
af4cba1
Save git branch to env in `train_aGPT_7B.sh`
saforem2 Oct 21, 2024
8a8472c
fixed print out bug
zhenghh04 Oct 21, 2024
dfd0643
Merge pull request #68 from argonne-lcf/feature/blending_corpus
saforem2 Oct 21, 2024
6cb727d
Fix `args.shuffle` in `megatron/data/gpt_dataset.py`
saforem2 Oct 21, 2024
5d10179
Update `--{shuffle,blend}-sample-in-corpus` arg in `ALCF/helpers.sh`
saforem2 Oct 24, 2024
160d6a6
fix: `GRAD_ACC_STEPS` when `NHOSTS == 256`
saforem2 Oct 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
391 changes: 227 additions & 164 deletions ALCF/helpers.sh

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ALCF/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ six
numpy<2
schedulefree
packaging>=20.0
wandb
40 changes: 39 additions & 1 deletion ALCF/test_blendable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python
import time
import json
start_time = time.time()
from mpi4py import MPI
import os
Expand Down Expand Up @@ -37,7 +38,7 @@ def print_rank_0(msg):

os.makedirs(args.trace_dir, exist_ok=True)


corpus_all = []
data_file_list = args.data_file_list
print_rank_0(f"Reading data from {args.data_file_list}")
files = []
Expand All @@ -51,6 +52,9 @@ def print_rank_0(msg):
files.append(float(w))
files.append(fname)
files.append(c)
if c not in corpus_all:
corpus_all.append(c)

splits_string="100,0,0"

weights = np.array(weights)
Expand Down Expand Up @@ -82,6 +86,40 @@ def print_rank_0(msg):
print_rank_0(f"Total number of samples: {len(train_ds)}")
print_rank_0(f"Weights set: {weights[:min(8, num_datasets)]}")


def get_sample_info(blendable_dataset, idx):
# corpus dataset
cd = blendable_dataset.dataset_index[idx]
# index within the corpus dataset
cds = blendable_dataset.dataset_sample_index[idx]
# dataset index within each corpus
fcd = blendable_dataset.datasets[cd].dataset_index[cds]
# sample index within the dataset
fcds = blendable_dataset.datasets[cd].dataset_sample_index[cds]
# corresponding data file
prefix = blendable_dataset.datasets[cd].dataset_builders[fcd].prefix
corpus = blendable_dataset.datasets[cd].dataset_builders[fcd].corpus
#v = blendable_dataset[idx]['text']
#norm = np.linalg.norm(v)
return prefix, corpus, fcds

num_batches = args.train_iters
print(f"global_batch_size: {args.global_batch_size}")
print(f"number of batches: {num_batches}")

fout = open("samples_list.jsonl", "w")
if comm.rank == 0:
for i in range(num_batches):
ns_corpus = {}
for c in corpus_all:
ns_corpus[c] = 0
for j in range(args.global_batch_size):
prefix, corpus, idx = get_sample_info(train_ds, i*args.global_batch_size+j)
ns_corpus[corpus] +=1
fout.write(f"\u007b 'batch': {i}, 'sample': {j}, 'corpus': '{corpus}', 'prefix': '{prefix}', 'dataset_sample_index': {idx} \u007d\n")
fout.write(f"\u007b 'batch': {i}, 'histogram': {ns_corpus} \u007d \n")
comm.Barrier()
exit()
start_build_dataloader = time.time()
print_rank_0(f"Starting to build the data loader")
rank_in_parallel_group = mpu.get_sequence_parallel_rank()
Expand Down
4 changes: 4 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,10 @@ def _add_data_args(parser):
group.add_argument('--data-file-list', type=str, default=None,
help='The file with the list of dataset and weights')

group.add_argument('--shuffle-sample-in-corpus', action='store_true', help="Whether to shuffle the samples within in the dataset files")

group.add_argument('--blend-sample-in-corpus', action='store_true', help="Whether to blend different files in the same corpus")

group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
Expand Down
20 changes: 11 additions & 9 deletions megatron/core/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from megatron.core import ModelParallelConfig
from deepspeed.accelerator import get_accelerator

from megatron.utils import Profile
dlp = Profile("PIPELINE")
# Types
Shape = Union[List[int], torch.Size]

Expand Down Expand Up @@ -329,6 +330,7 @@ def _ring_exchange_wrapper(**kwargs):
return tensor_recv_prev, tensor_recv_next, reqs


@dlp.log
def recv_forward(tensor_shape: Shape,
config: ModelParallelConfig) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
Expand All @@ -353,7 +355,7 @@ def recv_forward(tensor_shape: Shape,
config.timers('forward-recv').stop()
return input_tensor


@dlp.log
def recv_backward(tensor_shape: Shape,
config: ModelParallelConfig) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
Expand All @@ -376,7 +378,7 @@ def recv_backward(tensor_shape: Shape,
config.timers('backward-recv').stop()
return output_tensor_grad


@dlp.log
def send_forward(output_tensor: torch.Tensor,
config: ModelParallelConfig) -> None:
"""Send tensor to next rank in pipeline (forward send).
Expand All @@ -397,7 +399,7 @@ def send_forward(output_tensor: torch.Tensor,
if config.timers is not None:
config.timers('forward-send').stop()


@dlp.log
def send_backward(input_tensor_grad: torch.Tensor,
config: ModelParallelConfig) -> None:
"""Send tensor to previous rank in pipeline (backward send).
Expand All @@ -417,7 +419,7 @@ def send_backward(input_tensor_grad: torch.Tensor,
if config.timers is not None:
config.timers('backward-send').stop()


@dlp.log
def send_forward_recv_backward(output_tensor: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig) -> torch.Tensor:
Expand All @@ -441,7 +443,7 @@ def send_forward_recv_backward(output_tensor: torch.Tensor,
config.timers('forward-send-backward-recv').stop()
return output_tensor_grad


@dlp.log
def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig) -> torch.Tensor:
Expand All @@ -465,7 +467,7 @@ def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
config.timers('backward-send-forward-recv').stop()
return input_tensor


@dlp.log
def send_forward_recv_forward(output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
Expand All @@ -491,7 +493,7 @@ def send_forward_recv_forward(output_tensor: torch.Tensor,
return input_tensor, wait_handles
return input_tensor


@dlp.log
def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
Expand All @@ -517,7 +519,7 @@ def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
return output_tensor_grad, wait_handles
return output_tensor_grad


@dlp.log
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
Expand Down
37 changes: 15 additions & 22 deletions megatron/data/blendable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,9 @@
from megatron.utils import Profile, PerfTrace
from mpi4py import MPI

try:
import ezpz as ez
RANK = ez.get_rank()
except Exception:
RANK = torch.distributed.get_rank()

# NOTE: [logging]-----------------------------------------------------------
# - Set logging level to "INFO" on RANK == 0, "CRITICAL" on all other ranks
log = logging.getLogger(__name__)
LOG_LEVEL = str(os.environ.get("LOG_LEVEL", "INFO")).upper()
log.setLevel(LOG_LEVEL) if RANK == 0 else log.setLevel("CRITICAL")
# --------------------------------------------------------------------------
from megatron.utils import get_logger

log = get_logger(__name__, rank_zero_only=True)

dlp = Profile("DATASET")
class BlendableDataset(torch.utils.data.Dataset):
Expand All @@ -50,16 +41,18 @@ def __init__(self, datasets, weights, size, *,
# Build indicies.
@dlp.log
def _build_indices():
start_time = time.time()
start_time = time.perf_counter()
dataset_index = np.zeros(self.size, dtype=np.int64)
dataset_sample_index = np.zeros(self.size, dtype=np.int64)

from megatron.data import helpers
helpers.build_blending_indices(dataset_index, dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
log.info('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))
log.info(
"> elapsed time for building blendable dataset indices: "
f"{time.perf_counter() - start_time:.2f} (sec)"
)
return dataset_index, dataset_sample_index

desc = "Blendable dataset\n\n"
Expand All @@ -83,15 +76,15 @@ def _build_indices():
' dataset, building indices on rank 0 ...', flush=True)
dataset_index, dataset_sample_index = _build_indices()
try:
log.info(" > saving index map files")
start_time = time.time()
log.debug(" > saving index map files")
start_time = time.perf_counter()
os.makedirs(os.path.dirname(index_path), exist_ok=True)
with open(desc_path, 'wt') as fd:
fd.write(desc)
np.save(index_path, dataset_index, allow_pickle=True)
np.save(sample_index_path, dataset_sample_index,
allow_pickle=True)
log.info(f" > finished saving index map files in {time.time() - start_time} seconds")
log.info(f" > finished saving index map files in {time.perf_counter() - start_time} seconds")
except OSError:
print(f'There was an error trying to create the data cache directory ({data_cache_path})')
print('or a file in it. This is set with the --data-cache-path argument. Please')
Expand All @@ -114,15 +107,15 @@ def _build_indices():
torch.distributed.barrier(group=mpu.get_data_parallel_group())
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
torch.distributed.barrier(group=mpu.get_data_parallel_group())
start_time = time.time()

start_time = time.perf_counter()
log.info(f'> loading blendable dataset index: {index_path}')
self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r')
assert self.dataset_index.size == self.size
log.info(f'> loading blendable dataset sample index: {sample_index_path}')
self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r')
assert self.dataset_sample_index.size == self.size
log.info(f'> finished loading in {time.time() - start_time} seconds')
log.info(f'> finished loading in {time.perf_counter() - start_time} seconds')
else:
self.dataset_index, self.dataset_sample_index = _build_indices()

Expand All @@ -148,4 +141,4 @@ def __getitem__(self, idx):
return {
"dataset_idx" : dataset_idx,
**self.datasets[dataset_idx][sample_idx],
}
}
Loading