Skip to content

Commit

Permalink
Consistent logging in megatron/data/*.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Oct 15, 2024
1 parent 45ff652 commit 52a406c
Show file tree
Hide file tree
Showing 3 changed files with 590 additions and 347 deletions.
14 changes: 8 additions & 6 deletions megatron/data/blendable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def _build_indices():
helpers.build_blending_indices(dataset_index, dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
log.info('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))
log.info(
"> elapsed time for building blendable dataset indices: "
f"{time.perf_counter() - start_time:.2f} (sec)"
)
return dataset_index, dataset_sample_index

desc = "Blendable dataset\n\n"
Expand All @@ -74,7 +76,7 @@ def _build_indices():
' dataset, building indices on rank 0 ...', flush=True)
dataset_index, dataset_sample_index = _build_indices()
try:
log.info(" > saving index map files")
log.debug(" > saving index map files")
start_time = time.time()
os.makedirs(os.path.dirname(index_path), exist_ok=True)
with open(desc_path, 'wt') as fd:
Expand Down Expand Up @@ -105,15 +107,15 @@ def _build_indices():
torch.distributed.barrier(group=mpu.get_data_parallel_group())
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
torch.distributed.barrier(group=mpu.get_data_parallel_group())

start_time = time.time()
log.info(f'> loading blendable dataset index: {index_path}')
self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r')
assert self.dataset_index.size == self.size
log.info(f'> loading blendable dataset sample index: {sample_index_path}')
self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r')
assert self.dataset_sample_index.size == self.size
log.info(f'> finished loading in {time.time() - start_time} seconds')
log.info(f'> finished loading in {time.time() - start_time} seconds')
else:
self.dataset_index, self.dataset_sample_index = _build_indices()

Expand All @@ -139,4 +141,4 @@ def __getitem__(self, idx):
return {
"dataset_idx" : dataset_idx,
**self.datasets[dataset_idx][sample_idx],
}
}
Loading

0 comments on commit 52a406c

Please sign in to comment.