From 7cb9c1137f3bee4e84071c2b58fffdd301b44c54 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Sat, 14 Sep 2024 21:30:23 -0500 Subject: [PATCH 1/7] Update `pretrain_gpt_alcf.py` --- pretrain_gpt_alcf.py | 336 ++++++++++++++++++++++--------------------- 1 file changed, 171 insertions(+), 165 deletions(-) diff --git a/pretrain_gpt_alcf.py b/pretrain_gpt_alcf.py index 04018d7918..4a6c3453da 100644 --- a/pretrain_gpt_alcf.py +++ b/pretrain_gpt_alcf.py @@ -2,7 +2,9 @@ """Pretrain GPT""" import time +from typing import Callable, Type from mpi4py import MPI + comm = MPI.COMM_WORLD comm.Barrier() python_start_time = time.time() @@ -14,6 +16,7 @@ import math from functools import partial from megatron import get_args + # from megatron import print_rank_0 from megatron import get_timers from megatron import get_tokenizer @@ -23,14 +26,19 @@ from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids -from megatron.utils import average_losses_across_data_parallel_group, update_rotary_pos_emb +from megatron.utils import ( + average_losses_across_data_parallel_group, + update_rotary_pos_emb, +) from megatron.arguments import core_transformer_config_from_args + # from megatron.utils import Profile, PerfTrace import logging import deepspeed from deepspeed.runtime.utils import see_memory_usage + # from deepspeed.accelerator.real_accelerator import get_accelerator import subprocess import wandb @@ -38,7 +46,8 @@ from torch import nn import torch.nn.functional as F import ezpz as ez -dt_imports = time.time() - python_start_time + +dt_imports = time.time() - python_start_time t0_setup = time.time() # ---- [SETUP COMMS] ------------------------ @@ -62,19 +71,12 @@ log.info(f"ez.setup_torch time: {dt_setup} seconds") # ---- [SETUP WANDB FROM RANK 0] -------------- -WANDB_MODE = os.environ.get('WANDB_MODE', None) -DISABLE_WANDB = ( - WANDB_MODE is not None and str(WANDB_MODE).lower() == 'disabled' -) +WANDB_MODE = os.environ.get("WANDB_MODE", None) +DISABLE_WANDB = WANDB_MODE is not None and str(WANDB_MODE).lower() == "disabled" if RANK == 0 and not DISABLE_WANDB: - project_name = ( - os.environ.get( - 'WB_PROJECT', # look for WB_PROJECT in env - os.environ.get( - 'WANDB_PROJECT', # look for WANDB_PROJECT in env - 'AuroraGPT' - ), - ) + project_name = os.environ.get( + "WB_PROJECT", # look for WB_PROJECT in env + os.environ.get("WANDB_PROJECT", "AuroraGPT"), # look for WANDB_PROJECT in env ) log.info(f"Setting up W&B from: {RANK} with {project_name}") _ = ez.setup_wandb(project_name=project_name) @@ -83,16 +85,16 @@ @ez.dist.timeitlogit(rank=RANK) def model_provider(pre_process=True, post_process=True): """Build the model.""" - log.info('building GPT model ...') + log.info("building GPT model ...") see_memory_usage("Before Building Model", force=True) args = get_args() assert args is not None config = core_transformer_config_from_args(args) # if RANK == 0: # git_ds_info() - if hasattr(mpu, 'get_sequence_data_parallel_group'): + if hasattr(mpu, "get_sequence_data_parallel_group"): dpg = mpu.get_sequence_data_parallel_group() - elif hasattr(mpu, 'get_data_parallel_group'): + elif hasattr(mpu, "get_data_parallel_group"): dpg = mpu.get_data_parallel_group() else: dpg = None @@ -100,20 +102,14 @@ def model_provider(pre_process=True, post_process=True): if args.use_mics: deepspeed_zero_init = deepspeed.zero.MiCS_Init with deepspeed_zero_init( - data_parallel_group=dpg, - remote_device=( - None if args.remote_device == 'none' else args.remote_device - ), - config_dict_or_path=args.deepspeed_config_dict, - enabled=args.zero_stage == 3, - mpu=mpu + data_parallel_group=dpg, + remote_device=(None if args.remote_device == "none" else args.remote_device), + config_dict_or_path=args.deepspeed_config_dict, + enabled=args.zero_stage == 3, + mpu=mpu, ): if args.deepspeed and not args.no_pipeline_parallel: - model = GPTModelPipe( - config=config, - num_tokentypes=0, - parallel_output=True - ) + model = GPTModelPipe(config=config, num_tokentypes=0, parallel_output=True) # This is a hack to give us a reference to # get_batch_pipe from within training.py # We need to call model.set_batch_fn after deepspeed.initialize @@ -129,7 +125,7 @@ def model_provider(pre_process=True, post_process=True): ) ).view(1, 1, args.seq_length, args.seq_length) # Convert attention mask to binary: - attention_mask = (attention_mask < 0.5) + attention_mask = attention_mask < 0.5 if args.fp16: attention_mask = attention_mask.half() elif args.bf16: @@ -146,37 +142,33 @@ def model_provider(pre_process=True, post_process=True): num_tokentypes=0, parallel_output=True, pre_process=pre_process, - post_process=post_process + post_process=post_process, ) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - log.info(80 * '-') + log.info(80 * "-") log.info(f"Number of parameters in model: {num_params}") - log.info(80 * '-') + log.info(80 * "-") see_memory_usage("After Building Model", force=True) - if wandb is not None and getattr(wandb, 'run', None) is not None: + if wandb is not None and getattr(wandb, "run", None) is not None: assert wandb.run is not None tbdir = args.tensorboard_dir # tbdir = args.getattr('tensorboard_dir', None) if tbdir is not None: try: - log.info(f'Patching tensorboard from {tbdir}') + log.info(f"Patching tensorboard from {tbdir}") wandb.tensorboard.patch(root_logdir=tbdir) except ValueError as exc: log.exception(exc) - log.warning('Continuing without patching tensorboard!') - wandb.run.config.update({'num_params': num_params}) + log.warning("Continuing without patching tensorboard!") + wandb.run.config.update({"num_params": num_params}) if "args" not in wandb.run.config: log.info( f"Updating WandB run.config: [{wandb.run.name}]({wandb.run.get_url()})" ) try: - wandb.run.config.update( - {"args": dict(sorted(vars(args).items()))} - ) + wandb.run.config.update({"args": dict(sorted(vars(args).items()))}) except Exception: - log.error( - 'Unable to `wandb.run.config.update({"args": vars(args)})`' - ) + log.error('Unable to `wandb.run.config.update({"args": vars(args)})`') # try: # wandb.run.watch( # model, @@ -194,7 +186,7 @@ def get_batch(data_iterator): tokenizer = get_tokenizer() assert args is not None and tokenizer is not None # Items and their type. - keys = ['text'] + keys = ["text"] datatype = torch.int64 data = next(data_iterator) if data_iterator is not None else None # # Broadcast data. @@ -204,7 +196,7 @@ def get_batch(data_iterator): # data = None data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. - tokens_ = data_b['text'].long() + tokens_ = data_b["text"].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. @@ -215,7 +207,8 @@ def get_batch(data_iterator): args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, - skip_mask) + skip_mask, + ) # For DS's sequence parallel seq_parallel_world_size = mpu.get_sequence_parallel_world_size() seq_parallel_world_rank = mpu.get_sequence_parallel_rank() @@ -240,24 +233,37 @@ def data_post_process(data, data_sampler_state_dict): args = get_args() assert args is not None if args.data_efficiency_curriculum_learning: - if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']: - args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate' - current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate'] + if "seqlen_truncate" in data_sampler_state_dict["current_difficulties"]: + args.data_efficiency_curriculum_learning_seqlen_type = "seqlen_truncate" + current_seqlen = data_sampler_state_dict["current_difficulties"][ + "seqlen_truncate" + ] if current_seqlen < args.seq_length: - data['text'] = data['text'][:, :(current_seqlen+1)].contiguous() - elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']: - args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape' - current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape'] + data["text"] = data["text"][:, : (current_seqlen + 1)].contiguous() + elif "seqlen_reshape" in data_sampler_state_dict["current_difficulties"]: + args.data_efficiency_curriculum_learning_seqlen_type = "seqlen_reshape" + current_seqlen = data_sampler_state_dict["current_difficulties"][ + "seqlen_reshape" + ] if current_seqlen < args.seq_length: - orig_num_token = torch.numel(data['text']) - reshape_len = (data['text'].size()[1] // (current_seqlen+1)) * (current_seqlen+1) - data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen+1), - data['text'][:, -(current_seqlen+1):]), 0).contiguous() - num_row = math.ceil(orig_num_token / (current_seqlen+1)) - num_row = min(num_row, data['text'].size()[0]) + orig_num_token = torch.numel(data["text"]) + reshape_len = (data["text"].size()[1] // (current_seqlen + 1)) * ( + current_seqlen + 1 + ) + data["text"] = torch.cat( + ( + data["text"][:, :reshape_len] + .contiguous() + .view(-1, current_seqlen + 1), + data["text"][:, -(current_seqlen + 1) :], + ), + 0, + ).contiguous() + num_row = math.ceil(orig_num_token / (current_seqlen + 1)) + num_row = min(num_row, data["text"].size()[0]) if num_row > 1 and num_row % 2 != 0: num_row -= 1 - data['text'] = data['text'][:num_row, :].contiguous() + data["text"] = data["text"][:num_row, :].contiguous() else: args.data_efficiency_curriculum_learning_seqlen_type = None return data @@ -272,12 +278,12 @@ def get_batch_pipe(data): tokenizer = get_tokenizer() assert args is not None # Items and their type. - keys = ['text'] + keys = ["text"] datatype = torch.int64 # Broadcast data. data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. - tokens_ = data_b['text'].long() + tokens_ = data_b["text"].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. @@ -286,19 +292,17 @@ def get_batch_pipe(data): tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, - args.eod_mask_loss) - if ( - args.curriculum_learning_legacy - and args.curriculum_seqlen < tokens.size()[1] - ): + args.eod_mask_loss, + ) + if args.curriculum_learning_legacy and args.curriculum_seqlen < tokens.size()[1]: # seqlen-based curriculum learning # tokens, position_ids, labels, loss_mask # have size [batch size, seqlen] - tokens = tokens[:, :args.curriculum_seqlen].contiguous() - position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + tokens = tokens[:, : args.curriculum_seqlen].contiguous() + position_ids = position_ids[:, : args.curriculum_seqlen].contiguous() if labels is not None: - labels = labels[:, :args.curriculum_seqlen].contiguous() - loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + labels = labels[:, : args.curriculum_seqlen].contiguous() + loss_mask = loss_mask[:, : args.curriculum_seqlen].contiguous() return (tokens, position_ids, attention_mask), (labels, loss_mask) @@ -315,37 +319,32 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): loss = loss + moe_loss + mos_loss if args.mos: return loss, { - 'total loss': loss, - 'lm loss': averaged_loss[0], - 'moe loss': moe_loss, - 'mos loss': mos_loss + "total loss": loss, + "lm loss": averaged_loss[0], + "moe loss": moe_loss, + "mos loss": mos_loss, } elif args.kd: return loss, { - 'total loss': loss, - 'lm loss': averaged_loss[0], - 'moe loss': moe_loss, - 'kd loss': mos_loss + "total loss": loss, + "lm loss": averaged_loss[0], + "moe loss": moe_loss, + "kd loss": mos_loss, } log.info( - f'>>> total loss: {loss}, ' - f'lm loss {averaged_loss[0]}, ' - f'kd loss {mos_loss}' + f">>> total loss: {loss}, " + f"lm loss {averaged_loss[0]}, " + f"kd loss {mos_loss}" ) else: if max(args.num_experts) <= 1: - return loss, {'lm loss': averaged_loss[0]} + return loss, {"lm loss": averaged_loss[0]} loss = loss + moe_loss - return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + return loss, {"lm loss": averaged_loss[0], "moe loss": moe_loss} def calculate_mos_loss( - args, - stu_output, - teacher_model, - tokens, - position_ids, - attention_mask + args, stu_output, teacher_model, tokens, position_ids, attention_mask ): mos_loss = 0 alpha = args.kd_alpha_ce @@ -354,29 +353,25 @@ def calculate_mos_loss( if teacher_model: with torch.no_grad(): if ( - args.curriculum_learning_legacy and - args.curriculum_seqlen < args.seq_length + args.curriculum_learning_legacy + and args.curriculum_seqlen < args.seq_length ): assert args.curriculum_seqlen is not None curriculum_seqlen = args.curriculum_seqlen tokens = tokens[:, :curriculum_seqlen].contiguous() position_ids = position_ids[:, :curriculum_seqlen].contiguous() csl = curriculum_seqlen - attention_mask = ( - attention_mask[:, :, :csl, :csl].contiguous() - ) + attention_mask = attention_mask[:, :, :csl, :csl].contiguous() # No need to truncate labels # as we do not need it for the teacher logits tea_output, tea_other_losses = teacher_model( - tokens, - position_ids, - attention_mask + tokens, position_ids, attention_mask ) assert stu_output.size() == tea_output.size(), ( - 'teacher and student output should match in size. ' - f'Student: {stu_output.size()}, ' - f'Teacher: {tea_output.size()}, ' - f'CL seq length {args.curriculum_seqlen}' + "teacher and student output should match in size. " + f"Student: {stu_output.size()}, " + f"Teacher: {tea_output.size()}, " + f"CL seq length {args.curriculum_seqlen}" ) student_logits = F.log_softmax(stu_output / kd_temp, dim=2) # The target logits is expected to be probabilities. @@ -384,67 +379,81 @@ def calculate_mos_loss( # then we need to set target_log to true # when initializing the KLDivLoss. tea_logits = F.softmax(tea_output / kd_temp, dim=2) - mos_loss = kd_temp * kd_temp * nn.KLDivLoss(reduction='batchmean')( - student_logits, - tea_logits + mos_loss = ( + kd_temp + * kd_temp + * nn.KLDivLoss(reduction="batchmean")(student_logits, tea_logits) ) mos_loss = mos_loss.div(args.seq_length) * beta return mos_loss -def forward_step(data_iterator, model): +# ForwardStepOutput = Type[tuple[torch.Tensor | None, Callable[[torch.Tensor], torch.Tensor | None]]] + + +def _return_none(_: torch.Tensor) -> torch.Tensor | None: + return None + + +def forward_step(data_iterator, model) -> tuple[torch.Tensor | None, Callable]: """Forward step.""" args = get_args() timers = get_timers() assert args is not None assert timers is not None # Get the batch. - timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator - ) - timers('batch-generator').stop() + timers("batch-generator", log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) + timers("batch-generator").stop() + ranges_to_skip = None + if args.train_range_to_skip is not None: + assert ( + len(args.train_range_to_skip) % 2 == 0 + ), f"""Expected --train-range-to-skip to have an even number of values. + Received: {len(args.train_range_to_skip)} + """ + ranges_to_skip = list( + zip( + args.train_range_to_skip[::2], + args.train_range_to_skip[1::2], + ) + ) + if ranges_to_skip is not None and any( + [i <= (args.iteration + 1) <= j for (i, j) in ranges_to_skip] + ): + log.info( + f"Caught {args.iteration} in 'forward_step', {tokens.shape()=}, {args.consumed_train_tokens=}'" + ) + # log.info(f"Caught {args.iteration + 1} in 'ranges_to_skip', skipping!" + # return (None, _return_none) + return ( + torch.tensor([0.0], device=tokens.device), + lambda _: torch.Tensor([0.0], device=tokens.device), + # lambda _: return torch.Tensor([0.0], deviec=tokens.device), + ) if args.data_efficiency_curriculum_learning: args.curriculum_seqlen = tokens.size()[1] - if ( - hasattr( - args, - 'data_efficiency_curriculum_learning_seqlen_type') - and ( - args.data_efficiency_curriculum_learning_seqlen_type - == 'seqlen_reshape' - ) + if hasattr(args, "data_efficiency_curriculum_learning_seqlen_type") and ( + args.data_efficiency_curriculum_learning_seqlen_type == "seqlen_reshape" ): - args.data_efficiency_curriculum_learning_numel = ( - torch.numel(tokens) - ) + args.data_efficiency_curriculum_learning_numel = torch.numel(tokens) stu_output = None if args.mos or args.kd: # The forward func can return either the loss or the logits, # depending on whether passing in the labels or not. stu_output, other_losses = model(tokens, position_ids, attention_mask) - if ( - args.curriculum_learning_legacy - and args.curriculum_seqlen < args.seq_length - ): + if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: assert args.curriculum_seqlen is not None - labels = labels[:, :args.curriculum_seqlen].contiguous() + labels = labels[:, : args.curriculum_seqlen].contiguous() output_tensor = tensor_parallel.vocab_parallel_cross_entropy( - stu_output.contiguous().float(), - labels + stu_output.contiguous().float(), labels ) else: output_tensor, other_losses = model( - tokens, - position_ids, - attention_mask, - labels=labels + tokens, position_ids, attention_mask, labels=labels ) - if ( - args.curriculum_learning_legacy and - args.curriculum_seqlen < args.seq_length - ): - loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: + loss_mask = loss_mask[:, : args.curriculum_seqlen].contiguous() moe_losses = [] for moe_loss in other_losses: @@ -462,7 +471,7 @@ def forward_step(data_iterator, model): args.teacher_model[0], tokens, position_ids, - attention_mask + attention_mask, ) # Output_tensor stores the standard loss, @@ -479,7 +488,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): # from ezpz.profile import get_context_manager # cm = get_context_manager(rank=RANK, outdir=args.save) # with cm: - log.info('> building train, validation, and test datasets for GPT ...') + log.info("> building train, validation, and test datasets for GPT ...") files = [] if args.data_file_list is not None: log.info(f"Reading datasets from {args.data_file_list}") @@ -492,7 +501,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): # - `/path/to/data_text_document` is the path to the text document # - `corpus` is the corpus (~ source, can be made up) where that # document came from (i.e. `books`, `arxiv`, etc.) - with open(args.data_file_list, 'r') as flist: + with open(args.data_file_list, "r") as flist: for f in flist.readlines(): if len(f.strip()) != 0: try: @@ -505,17 +514,11 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ) if fname.find(".bin") != -1: fname = fname.split(".bin")[0] - files.extend( - [ - float(w), # weight - fname, # filename - c # corpus - ] - ) + files.extend([float(w), fname, c]) # weight # filename # corpus elif len(args.data_path) == 1 and os.path.isdir(args.data_path[0]): path = args.data_path[0] + "/" for f in os.listdir(path): - if (os.path.isfile(path + f) and f.find(".bin") != -1): + if os.path.isfile(path + f) and f.find(".bin") != -1: files.append(1) files.append(path + f.split(".bin")[0]) else: @@ -540,11 +543,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): def command_exists(cmd): - result = subprocess.Popen( - f'type {cmd}', - stdout=subprocess.PIPE, - shell=True - ) + result = subprocess.Popen(f"type {cmd}", stdout=subprocess.PIPE, shell=True) return result.wait() == 0 @@ -552,17 +551,18 @@ def git_ds_info(): if RANK != 0: return from deepspeed.env_report import main as ds_report + ds_report() # Write out version/git info git_hash_cmd = "git rev-parse --short HEAD" git_branch_cmd = "git rev-parse --abbrev-ref HEAD" - if command_exists('git'): + if command_exists("git"): try: result = subprocess.check_output(git_hash_cmd, shell=True) - git_hash = result.decode('utf-8').strip() + git_hash = result.decode("utf-8").strip() result = subprocess.check_output(git_branch_cmd, shell=True) - git_branch = result.decode('utf-8').strip() + git_branch = result.decode("utf-8").strip() except subprocess.CalledProcessError: git_hash = "unknown" git_branch = "unknown" @@ -570,21 +570,26 @@ def git_ds_info(): git_hash = "unknown" git_branch = "unknown" print( - f'**** Git info for Megatron: ' - f'git_hash={git_hash} git_branch={git_branch} ****' + f"**** Git info for Megatron: " + f"git_hash={git_hash} git_branch={git_branch} ****" ) def main(): - if os.getenv('TORCH_PROFILER_ENABLE') == '1': + if os.getenv("TORCH_PROFILER_ENABLE") == "1": # record_function from torch.profiler import profile, ProfilerActivity + try: - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU] + activities = [ + ProfilerActivity.CPU, + ProfilerActivity.CUDA, + ProfilerActivity.XPU, + ] except Exception as exc: log.exception(exc) - log.warning("TORCH PROFILER WARNING: XPU is not supported") - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] + log.warning("TORCH PROFILER WARNING: XPU is not supported") + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] with profile(activities=activities) as prof: model = pretrain( train_valid_test_datasets_provider, @@ -592,7 +597,7 @@ def main(): ModelType.encoder_or_decoder, forward_step, # args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - data_post_process=data_post_process + data_post_process=data_post_process, ) args = get_args() assert args is not None @@ -606,7 +611,7 @@ def main(): ModelType.encoder_or_decoder, forward_step, # args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - data_post_process=data_post_process + data_post_process=data_post_process, ) # try: # from megatron.text_generation import generate_and_post_process @@ -641,6 +646,7 @@ def main(): # data_post_process=data_post_process) import sys import deepspeed.comm as dist + model = main() dist.log_summary() if wandb.run is not None: From e83de19527830ade3a49f3fcfb698739a84c4ac9 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Sat, 14 Sep 2024 21:33:05 -0500 Subject: [PATCH 2/7] Update `megatron/training_log.py` --- megatron/training_log.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/megatron/training_log.py b/megatron/training_log.py index cd6638e17d..24b8015264 100644 --- a/megatron/training_log.py +++ b/megatron/training_log.py @@ -645,6 +645,12 @@ def training_log( ) log_string += " [LM]TFLOPs={:.2f} |".format(tflops_lm_per_gpu) log_string += " [DS]TFLOPs={:.2f} |".format(tflops) + if wandb is not None and getattr(wandb, "run", None) is not None: + wandb_metrics |= { + "training/skiped_iterations": total_loss_dict[skipped_iters_key] + } + wandb_metrics |= {"training/nan_iterations": total_loss_dict[nan_iters_key]} + wandb.log(wandb_metrics) total_loss_dict[advanced_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 @@ -654,12 +660,6 @@ def training_log( # Report memory after optimizer state has been initialized. report_memory("(after {} iterations)".format(iteration)) report_memory_flag = False - if wandb is not None and getattr(wandb, "run", None) is not None: - wandb_metrics |= { - "training/skiped_iterations": total_loss_dict[skipped_iters_key] - } - wandb_metrics |= {"training/nan_iterations": total_loss_dict[nan_iters_key]} - wandb.log(wandb_metrics) if timers is not None: timers.log(timers_to_log, normalizer=args.log_interval) From 29756d6e1e97464e8a7fbe1eca4cf832fa2a916a Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Sat, 14 Sep 2024 21:34:33 -0500 Subject: [PATCH 3/7] Warn if mismatch b/w iters in `megatron/checkpointing.py` --- megatron/checkpointing.py | 4 +++- megatron/training.py | 17 ++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index b7f4b30bde..a4f82ec9d3 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -239,7 +239,9 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): """Save a model checkpoint.""" args = get_args() assert args is not None - iteration = args.iteration + args_iter = args.iteration + if args_iter != iteration: + log.warning(f"{args.iteration=} != {iteration} passed to 'save_checkpoint'") save_lr_state_dict() diff --git a/megatron/training.py b/megatron/training.py index 90a1250648..668aea930c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1004,7 +1004,7 @@ def train( model_module.train() # Tracking loss. total_loss_dict = {} - loss_dict = {} + loss_dict = {"skipped_iter": 0} # Iterations. iteration = args.iteration # Translate args to core configuration @@ -1061,11 +1061,22 @@ def train( ): log.info(f"Caught {iteration + 1} in 'ranges_to_skip', skipping!") # total_loss_dict = {"skipped iterations": } + # loss_dict skipped_iter = 1 - total_loss_dict["skipped iterations"] += skipped_iter - grad_norm = None + # grad_norm = None num_zeros_in_grad = None num_skipped_iters += 1 + increment = ( + get_num_microbatches() * args.micro_batch_size * args.data_parallel_size + ) + model[0].skipped_steps += 1 + model[0].global_steps += 1 + model[0].micro_steps += 1 + model[0].global_samples += model[0].train_batch_size() + # model[0].step(lr_kwargs={"increment": increment}) + # grad_norm = model[0].get_global_grad_norm() + # update_successful = model[0].was_step_applied() + opt_param_scheduler.step(increment=increment) else: if os.getenv("TORCH_PROFILER_ENABLE") == "2": from torch.profiler import profile, ProfilerActivity From 1a7f03b67a260e2e7326fe3a0f12709a8908d616 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Sun, 15 Sep 2024 20:39:24 -0500 Subject: [PATCH 4/7] fix: `try/except` for non tensors in `megatron/training_log.py` --- megatron/training_log.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megatron/training_log.py b/megatron/training_log.py index 24b8015264..3eb96c392d 100644 --- a/megatron/training_log.py +++ b/megatron/training_log.py @@ -92,7 +92,10 @@ def training_log( + loss_dict[key] ) else: - value = loss_dict[key].float().sum().item() + try: + value = loss_dict[key].float().sum().item() + except AttributeError: + value = loss_dict[key] is_nan = value == float("inf") or value == -float("inf") or value != value got_nan = got_nan or is_nan total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int( From 828f6a944627335336f7f6d2e348dd3989ffb021 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Tue, 17 Sep 2024 08:17:17 -0500 Subject: [PATCH 5/7] fix: Correctly draw `grad_acc_steps` batches of data when skipping step --- megatron/training.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 668aea930c..8ffac6cb9c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -822,6 +822,8 @@ def train_step( timers = get_timers() accelerator = get_accelerator() assert args is not None and timers is not None and accelerator is not None + grad_norm = None + num_zeros_in_grad = None if args.deepspeed and args.ds_pipeline_enabled: num_zeros_in_grad = 0 assert isinstance(model[0], deepspeed.PipelineEngine) @@ -919,6 +921,10 @@ def train_step( if args.deepspeed: skipped_iter = 0 if update_successful else 1 grad_norm = model[0].get_global_grad_norm() + # Empty unused memory. + if args.empty_unused_memory_level >= 2 and accelerator is not None: + accelerator.empty_cache() + # XXX: [saforem2]: ---------------------------------------------------- # Is `num_zeros_in_grad` worth calculating (/ implementing) ?? # the `Megatron`-specific implementation is at: @@ -1002,6 +1008,7 @@ def train( # Turn on training mode which enables dropout. for model_module in model: model_module.train() + grad_norm = None # Tracking loss. total_loss_dict = {} loss_dict = {"skipped_iter": 0} @@ -1060,12 +1067,23 @@ def train( [i <= (iteration + 1) <= j for (i, j) in ranges_to_skip] ): log.info(f"Caught {iteration + 1} in 'ranges_to_skip', skipping!") - # total_loss_dict = {"skipped iterations": } - # loss_dict skipped_iter = 1 - # grad_norm = None - num_zeros_in_grad = None num_skipped_iters += 1 + num_zeros_in_grad = None + gas = args.deepspeed_config_dict["gradient_accumulation_steps"] + for microstep in range(gas): + _batch = next(train_data_iterator) + _tokens = _batch["text"] + if ( + iteration < 10 + and os.environ.get("DUMP_SKIPPED_ITERS", None) + and RANK == 0 + ): + log.info(f"{_tokens.shape}, {len(train_data_iterator)=}") + log.info( + f"{iteration=} [{microstep}/{gas}]: ({_tokens.shape})\n{_tokens[:10]=}" + ) + increment = ( get_num_microbatches() * args.micro_batch_size * args.data_parallel_size ) @@ -1073,9 +1091,6 @@ def train( model[0].global_steps += 1 model[0].micro_steps += 1 model[0].global_samples += model[0].train_batch_size() - # model[0].step(lr_kwargs={"increment": increment}) - # grad_norm = model[0].get_global_grad_norm() - # update_successful = model[0].was_step_applied() opt_param_scheduler.step(increment=increment) else: if os.getenv("TORCH_PROFILER_ENABLE") == "2": @@ -1085,7 +1100,7 @@ def train( activities = [ ProfilerActivity.CPU, ProfilerActivity.CUDA, - ProfilerActivity.XPU, + ProfilerActivity.XPU, # type:ignore ] except Exception: log.warning("TORCH PROFILER WARNING: XPU is not supported") From 295fcb3d57a40ec513a521aa8814d99a5c8827b8 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Tue, 17 Sep 2024 08:18:32 -0500 Subject: [PATCH 6/7] Update `pretrain_gpt_alcf.py` Remve `--train-range-to-skip` logic from `pretrain_gpt_alcf.py` and remove redundant code. --- pretrain_gpt_alcf.py | 44 +++++++++----------------------------------- 1 file changed, 9 insertions(+), 35 deletions(-) diff --git a/pretrain_gpt_alcf.py b/pretrain_gpt_alcf.py index 4a6c3453da..12a05c5299 100644 --- a/pretrain_gpt_alcf.py +++ b/pretrain_gpt_alcf.py @@ -2,13 +2,12 @@ """Pretrain GPT""" import time -from typing import Callable, Type +from typing import Callable from mpi4py import MPI comm = MPI.COMM_WORLD comm.Barrier() python_start_time = time.time() -from pathlib import Path import os from rich import print @@ -189,6 +188,14 @@ def get_batch(data_iterator): keys = ["text"] datatype = torch.int64 data = next(data_iterator) if data_iterator is not None else None + + if ( + args.iteration < 10 + and RANK == 0 + and os.environ.get("DUMP_TOKENS", None) + and data is not None + ): + log.info(f"{args.iteration=}: {data['text'][:10]=}") # # Broadcast data. # if data_iterator is not None: # data = next(data_iterator) @@ -388,13 +395,6 @@ def calculate_mos_loss( return mos_loss -# ForwardStepOutput = Type[tuple[torch.Tensor | None, Callable[[torch.Tensor], torch.Tensor | None]]] - - -def _return_none(_: torch.Tensor) -> torch.Tensor | None: - return None - - def forward_step(data_iterator, model) -> tuple[torch.Tensor | None, Callable]: """Forward step.""" args = get_args() @@ -405,32 +405,6 @@ def forward_step(data_iterator, model) -> tuple[torch.Tensor | None, Callable]: timers("batch-generator", log_level=2).start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) timers("batch-generator").stop() - ranges_to_skip = None - if args.train_range_to_skip is not None: - assert ( - len(args.train_range_to_skip) % 2 == 0 - ), f"""Expected --train-range-to-skip to have an even number of values. - Received: {len(args.train_range_to_skip)} - """ - ranges_to_skip = list( - zip( - args.train_range_to_skip[::2], - args.train_range_to_skip[1::2], - ) - ) - if ranges_to_skip is not None and any( - [i <= (args.iteration + 1) <= j for (i, j) in ranges_to_skip] - ): - log.info( - f"Caught {args.iteration} in 'forward_step', {tokens.shape()=}, {args.consumed_train_tokens=}'" - ) - # log.info(f"Caught {args.iteration + 1} in 'ranges_to_skip', skipping!" - # return (None, _return_none) - return ( - torch.tensor([0.0], device=tokens.device), - lambda _: torch.Tensor([0.0], device=tokens.device), - # lambda _: return torch.Tensor([0.0], deviec=tokens.device), - ) if args.data_efficiency_curriculum_learning: args.curriculum_seqlen = tokens.size()[1] if hasattr(args, "data_efficiency_curriculum_learning_seqlen_type") and ( From cf80e6bb75c56cd8a26627e5364e4290da736bd8 Mon Sep 17 00:00:00 2001 From: Marieme Ngom Date: Mon, 23 Sep 2024 16:21:00 +0000 Subject: [PATCH 7/7] added sophia --- megatron/arguments.py | 10 ++ megatron/optimizer/__init__.py | 9 ++ megatron/optimizer/sophia.py | 202 +++++++++++++++++++++++++++++++++ 3 files changed, 221 insertions(+) create mode 100644 megatron/optimizer/sophia.py diff --git a/megatron/arguments.py b/megatron/arguments.py index b3ed06353e..2a0ac606ce 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -780,6 +780,15 @@ def _add_regularization_args(parser): help='Weight decay increment function.') group.add_argument('--clip-grad', type=float, default=1.0, help='Gradient clipping based on global L2 norm.') + group.add_argument('--sophiag-beta1', type=float, default=0.9, + help='First coefficient for computing running averages ' + 'of gradient and its hessian') + group.add_argument('--sophiag-beta2', type=float, default=0.95, + help='Second coefficient for computing running averages ' + 'of gradient and its hessian') + group.add_argument('--sophiag-rho', type=float, default=0.01, + help='SophiaG clipping threshhold') + group.add_argument('--adam-beta1', type=float, default=0.9, help='First coefficient for computing running averages ' 'of gradient and its square') @@ -946,6 +955,7 @@ def _add_training_args(parser): choices=[ 'adam', 'adamw', + 'sophiag', 'sgd', 'ds.fusedlamb', 'ipex.lamb', diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 48f2737a06..99145ff4f4 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -315,6 +315,15 @@ def optimizer_hook(p): weight_decay=args.weight_decay, momentum=args.sgd_momentum ) + elif str(args.optimizer).lower() == 'sophiag': + from .sophia import SophiaG + optimizer = SophiaG( + param_groups, + lr=args.lr, + betas=(args.sophiag_beta1, args.sophiag_beta2), + rho = args.sophiag_rho, + weight_decay=args.weight_decay + ) else: raise TypeError(f'{args.optimizer} optimizer is not supported.') if args.deepspeed: diff --git a/megatron/optimizer/sophia.py b/megatron/optimizer/sophia.py new file mode 100644 index 0000000000..4c4e074790 --- /dev/null +++ b/megatron/optimizer/sophia.py @@ -0,0 +1,202 @@ +import math +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing import List, Optional + + +#SOphiaG implementation from https://github.com/Liuhong99/Sophia/blob/main/sophia.py, copy pasted here because no pip and not sure about submodules + +class SophiaG(Optimizer): + def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho = 0.04, + weight_decay=1e-1, *, maximize: bool = False, + capturable: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= rho: + raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, rho=rho, + weight_decay=weight_decay, + maximize=maximize, capturable=capturable) + super(SophiaG, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('maximize', False) + group.setdefault('capturable', False) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def update_hessian(self): + for group in self.param_groups: + beta1, beta2 = group['betas'] + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + + + @torch.no_grad() + def step(self, closure=None, bs=5120): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + state_steps = [] + hessian = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError('Hero does not support sparse gradients') + grads.append(p.grad) + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + state_steps.append(state['step']) + hessian.append(state['hessian']) + + if self.defaults['capturable']: + bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs + + sophiag(params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group['rho'], + lr=group['lr'], + weight_decay=group['weight_decay'], + maximize=group['maximize'], + capturable=group['capturable']) + + return loss + +def sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool): + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + + func = _single_tensor_sophiag + + func(params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable) + +def _single_tensor_sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool): + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + hess = hessian[i] + step_t = state_steps[i] + + if capturable: + assert param.is_cuda and step_t.is_cuda and bs.is_cuda + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + hess = torch.view_as_real(hess) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + if capturable: + step_size = lr + step_size_neg = step_size.neg() + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) + else: + step_size_neg = - lr + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)