Skip to content

Commit

Permalink
Update pretrain_gpt_alcf.py
Browse files Browse the repository at this point in the history
Remve `--train-range-to-skip` logic from `pretrain_gpt_alcf.py` and
remove redundant code.
  • Loading branch information
saforem2 committed Sep 17, 2024
1 parent 828f6a9 commit 295fcb3
Showing 1 changed file with 9 additions and 35 deletions.
44 changes: 9 additions & 35 deletions pretrain_gpt_alcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

"""Pretrain GPT"""
import time
from typing import Callable, Type
from typing import Callable
from mpi4py import MPI

comm = MPI.COMM_WORLD
comm.Barrier()
python_start_time = time.time()
from pathlib import Path

import os
from rich import print
Expand Down Expand Up @@ -189,6 +188,14 @@ def get_batch(data_iterator):
keys = ["text"]
datatype = torch.int64
data = next(data_iterator) if data_iterator is not None else None

if (
args.iteration < 10
and RANK == 0
and os.environ.get("DUMP_TOKENS", None)
and data is not None
):
log.info(f"{args.iteration=}: {data['text'][:10]=}")
# # Broadcast data.
# if data_iterator is not None:
# data = next(data_iterator)
Expand Down Expand Up @@ -388,13 +395,6 @@ def calculate_mos_loss(
return mos_loss


# ForwardStepOutput = Type[tuple[torch.Tensor | None, Callable[[torch.Tensor], torch.Tensor | None]]]


def _return_none(_: torch.Tensor) -> torch.Tensor | None:
return None


def forward_step(data_iterator, model) -> tuple[torch.Tensor | None, Callable]:
"""Forward step."""
args = get_args()
Expand All @@ -405,32 +405,6 @@ def forward_step(data_iterator, model) -> tuple[torch.Tensor | None, Callable]:
timers("batch-generator", log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers("batch-generator").stop()
ranges_to_skip = None
if args.train_range_to_skip is not None:
assert (
len(args.train_range_to_skip) % 2 == 0
), f"""Expected --train-range-to-skip to have an even number of values.
Received: {len(args.train_range_to_skip)}
"""
ranges_to_skip = list(
zip(
args.train_range_to_skip[::2],
args.train_range_to_skip[1::2],
)
)
if ranges_to_skip is not None and any(
[i <= (args.iteration + 1) <= j for (i, j) in ranges_to_skip]
):
log.info(
f"Caught {args.iteration} in 'forward_step', {tokens.shape()=}, {args.consumed_train_tokens=}'"
)
# log.info(f"Caught {args.iteration + 1} in 'ranges_to_skip', skipping!"
# return (None, _return_none)
return (
torch.tensor([0.0], device=tokens.device),
lambda _: torch.Tensor([0.0], device=tokens.device),
# lambda _: return torch.Tensor([0.0], deviec=tokens.device),
)
if args.data_efficiency_curriculum_learning:
args.curriculum_seqlen = tokens.size()[1]
if hasattr(args, "data_efficiency_curriculum_learning_seqlen_type") and (
Expand Down

0 comments on commit 295fcb3

Please sign in to comment.