From e330ae342fc2b0c86875665fd4379c36fa69f6b9 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Wed, 20 Mar 2024 18:17:03 +0000 Subject: [PATCH 01/50] changes for dpo --- dpo_gpt_alcf.py | 750 ++++++++++++++++++++++ dpo_training.py | 1602 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 2352 insertions(+) create mode 100644 dpo_gpt_alcf.py create mode 100644 dpo_training.py diff --git a/dpo_gpt_alcf.py b/dpo_gpt_alcf.py new file mode 100644 index 0000000000..821c7cd8ab --- /dev/null +++ b/dpo_gpt_alcf.py @@ -0,0 +1,750 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain GPT""" + +import os +from rich import print +import torch +import math +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import GPTModel, GPTModelPipe +# from megatron.training import pretrain +from megatron.dpo_training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group, update_rotary_pos_emb +from megatron.arguments import core_transformer_config_from_args +from megatron.utils import ( + report_memory, + throughput_calculator, + checkpoint_throughput_calculator +) +from pathlib import Path + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator.real_accelerator import get_accelerator +import subprocess +import wandb + +import time +from torch import nn +import torch.nn.functional as F + +# from ezpz import get_logger +from ezpz.dist import get_world_size, setup_wandb, get_rank + +# RANK = setup_torch( +# backend='deepspeed', +# port='5432', +# ) +RANK = get_rank() +WORLD_SIZE = get_world_size() +LEVEL = "DEBUG" if RANK == 0 else "CRITICAL" + +WANDB_MODE = os.environ.get('WANDB_MODE', None) +DISABLE_WANDB = ( + WANDB_MODE is not None and str(WANDB_MODE).lower() == 'disabled' +) + +if RANK == 0 and not DISABLE_WANDB: + project_name = ( + os.environ.get( + 'WB_PROJECT', + os.environ.get( + 'WANDB_PROJECT', + 'AuroraGPT' + ), + ) + ) + print('--------------------------------------------------') + print(f"Setting up W&B from: {RANK} with {project_name}") + print('--------------------------------------------------') + setup_wandb(project_name=project_name) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + print_rank_0('building GPT model ...') + see_memory_usage("Before Building Model", force=True) + args = get_args() + config = core_transformer_config_from_args(args) + + if wandb.run is not None: + print(f"Updating WandB run: [{wandb.run.name}]({wandb.run.url})") + wandb.run.config.update({"args": vars(args)}) + if RANK == 0: + git_ds_info() + if hasattr(mpu, 'get_sequence_parallel_group'): + dpg = mpu.get_sequence_parallel_group() + elif hasattr(mpu, 'get_data_parallel_group'): + dpg = mpu.get_data_parallel_group() + else: + dpg = None + if wandb is not None and wandb.run is not None: + assert wandb is not None and wandb.run is not None + print(f'Updating {wandb.run.name=} at {wandb.run.url=}') + wandb.run.config.update({'args': vars(args)}) + with deepspeed.zero.Init( + data_parallel_group=dpg, + remote_device=( + None if args.remote_device == 'none' else args.remote_device + ), + config_dict_or_path=args.deepspeed_config_dict, + enabled=args.zero_stage == 3, + mpu=mpu + ): + if args.deepspeed and not args.no_pipeline_parallel: + model = GPTModelPipe( + config=config, + num_tokentypes=0, + parallel_output=True + ) + # This is a hack to give us a reference to + # get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + # Predompute the attention mask and store it in args. + # This avoids having to pipeline it + # as an activation during training. + # The mask is constant, and thus we can reuse it. + attention_mask = torch.tril( + torch.ones( + (1, args.seq_length, args.seq_length), + device=get_accelerator().current_device_name() + ) + ).view(1, 1, args.seq_length, args.seq_length) + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + # Attention mask must be bool. + args.attn_mask = attention_mask.to(torch.bool) + + # For prertaining, since sequence length is fixed, + # cache rotary embedding in args, to avoid communicating around + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(args.seq_length) + + else: + model = GPTModel( + config=config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + # print_rank_0('\n ------------------------ ') + # print_rank_0(f'num of parameters {num_params}') + # print_rank_0('------------------------\n ') + print_rank_0(80 * '-') + print_rank_0(f"Number of parameters in model: {num_params}") + print_rank_0(80 * '-') + see_memory_usage("After Building Model", force=True) + if wandb.run is not None: + wandb.run.config.update({'num_params': num_params}) + # wandb.run.watch( + # model, + # log='all', + # log_graph=True, + # ) + # wandb.run.config.update({'num_params': num_params}) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + data = next(data_iterator) if data_iterator is not None else None + # # Broadcast data. + # if data_iterator is not None: + # data = next(data_iterator) + # else: + # data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + skip_mask = args.use_flash_attn or args.use_flash_attn_triton + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + skip_mask) + + # For DS's sequence parallel + seq_parallel_world_size = mpu.get_sequence_parallel_world_size() + seq_parallel_world_rank = mpu.get_sequence_parallel_rank() + + # For Megatron's sequence parallel + if args.sequence_parallel: + seq_parallel_world_size = mpu.get_tensor_model_parallel_world_size() + seq_parallel_world_rank = mpu.get_tensor_model_parallel_rank() + seq_length = tokens.size(1) + + assert seq_length % seq_parallel_world_size == 0 + sub_seq_length = seq_length // seq_parallel_world_size + sub_seq_start = seq_parallel_world_rank * sub_seq_length + sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length + + tokens = tokens[:, sub_seq_start:sub_seq_end] + position_ids = position_ids[:, sub_seq_start:sub_seq_end] + # For DS's sequence parallel + if mpu.get_sequence_parallel_world_size() > 1: + labels = labels[:, sub_seq_start:sub_seq_end] + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def data_post_process(data, data_sampler_state_dict): + args = get_args() + if args.data_efficiency_curriculum_learning: + if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate'] + if current_seqlen < args.seq_length: + data['text'] = data['text'][:, :(current_seqlen+1)].contiguous() + elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape'] + if current_seqlen < args.seq_length: + orig_num_token = torch.numel(data['text']) + reshape_len = (data['text'].size()[1] // (current_seqlen+1)) * (current_seqlen+1) + data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen+1), + data['text'][:, -(current_seqlen+1):]), 0).contiguous() + num_row = math.ceil(orig_num_token / (current_seqlen+1)) + num_row = min(num_row, data['text'].size()[0]) + if num_row > 1 and num_row % 2 != 0: + num_row -= 1 + data['text'] = data['text'][:num_row, :].contiguous() + else: + args.data_efficiency_curriculum_learning_seqlen_type = None + return data + + +def get_batch_pipe(data): + """ + Modification of `get_batch` to work on `next(data_iterator)` + instead of `data_iterator` + """ + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + if ( + args.curriculum_learning_legacy + and args.curriculum_seqlen < tokens.size()[1] + ): + # seqlen-based curriculum learning + # tokens, position_ids, labels, loss_mask + # have size [batch size, seqlen] + tokens = tokens[:, :args.curriculum_seqlen].contiguous() + position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + if labels is not None: + labels = labels[:, :args.curriculum_seqlen].contiguous() + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + return (tokens, position_ids, attention_mask), (labels, loss_mask) + + +def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + if args.mos or args.kd: + # assert max(args.num_experts) >= 1 + loss = loss + moe_loss + mos_loss + if args.mos: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'mos loss': mos_loss + } + elif args.kd: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'kd loss': mos_loss + } + print_rank_0( + f'>>> total loss: {loss}, ' + f'lm loss {averaged_loss[0]}, ' + f'kd loss {mos_loss}' + ) + else: + if max(args.num_experts) <= 1: + return loss, {'lm loss': averaged_loss[0]} + loss = loss + moe_loss + return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + +def dpo_loss_func(loss_mask, dpo_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + if args.mos or args.kd: + # assert max(args.num_experts) >= 1 + loss = loss + moe_loss + mos_loss + if args.mos: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'mos loss': mos_loss + } + elif args.kd: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'kd loss': mos_loss + } + print_rank_0( + f'>>> total loss: {loss}, ' + f'lm loss {averaged_loss[0]}, ' + f'kd loss {mos_loss}' + ) + # else: + # if max(args.num_experts) <= 1: + # return loss, {'lm loss': averaged_loss[0]} + # loss = loss + moe_loss + # return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + else: + # if max(args.num_experts) <= 1: + # return loss, {'lm loss': averaged_loss[0]} + loss = dpo_loss + return loss, {'lm loss': averaged_loss[0], 'dpo loss': dpo_loss} + + +def calculate_mos_loss( + args, + stu_output, + teacher_model, + tokens, + position_ids, + attention_mask +): + mos_loss = 0 + alpha = args.kd_alpha_ce + beta = args.kd_beta_ce + kd_temp = args.kd_temp + + if teacher_model: + with torch.no_grad(): + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + curriculum_seqlen = args.curriculum_seqlen + tokens = tokens[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + csl = curriculum_seqlen + attention_mask = ( + attention_mask[:, :, :csl, :csl].contiguous() + ) + # No need to truncate labels + # as we do not need it for the teacher logits + tea_output, tea_other_losses = teacher_model( + tokens, + position_ids, + attention_mask + ) + assert stu_output.size() == tea_output.size(), ( + 'teacher and student output should match in size. ' + f'Student: {stu_output.size()}, ' + f'Teacher: {tea_output.size()}, ' + f'CL seq length {args.curriculum_seqlen}' + ) + + student_logits = F.log_softmax(stu_output / kd_temp, dim=2) + # The target logits is expected to be probabilities. + # If we use log_softmax, + # then we need to set target_log to true + # when initializing the KLDivLoss. + tea_logits = F.softmax(tea_output / kd_temp, dim=2) + + mos_loss = kd_temp * kd_temp * nn.KLDivLoss(reduction='batchmean')( + student_logits, + tea_logits + ) + + mos_loss = mos_loss.div(args.seq_length) * beta + return mos_loss + +def calculate_dpo_loss( + args, + stu_output, + teacher_model, + tokens, + position_ids, + attention_mask +): + mos_loss = 0 + alpha = args.kd_alpha_ce + beta = args.kd_beta_ce + kd_temp = args.kd_temp + kd_temp = 1.0 + beta = 0.1 + + if teacher_model: + with torch.no_grad(): + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + curriculum_seqlen = args.curriculum_seqlen + tokens = tokens[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + csl = curriculum_seqlen + attention_mask = ( + attention_mask[:, :, :csl, :csl].contiguous() + ) + # No need to truncate labels + # as we do not need it for the teacher logits + ref_output, ref_other_losses = teacher_model( + tokens, + position_ids, + attention_mask + ) + assert stu_output.size() == ref_output.size(), ( + 'ref and student output should match in size. ' + f'Student: {stu_output.size()}, ' + f'Reference: {ref_output.size()}, ' + f'CL seq length {args.curriculum_seqlen}' + ) + + student_logits = F.log_softmax(stu_output / kd_temp, dim=2) + # Labels ? + logprobs = torch.gather(student_logits, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + # The target logits is expected to be probabilities. + # If we use log_softmax, + # then we need to set target_log to true + # when initializing the KLDivLoss. + ref_logits = F.softmax(ref_output / kd_temp, dim=2) + ref_logprobs = torch.gather(ref_logits, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + # Partial DPO loss (from preferred/unpreferred) + logprob_ratio = logprobs - ref_logprobs + #------------ [ToDo]------------- + # # Get ratios of unpreferred log probabilities from model and ref model + # unpreferred_logprob_ratio = unpreferred_logprobs - ref_unpreferred_logprobs + + # Difference of logprobs ratios scaled by beta + # scaled_diff_logprob_ratios = self.beta * (preferred_logprob_ratio - unpreferred_logprob_ratio) + #------------ [ToDo]------------- + scaled_diff_logprob_ratios = beta * (logprob_ratio) + + # Losses computed as negative logsigmoid of scaled difference + dpo_loss = -F.logsigmoid(scaled_diff_logprob_ratios) + + return dpo_loss + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + # #---------Return a tuple----------- + # tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + # data_iterator) + # timers('batch-generator').stop() + # #-------------------------- + + if args.data_efficiency_curriculum_learning: + args.curriculum_seqlen = tokens.size()[1] + if ( + hasattr( + args, + 'data_efficiency_curriculum_learning_seqlen_type') + and ( + args.data_efficiency_curriculum_learning_seqlen_type + == 'seqlen_reshape' + ) + ): + args.data_efficiency_curriculum_learning_numel = ( + torch.numel(tokens) + ) + + if args.mos or args.kd: + # The forward func can return either the loss or the logits, + # depending on whether passing in the labels or not. + stu_output, other_losses = model(tokens, position_ids, attention_mask) + if ( + args.curriculum_learning_legacy + and args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + labels = labels[:, :args.curriculum_seqlen].contiguous() + output_tensor = tensor_parallel.vocab_parallel_cross_entropy( + stu_output.contiguous().float(), + labels + ) + else: + output_tensor, other_losses = model( + tokens[0], + position_ids[0], + attention_mask[0], + labels=labels[0] + ) + output_tensor_u, other_losses_u = model( + tokens[1], + position_ids[1], + attention_mask[1], + labels=labels[1] + ) + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + moe_losses = [] + for moe_loss in other_losses: + if moe_loss is not None: + moe_losses.append(moe_loss) + moe_loss = sum(moe_losses) * args.moe_loss_coeff + + mos_loss = 0 + if args.mos or args.kd: + assert model.training + if args.teacher_forward and args.teacher_model is not None: + mos_loss = calculate_mos_loss( + args, + stu_output, + args.teacher_model[0], + tokens, + position_ids, + attention_mask + ) + + dpo_loss = 0 + if args.teacher_model is not None: + dpo_loss = calculate_dpo_loss( + args, + stu_output, + args.teacher_model[0], + tokens, + position_ids, + attention_mask + ) + + # Output_tensor stores the standard loss, + # loss_func calculates the total loss. + return output_tensor, partial(dpo_loss_func, loss_mask, dpo_loss) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0('> building train, validation, and test datasets ' + 'for GPT ...') + files = [] + if args.data_file_list is not None: + with open(args.data_file_list, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files.append(float(w)) + files.append(fname) + elif len(args.data_path) == 1 and os.path.isdir(args.data_path[0]): + path = args.data_path[0] + "/" + for f in os.listdir(path): + if (os.path.isfile(path + f) and f.find(".bin") != -1): + files.append(1) + files.append(path + f.split(".bin")[0]) + else: + files = args.data_path + print_rank_0(f"file list {files}") + + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=files, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating GPT datasets ...") + + # #------------ Preferred -------------- + # train_ds_p, valid_ds_p, test_ds_p = build_train_valid_test_datasets( + # data_prefix=files, + # data_impl=args.data_impl, + # splits_string=args.split, + # train_valid_test_num_samples=train_val_test_num_samples, + # seq_length=args.seq_length, + # seed=args.seed, + # skip_warmup=True, + # # skip_warmup=(not args.mmap_warmup), + # train_data_prefix=args.train_data_path, + # valid_data_prefix=args.valid_data_path, + # test_data_prefix=args.test_data_path, + # data_cache_path=args.data_cache_path) + # print_rank_0("> finished creating GPT datasets ...") + + # #------------ Unpreferred -------------- + # train_ds_u, valid_ds_u, test_ds_u = build_train_valid_test_datasets( + # data_prefix=files, + # data_impl=args.data_impl, + # splits_string=args.split, + # train_valid_test_num_samples=train_val_test_num_samples, + # seq_length=args.seq_length, + # seed=args.seed, + # skip_warmup=True, + # # skip_warmup=(not args.mmap_warmup), + # train_data_prefix=args.train_data_path, + # valid_data_prefix=args.valid_data_path, + # test_data_prefix=args.test_data_path, + # data_cache_path=args.data_cache_path) + # print_rank_0("> finished creating GPT datasets ...") + + # Create a new Dataiterator with __getitem__() overwritten to give a tuple of [u,p] in train_ds + + return train_ds, valid_ds, test_ds + + +def command_exists(cmd): + result = subprocess.Popen( + f'type {cmd}', + stdout=subprocess.PIPE, + shell=True + ) + return result.wait() == 0 + + +def git_ds_info(): + if RANK != 0: + return + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print( + f'**** Git info for Megatron: ' + f'git_hash={git_hash} git_branch={git_branch} ****' + ) + + +def main(): + # if RANK == 0: + # setup_wandb() + if os.getenv('TORCH_PROFILER_ENABLED') == '1': + from torch.profiler import profile, record_function, ProfilerActivity + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + model = pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + data_post_process=data_post_process + ) + + prof.export_chrome_trace(f"{args.tensorboard_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}.json") + else: + model = pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + data_post_process=data_post_process + ) + return model + + +if __name__ == "__main__": + # git_ds_info() + # pretrain(train_valid_test_datasets_provider, + # model_provider, + # ModelType.encoder_or_decoder, + # forward_step, + # args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # data_post_process=data_post_process) + import sys + import deepspeed.comm as dist + model = main() + dist.log_summary() + if wandb.run is not None: + print(f"wandb.run.name: {wandb.run.name}") + print(f"wandb.run.url: {wandb.run.url}") + wandb.finish() + sys.exit() + diff --git a/dpo_training.py b/dpo_training.py new file mode 100644 index 0000000000..fbd5658648 --- /dev/null +++ b/dpo_training.py @@ -0,0 +1,1602 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain utilities.""" + +from datetime import datetime +import math +import sys +import time +import json +# The earliest we can measure the start time. +_TRAIN_START_TIME = time.time() +import torch +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP + +from megatron import get_args +from megatron import get_signal_handler +from megatron import get_timers +from megatron import get_tensorboard_writer +from megatron import get_current_global_batch_size +from megatron import get_num_microbatches +from megatron import is_last_rank +from megatron import update_num_microbatches +from megatron.core import mpu, tensor_parallel +from megatron import print_rank_0, is_rank_0 +from megatron import print_rank_last +from megatron.checkpointing import load_checkpoint +from megatron.checkpointing import save_checkpoint +from megatron.model import Float16Module +from megatron.model import GPTModel +from megatron.core.enums import ModelType +from megatron.optimizer import get_megatron_optimizer +from megatron.initialize import initialize_megatron +from megatron.initialize import write_args_to_tensorboard +from megatron.initialize import set_jit_fusion_options +from megatron.optimizer_param_scheduler import OptimizerParamScheduler +from megatron.model import DistributedDataParallel as LocalDDP +from megatron.utils import check_adlr_autoresume_termination +from megatron.utils import unwrap_model +from megatron.data.data_samplers import build_pretraining_data_loader +from megatron.utils import calc_params_l2_norm +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator, update_rotary_pos_emb +from megatron.model.vision.knn_monitor import compute_feature_bank +from megatron.arguments import core_transformer_config_from_args + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.compression.compress import init_compression, redundancy_clean +from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd +from megatron.model.transformer import ParallelTransformerLayer + +from deepspeed import comm as dist + +try: + import wandb +except (ImportError, ModuleNotFoundError): + wandb = None + + +def print_datetime(string): + """Note that this call will sync across all ranks.""" + torch.distributed.barrier() + time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + print_rank_0('[' + string + '] datetime: {} '.format(time_str)) + +''' +Since v0.9.0, deepspeed.initialize() has forbidden simultaneous setting of args.deepspeed_config (Path) and ds_config dict. +So, we use ds_config dict which is the more flexible option. +''' +def _create_ds_config_dict(): + args = get_args() + if isinstance(args.deepspeed_config, dict) : + ds_config_dict = args.deepspeed_config + else: + with open(args.deepspeed_config, 'r', encoding='utf-8') as config_file: + ds_config_dict = json.load(config_file) + + if args.universal_checkpoint: + ds_config_dict["checkpoint"] = {"load_universal": True} + + # Clear config path + args.deepspeed_config = None + + return ds_config_dict + + +def pretrain(train_valid_test_dataset_provider, + model_provider, + model_type, + forward_step_func, + process_non_loss_data_func=None, + extra_args_provider=None, + args_defaults={}, + data_post_process=None, + external_args={}): + """Main training program. + + This function will run the followings in the order provided: + 1) initialize Megatron. + 2) setup model, optimizer and lr schedule using the model_provider. + 3) call train_val_test_data_provider to get train/val/test datasets. + 4) train the modle using the forward_step_func. + + Arguments: + train_valid_test_dataset_provider: a function that takes the size of + train/valid/test dataset and returns `train, valid, test` datasets. + model_provider: a function that returns a vanilla version of the + model. By vanilla we mean a simple model on cpu with no fp16 or ddp. + model_type: an enum that specifies the type of model being trained. + forward_step_func: a function that takes a `data iterator` and `model`, + and returns a `loss` scalar with a dictionary with key:values being + the info we would like to monitor during training, for example + `lm-loss: value`. We also require that this function add + `batch generator` to the timers class. + process_non_loss_data_func: a function to post process outputs of the + network. It can be used for dumping output tensors (e.g images) to + tensorboard. It takes `collected data`(list of tensors), + `current iteration index` and `tensorboard writer` as arguments. + extra_args_provider: a function that takes a parser and adds arguments + to it. It is used for programs to add their own arguments. + args_defaults: a dictionary from argument-name to argument-value. It + to set already parse arguments. + """ + + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron(extra_args_provider=extra_args_provider, + args_defaults=args_defaults, external_args=external_args) + # Set pytorch JIT layer fusion options and warmup JIT functions. + if get_accelerator().device_name() == 'cuda': + set_jit_fusion_options() + + # Adjust the startup time so it reflects the largest value. + # This will be closer to what scheduler will see (outside of + # image ... launches. + global _TRAIN_START_TIME + start_time_tensor = get_accelerator().DoubleTensor([_TRAIN_START_TIME]) + torch.distributed.all_reduce(start_time_tensor, + op=torch.distributed.ReduceOp.MIN) + _TRAIN_START_TIME = start_time_tensor.item() + print_rank_0('time to initialize megatron (seconds): {:.3f}'.format( + time.time() - _TRAIN_START_TIME)) + print_datetime('after megatron is initialized') + + args = get_args() + timers = get_timers() + + if args.deepspeed: + args.deepspeed_config_dict = _create_ds_config_dict() + if "curriculum_learning" in args.deepspeed_config_dict and \ + "enabled" in args.deepspeed_config_dict["curriculum_learning"]: + args.curriculum_learning_legacy = args.deepspeed_config_dict[ \ + "curriculum_learning"]["enabled"] + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + from deepspeed.runtime.data_pipeline.curriculum_scheduler \ + import CurriculumScheduler + args.curriculum_scheduler = CurriculumScheduler( \ + args.deepspeed_config_dict["curriculum_learning"]) + if "compression_training" in args.deepspeed_config_dict: + args.compression_training = True + + # Model, optimizer, and learning rate. + timers('model-and-optimizer-setup', log_level=0).start(barrier=True) + model, optimizer, opt_param_scheduler = setup_model_and_optimizer( + model_provider, model_type, teacher=False, data_post_process=data_post_process, + build_train_valid_test_datasets_provider=train_valid_test_dataset_provider) + timers('model-and-optimizer-setup').stop() + print_datetime('after model, optimizer, and learning rate ' + 'scheduler are built') + + # Data stuff. + timers('train/valid/test-data-iterators-setup', log_level=0).start( + barrier=True) + if args.virtual_pipeline_model_parallel_size is not None: + all_data_iterators = [ + build_train_valid_test_data_iterators( + train_valid_test_dataset_provider) + for _ in range(len(model)) + ] + train_data_iterator = [data_iterators[0] + for data_iterators in all_data_iterators] + valid_data_iterator = [data_iterators[1] + for data_iterators in all_data_iterators] + test_data_iterator = [data_iterators[2] + for data_iterators in all_data_iterators] + else: + train_data_iterator, valid_data_iterator, test_data_iterator \ + = build_train_valid_test_data_iterators( + train_valid_test_dataset_provider) + if args.data_efficiency_curriculum_learning: + if args.deepspeed_dataloader is not None: + # We use args to pass the deepspeed_dataloader because adding + # output to setup_model_and_optimizer will break the API for other + # cases. We clear args.deepspeed_dataloader after updating + # train_data_iterator because args will be saved in checkpoint and + # attempting to save the whole deepspeed_dataloader will lead to + # "AttributeError: Can't pickle local object...". + train_data_iterator = iter(args.deepspeed_dataloader) + args.deepspeed_dataloader = None + else: + train_data_iterator = None + timers('train/valid/test-data-iterators-setup').stop() + print_datetime('after dataloaders are built') + + # args.teacher_model is used as global variable to pass the teacher model + # for knowledge distillation. Users do not need to set it in the command + # line to use kd, but users do need to provide teacher model configurations + # like args.num_layers_teacher as described in setup_teacher_model() + args.teacher_model = None + if args.mos or args.kd: # Set up teacher model + args.teacher_model = setup_teacher_model(args, model_provider) + + # ToDo + args.teacher_model, _, _ = load_model_weights_only(model_provider) + + # Print setup timing. + print_rank_0('done with setup ...') + timers.log(['model-and-optimizer-setup', + 'train/valid/test-data-iterators-setup'], barrier=True) + + if not args.skip_train: + print_rank_0('training ...') + + if args.dataloader_type == 'cyclic' and args.retro_add_retriever: + args.train_iters = args.retro_cyclic_train_iters + print_rank_0("retro cyclic train iters : %d" % args.train_iters) + + iteration = 0 + if args.do_train and args.train_iters > 0: + iteration = train(forward_step_func, + model, optimizer, opt_param_scheduler, + train_data_iterator, valid_data_iterator, + process_non_loss_data_func) + + print_datetime('after training is done') + # Clean the model + if args.compression_training: + model = [redundancy_clean(model[0], args.deepspeed_config_dict, mpu)] + + if args.save and iteration != 0: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + else: + print_rank_0('skipping training (--skip-train is on) ...') + + iteration = args.iteration + + config = core_transformer_config_from_args(args) + if args.do_valid: + prefix = f'iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from validation set' + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, config, + verbose=True, write_to_tensorboard=not args.skip_train) + + if args.do_test: + prefix = f'iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from test set' + evaluate_and_print_results(prefix, forward_step_func, + test_data_iterator, model, + iteration, process_non_loss_data_func, config, + verbose=True, write_to_tensorboard=not args.skip_train, test=True) + return model + + +def update_train_iters(args): + + # For iteration-based training, we don't need to do anything + if args.train_iters: + return + + # Constant batch size with sample-based training. + if args.rampup_batch_size is None: + args.train_iters = args.train_samples // args.global_batch_size + + else: + # Sample based training with rampup batch size. + iterations = 0 + consumed_samples = 0 + # Rampup phase. + while consumed_samples <= int(args.rampup_batch_size[2]): + update_num_microbatches(consumed_samples, consistency_check=False) + consumed_samples += get_current_global_batch_size() + iterations += 1 + # Reset + update_num_microbatches(0, consistency_check=False) + # Constant phase + # Note that we throw away any partial last batch. + iterations += (args.train_samples - consumed_samples) // \ + args.global_batch_size + args.train_iters = iterations + + print_rank_0('setting training iterations to {}'.format(args.train_iters)) + + +def setup_teacher_model(args, model_provider): + + print_rank_0('***>>>>> Student model checkpoint iteration:{}'.format(args.iteration)) + iteration_stuent = args.iteration + num_layers_student = args.num_layers + num_experts_student = args.num_experts + hidden_size_student = args.hidden_size + num_attention_heads_student = args.num_attention_heads + load_student = args.load + + print_rank_0('***>>>>> Setting up the teacher model') + + args.num_layers = args.num_layers_teacher + args.num_experts = args.num_experts_teacher + args.hidden_size = args.hidden_size_teacher + args.num_attention_heads = args.num_attention_heads_teacher + args.load = args.load_teacher + teacher_model, _, _ = load_model_weights_only(model_provider) + print_rank_0('***>>>>> Teacher model:{}'.format(teacher_model)) + + args.num_layers = num_layers_student + args.num_experts = num_experts_student + args.hidden_size = hidden_size_student + args.num_attention_heads = num_attention_heads_student + args.load = load_student + args.iteration = iteration_stuent + + return teacher_model + +def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): + """Build the model.""" + args = get_args() + args.model_type = model_type + + # Build model. + if mpu.get_pipeline_model_parallel_world_size() > 1 and \ + args.virtual_pipeline_model_parallel_size is not None: + assert model_type != ModelType.encoder_and_decoder, \ + "Interleaved schedule not supported for model with both encoder and decoder" + model = [] + for i in range(args.virtual_pipeline_model_parallel_size): + mpu.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + this_model = model_provider_func( + pre_process=pre_process, + post_process=post_process + ) + this_model.model_type = model_type + model.append(this_model) + else: + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + add_encoder = True + add_decoder = True + if model_type == ModelType.encoder_and_decoder: + if mpu.get_pipeline_model_parallel_world_size() > 1: + assert args.pipeline_model_parallel_split_rank is not None, \ + "Split rank needs to be specified for model with both encoder and decoder" + rank = mpu.get_pipeline_model_parallel_rank() + split_rank = args.pipeline_model_parallel_split_rank + world_size = mpu.get_pipeline_model_parallel_world_size() + pre_process = rank == 0 or rank == split_rank + post_process = (rank == (split_rank - 1)) or ( + rank == (world_size - 1)) + add_encoder = mpu.is_pipeline_stage_before_split() + add_decoder = mpu.is_pipeline_stage_after_split() + model = model_provider_func( + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder) + else: + model = model_provider_func( + pre_process=pre_process, + post_process=post_process + ) + model.model_type = model_type + + + if not isinstance(model, list): + model = [model] + + # Disallow training and inference with Transformer Engine + # for non-GPT models + args.allow_transformer_engine = all([type(m) == GPTModel for m in model]) + assert args.allow_transformer_engine or args.transformer_impl == 'local', \ + 'Transformer Engine is only approved for GPT models' + + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # Print number of parameters. + if mpu.get_data_parallel_rank() == 0: + print(' > number of parameters on (tensor, pipeline) ' + 'model parallel rank ({}, {}): {}'.format( + mpu.get_tensor_model_parallel_rank(), + mpu.get_pipeline_model_parallel_rank(), + sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()]) + for model_module in model])), flush=True) + + if args.deepspeed: + return model + + # GPU allocation. + for model_module in model: + model_module.to(get_accelerator().current_device_name()) + + + # Fp16 conversion. + if args.fp16 or args.bf16: + model = [Float16Module(model_module, args) for model_module in model] + + if wrap_with_ddp: + if args.DDP_impl == 'torch': + i = get_accelerator().current_device() + model = [torchDDP(model_module, device_ids=[i], output_device=i, + process_group=mpu.get_data_parallel_group()) + for model_module in model] + + elif args.DDP_impl == 'local': + model = [LocalDDP(model_module, + args.accumulate_allreduce_grads_in_fp32, + args.use_contiguous_buffers_in_local_ddp) + for model_module in model] + # broad cast params from data parallel src rank to other data parallel ranks + if args.data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + else: + raise NotImplementedError('Unknown DDP implementation specified: ' + '{}. Exiting.'.format(args.DDP_impl)) + + return model + + +def get_optimizer_param_scheduler(optimizer): + """Build the learning rate scheduler.""" + args = get_args() + + # Iteration-based training. + if args.train_iters: + if args.lr_decay_iters is None: + args.lr_decay_iters = args.train_iters + lr_decay_steps = args.lr_decay_iters * args.global_batch_size + wd_incr_steps = args.train_iters * args.global_batch_size + if args.lr_warmup_fraction is not None: + lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps + else: + lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size + # Sample-based training. + elif args.train_samples: + # We need to set training iters for later use. Technically + # we need to adjust the training samples too (due to last + # batch being incomplete) but we leave it as is for now. + update_train_iters(args) + if args.lr_decay_samples is None: + args.lr_decay_samples = args.train_samples + lr_decay_steps = args.lr_decay_samples + wd_incr_steps = args.train_samples + if args.lr_warmup_fraction is not None: + lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps + else: + lr_warmup_steps = args.lr_warmup_samples + else: + raise Exception( + 'either train-iters or train-samples should be provided.') + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + max_lr=args.lr, + min_lr=args.min_lr, + lr_warmup_steps=lr_warmup_steps, + lr_decay_steps=lr_decay_steps, + lr_decay_style=args.lr_decay_style, + start_wd=args.start_weight_decay, + end_wd=args.end_weight_decay, + wd_incr_steps=wd_incr_steps, + wd_incr_style=args.weight_decay_incr_style, + use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler, + override_opt_param_scheduler=args.override_opt_param_scheduler) + + return opt_param_scheduler + +def load_model_weights_only(model_provider_func): + """Setup model and optimizer.""" + args = get_args() + print_rank_0('***>>>>> Args:{}'.format(args)) + + model = get_model(model_provider_func) + + optimizer = None + lr_scheduler = None + + if args.deepspeed: + # When loading just the model weights, ZeRO can be disabled. + if 'zero_optimization' in args.deepspeed_config_dict: + del args.deepspeed_config_dict['zero_optimization'] + + model, optimizer, _, lr_scheduler = deepspeed.initialize( + model=model[0], + config=args.deepspeed_config_dict + ) + + assert not isinstance(model, deepspeed.PipelineEngine), \ + 'Weight loading only mode is not supported in pipeline parallelism yet.' + + model = [model] + + print_datetime('before load checkpoint') + if args.load is not None: + iteration = load_checkpoint(model, optimizer, lr_scheduler, strict=True, load_only_weights=True) + + print_datetime('after load checkpoint weights') + + return model, optimizer, lr_scheduler + + +def setup_model_and_optimizer(model_provider_func, + model_type, + no_wd_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0, + teacher=False, + data_post_process=None, + build_train_valid_test_datasets_provider=None): + """Setup model and optimizer.""" + args = get_args() + + model = get_model(model_provider_func, model_type) + + # initialize the compression here + student_global_steps = 0 + if args.kd or args.mos: + model, _, _, _ = deepspeed.initialize( + model=model[0], + args=args, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + model = [model] + if args.load is not None: + args.iteration = load_checkpoint(model, None, None, strict=False) + else: + args.iteration = 0 + student_global_steps = model[0].global_steps + print_rank_0('***>>>>> Student model, global step:{}'.format(student_global_steps)) + + if args.compression_training: + model, _, _, _ = deepspeed.initialize( + model=model[0], + args=args, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + model = [model] + model = [init_compression(model[0].module, args.deepspeed_config_dict, mpu)] + + unwrapped_model = unwrap_model(model, + (torchDDP, LocalDDP, Float16Module)) + + if args.inference: + optimizer = None + opt_param_scheduler = None + else: + if teacher: + optimizer = None + else: + optimizer = get_megatron_optimizer(model, no_wd_decay_cond, + scale_lr_cond, lr_mult) + # opt_param_scheduler is the old lr_scheduler plus weight decay scheduling + opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + + if args.deepspeed: + print_rank_0("DeepSpeed is enabled.") + pp = mpu.get_pipeline_model_parallel_world_size() + if args.data_efficiency_curriculum_learning and build_train_valid_test_datasets_provider is not None: + train_ds = None + # Only need to build dataset on tp rank 0 since Megatron has the + # broadcast_data() function that broadcast data from tp rank 0. + if mpu.get_tensor_model_parallel_rank() == 0: + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + update_train_iters(args) + else: + train_samples = args.train_iters * args.global_batch_size + # eval_iters and test_iters here are not actually used, only for + # satisfying the input of build_train_valid_test_datasets_provider. + # We only need to build the training data here. And we follow + # baseline's logic to build eval/test dataset later in + # build_train_valid_test_data_iterators. + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + # Build the datasets. + train_ds, _, _ = build_train_valid_test_datasets_provider( + train_val_test_num_samples) + model, optimizer, args.deepspeed_dataloader, opt_param_scheduler = deepspeed.initialize( + model=model[0], + optimizer=optimizer, + args=args, + lr_scheduler=opt_param_scheduler, + training_data=train_ds, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + model.set_data_post_process_func(data_post_process) + else: + model, optimizer, _, opt_param_scheduler = deepspeed.initialize( + model=model[0], + optimizer=optimizer, + args=args, + lr_scheduler=opt_param_scheduler, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + if isinstance(model, deepspeed.PipelineEngine): + # hack to get batch_fn from pretrain_gpt.py + model.set_batch_fn(model.module._megatron_batch_fn) + + assert model.grid.get_pipe_parallel_rank() == mpu.get_pipeline_model_parallel_rank() + assert model.grid.get_slice_parallel_rank() == mpu.get_tensor_model_parallel_rank() + assert model.grid.get_data_parallel_rank() == mpu.get_data_parallel_rank() + model = [model] + + # Compression has its own checkpoint loading path (e.g, loading both teacher and student models). So if compression is enabled, we skip the following checkpoint loading. + no_post_init_checkpoint_loading = args.kd or args.mos + if not no_post_init_checkpoint_loading: + if args.load is not None: + timers = get_timers() + timers('load-checkpoint', log_level=0).start(barrier=True) + args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) + timers('load-checkpoint').stop(barrier=True) + timers.log(['load-checkpoint']) + else: + args.iteration = 0 + else: + model[0].global_steps = student_global_steps + + # We only support local DDP with multiple micro-batches. + if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1: + assert args.DDP_impl == 'local' + + # get model without FP16 and/or TorchDDP wrappers + if args.iteration == 0 and len(unwrapped_model) == 1 \ + and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): + print_rank_0("Initializing ICT from pretrained BERT model") + unwrapped_model[0].init_state_dict_from_bert() + if args.fp16: + optimizer.reload_model_params() + + # random-LTD requires converting transformer layers + if args.random_ltd: + model[0] = convert_to_random_ltd(model[0], ParallelTransformerLayer) + + return model, optimizer, opt_param_scheduler + + + +def train_step(forward_step_func, data_iterator, + model, optimizer, opt_param_scheduler, config): + """Single training step.""" + args = get_args() + timers = get_timers() + + if args.deepspeed and args.ds_pipeline_enabled: + skipped_iter = 0 + num_zeros_in_grad = 0 + assert isinstance(model[0], deepspeed.PipelineEngine) + loss = model[0].train_batch(data_iter=data_iterator) + grad_norm = model[0].get_global_grad_norm() + return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad + + # Set grad to zero. + if not args.deepspeed: + if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp: + for partition in model: + partition.zero_grad_buffer() + optimizer.zero_grad() + + # Forward pass. + timers('forward-backward', log_level=1).start( + barrier=args.barrier_with_L1_time) + forward_backward_func = get_forward_backward_func() + if args.mos or args.kd: + # args.teacher_forward is used as global variable to enable kd loss + # calculation in forward pass. Users do not need to set it in the + # command line to use kd. + args.teacher_forward = True + + # set timers to None if none of the timers in fwd_bwd are active, just to save the checks + if args.timing_log_level < 2: + config.timers = None + + losses_reduced = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=False) + + # reset timers if necessary + if config.timers is None: + config.timers = timers + timers('forward-backward').stop() + if args.mos or args.kd: + args.teacher_forward = False + + # Empty unused memory. + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + # Reduce gradients. + if not args.deepspeed: + optimizer.reduce_model_grads(args, timers) + + # Vision gradients. + if args.vision_pretraining and args.vision_pretraining_type == "dino": + unwrapped_model = unwrap_model(model[0], + (torchDDP, LocalDDP, Float16Module)) + unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) + + # Update parameters. + timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) + if args.deepspeed: + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + model[0].step(lr_kwargs={'increment': increment}) + update_successful = model[0].was_step_applied() + else: + update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers) + timers('optimizer').stop() + + # Gather params. + if not args.deepspeed and update_successful: + optimizer.gather_model_params(args, timers) + + # Vision momentum. + if args.vision_pretraining and args.vision_pretraining_type == "dino": + unwrapped_model = unwrap_model(model[0], + (torchDDP, LocalDDP, Float16Module)) + unwrapped_model.update_momentum(args.curr_iteration) + + # Update learning rate. + if args.deepspeed: + skipped_iter = 0 + grad_norm = None + num_zeros_in_grad = None + + loss_reduced = {} + for key in losses_reduced[0]: + losses_reduced_for_key = [x[key] for x in losses_reduced] + loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) + return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad + else: + if update_successful: + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + opt_param_scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + # Empty unused memory. + if args.empty_unused_memory_level >= 2: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + for key in losses_reduced[0]: + losses_reduced_for_key = [x[key] for x in losses_reduced] + loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) + return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad + return {}, skipped_iter, grad_norm, num_zeros_in_grad + + +def training_log(loss_dict, total_loss_dict, learning_rate, iteration, + loss_scale, report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad, + model=None, optimizer=None): + """Log training information such as losses, timing, ....""" + args = get_args() + timers = get_timers() + writer = get_tensorboard_writer() + + # Advanced, skipped, and Nan iterations. + advanced_iters_key = 'advanced iterations' + skipped_iters_key = 'skipped iterations' + nan_iters_key = 'nan iterations' + # Advanced iterations. + if not skipped_iter: + total_loss_dict[advanced_iters_key] = total_loss_dict.get( + advanced_iters_key, 0) + 1 + else: + if advanced_iters_key not in total_loss_dict: + total_loss_dict[advanced_iters_key] = 0 + # Skipped iterations. + total_loss_dict[skipped_iters_key] = total_loss_dict.get( + skipped_iters_key, 0) + skipped_iter + # Update losses and set nan iterations + got_nan = False + for key in loss_dict: + if not skipped_iter: + total_loss_dict[key] = total_loss_dict.get( + key, get_accelerator().FloatTensor([0.0])) + loss_dict[key] + else: + value = loss_dict[key].float().sum().item() + is_nan = value == float('inf') or \ + value == -float('inf') or \ + value != value + got_nan = got_nan or is_nan + total_loss_dict[nan_iters_key] = total_loss_dict.get( + nan_iters_key, 0) + int(got_nan) + + # Logging. + timers_to_log = [ + 'forward-backward', + 'forward-compute', + 'backward-compute', + 'batch-generator', + 'forward-recv', + 'forward-send', + 'backward-recv', + 'backward-send', + 'forward-send-forward-recv', + 'forward-send-backward-recv', + 'backward-send-forward-recv', + 'backward-send-backward-recv', + 'forward-backward-send-forward-backward-recv', + 'layernorm-grads-all-reduce', + 'embedding-grads-all-reduce', + 'grads-all-reduce', + 'grads-reduce-scatter', + 'params-all-gather', + 'optimizer-copy-to-main-grad', + 'optimizer-unscale-and-check-inf', + 'optimizer-clip-main-grad', + 'optimizer-count-zeros', + 'optimizer-inner-step', + 'optimizer-copy-main-to-model-params', + 'optimizer'] + + # Calculate batch size. + batch_size = args.micro_batch_size * args.data_parallel_size * \ + get_num_microbatches() + + total_iterations = total_loss_dict[advanced_iters_key] + \ + total_loss_dict[skipped_iters_key] + + # Tensorboard values. + # Timer requires all the ranks to call. + if args.log_timers_to_tensorboard and \ + (iteration % args.tensorboard_log_interval == 0): + timers.write(timers_to_log, writer, iteration, + normalizer=total_iterations) + if writer and (iteration % args.tensorboard_log_interval == 0): + writer.add_scalar('steps-vs-samples/y=steps,x=samples', iteration, args.consumed_train_samples) + writer.add_scalar('steps-vs-samples/y=samples,x=steps', args.consumed_train_samples, iteration) + writer.add_scalar('steps-vs-tokens/y=steps,x=tokens', iteration, args.consumed_train_tokens) + writer.add_scalar('steps-vs-tokens/y=tokens,x=steps', args.consumed_train_tokens, iteration) + if args.log_learning_rate_to_tensorboard: + writer.add_scalar('learning-rate/learning-rate', learning_rate, iteration) + writer.add_scalar('learning-rate/learning-rate vs samples', learning_rate, + args.consumed_train_samples) + writer.add_scalar('learning-rate/learning-rate vs tokens', learning_rate, + args.consumed_train_tokens) + if args.log_batch_size_to_tensorboard: + writer.add_scalar('batch-size/batch-size', batch_size, iteration) + writer.add_scalar('batch-size/batch-size vs samples', batch_size, + args.consumed_train_samples) + writer.add_scalar('batch-size/batch-size vs tokens', batch_size, + args.consumed_train_tokens) + for key in loss_dict: + writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration) + writer.add_scalar(f"lm-loss-training/{key}" + ' vs samples', loss_dict[key], + args.consumed_train_samples) + writer.add_scalar(f"lm-loss-training/{key}" + ' vs tokens', loss_dict[key], + args.consumed_train_tokens) + if args.fp16 and args.log_loss_scale_to_tensorboard: + writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration) + writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale, + args.consumed_train_samples) + writer.add_scalar('loss-scale/loss-scale vs tokens', loss_scale, + args.consumed_train_tokens) + if args.log_world_size_to_tensorboard: + writer.add_scalar('world-size/world-size', args.world_size, iteration) + writer.add_scalar('world-size/world-size vs samples', args.world_size, + args.consumed_train_samples) + writer.add_scalar('world-size/world-size vs tokens', args.world_size, + args.consumed_train_tokens) + if grad_norm is not None: + writer.add_scalar('grad-norm/grad-norm', grad_norm, iteration) + writer.add_scalar('grad-norm/grad-norm vs samples', grad_norm, + args.consumed_train_samples) + writer.add_scalar('grad-norm/grad-norm vs tokens', grad_norm, + args.consumed_train_tokens) + if num_zeros_in_grad is not None: + writer.add_scalar('num-zeros/num-zeros', num_zeros_in_grad, iteration) + writer.add_scalar('num-zeros/num-zeros vs samples', num_zeros_in_grad, + args.consumed_train_samples) + writer.add_scalar('num-zeros/num-zeros vs tokens', num_zeros_in_grad, + args.consumed_train_tokens) + if params_norm is not None: + writer.add_scalar('params-norm/params-norm', params_norm, iteration) + writer.add_scalar('params-norm/params-norm vs samples', params_norm, + args.consumed_train_samples) + writer.add_scalar('params-norm/params-norm vs tokens', params_norm, + args.consumed_train_tokens) + if hasattr(args, 'actual_seq_length'): + writer.add_scalar('seqlen/actual_seq_length', args.actual_seq_length, + iteration) + writer.add_scalar('seqlen/actual_seq_length vs samples', args.actual_seq_length, + args.consumed_train_samples) + writer.add_scalar('seqlen/actual_seq_length vs tokens', args.actual_seq_length, + args.consumed_train_tokens) + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + writer.add_scalar('seqlen/curriculum_seqlen', args.curriculum_seqlen, + iteration) + writer.add_scalar('seqlen/curriculum_seqlen vs samples', args.curriculum_seqlen, + args.consumed_train_samples) + writer.add_scalar('seqlen/curriculum_seqlen vs tokens', args.curriculum_seqlen, + args.consumed_train_tokens) + if args.random_ltd: + writer.add_scalar('seqlen/random_ltd_reserved_length', args.random_ltd_reserved_length, + iteration) + writer.add_scalar('seqlen/random_ltd_reserved_length vs samples', args.random_ltd_reserved_length, + args.consumed_train_samples) + writer.add_scalar('seqlen/random_ltd_reserved_length vs tokens', args.random_ltd_reserved_length, + args.consumed_train_tokens) + if args.log_memory_to_tensorboard: + mem_stats = torch.cuda.memory_stats() + writer.add_scalar( + "mem-reserved-bytes", + mem_stats["reserved_bytes.all.current"], + iteration, + ) + writer.add_scalar( + "mem-allocated-bytes", + mem_stats["allocated_bytes.all.current"], + iteration, + ) + writer.add_scalar( + "mem-allocated-count", + mem_stats["allocation.all.current"], + iteration, + ) + + if iteration % args.tensorboard_log_interval == 0: + # This logging write various optimizer states to tensorboard. This + # feature may consume extra GPU memory thus is set at false by default. + if args.log_optimizer_states_to_tensorboard and optimizer is not None: + opt_stats = [0.0] * 8 + opt_stats_2 = [0.0] * 4 + for _, group in enumerate(optimizer.param_groups): + for _, param in enumerate(group['params']): + opt_stats[0] += (torch.norm(optimizer.state[param]['exp_avg_sq']).item())**2 + opt_stats[1] += (torch.norm(optimizer.state[param]['exp_avg_sq'].sqrt()).item())**2 + opt_stats[2] += (torch.norm(optimizer.state[param]['exp_avg']).item())**2 + opt_stats[3] += (torch.norm(param).item())**2 + opt_stats[4] += torch.norm(optimizer.state[param]['exp_avg_sq'],p=1).item() + opt_stats[5] += torch.norm(optimizer.state[param]['exp_avg_sq'].sqrt(),p=1).item() + opt_stats[6] += torch.norm(optimizer.state[param]['exp_avg'],p=1).item() + opt_stats[7] += torch.norm(param,p=1).item() + opt_stats_2[0] = max(opt_stats_2[0], abs(optimizer.state[param]['exp_avg_sq'].max().item()), abs(optimizer.state[param]['exp_avg_sq'].min().item())) + opt_stats_2[1] = max(opt_stats_2[1], optimizer.state[param]['exp_avg_sq'].sqrt().abs_().max().item()) + opt_stats_2[2] = max(opt_stats_2[2], abs(optimizer.state[param]['exp_avg'].max().item()), abs(optimizer.state[param]['exp_avg'].min().item())) + opt_stats_2[3] = max(opt_stats_2[3], abs(param.max().item()), abs(param.min().item())) + # print('step {} rank {} before sync opt_stats {}, {}'.format(iteration, torch.distributed.get_rank(), opt_stats_2, opt_stats)) + if args.zero_stage > 0: + # ZeRO partiions optimizer states + opt_stats = get_accelerator().FloatTensor(opt_stats) + torch.distributed.all_reduce(opt_stats, group=mpu.get_sequence_data_parallel_group()) + opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) + torch.distributed.all_reduce(opt_stats_2, op=torch.distributed.ReduceOp.MAX, + group=mpu.get_sequence_data_parallel_group()) + + if args.tensor_model_parallel_size > 1: + opt_stats = get_accelerator().FloatTensor(opt_stats) + torch.distributed.all_reduce(opt_stats, group=mpu.get_tensor_model_parallel_group()) + opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) + torch.distributed.all_reduce(opt_stats_2, op=torch.distributed.ReduceOp.MAX, + group=mpu.get_tensor_model_parallel_group()) + + if args.pipeline_model_parallel_size > 1: + opt_stats = get_accelerator().FloatTensor(opt_stats) + torch.distributed.all_reduce(opt_stats, group=mpu.get_pipeline_model_parallel_group()) + opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) + torch.distributed.all_reduce(opt_stats_2, op=torch.distributed.ReduceOp.MAX, + group=mpu.get_pipeline_model_parallel_group()) + + # print('step {} rank {} after sync opt_stats {}, {}'.format(iteration, torch.distributed.get_rank(), opt_stats_2, opt_stats)) + if writer and is_last_rank(): + writer.add_scalar('optimizer/variance_l2 vs tokens', opt_stats[0]**0.5, args.consumed_train_tokens) + writer.add_scalar('optimizer/variance_sqrt_l2 vs tokens', opt_stats[1]**0.5, args.consumed_train_tokens) + writer.add_scalar('optimizer/momentum_l2 vs tokens', opt_stats[2]**0.5, args.consumed_train_tokens) + writer.add_scalar('optimizer/weight_l2 vs tokens', opt_stats[3]**0.5, args.consumed_train_tokens) + writer.add_scalar('optimizer/variance_l1 vs tokens', opt_stats[4], args.consumed_train_tokens) + writer.add_scalar('optimizer/variance_sqrt_l1 vs tokens', opt_stats[5], args.consumed_train_tokens) + writer.add_scalar('optimizer/momentum_l1 vs tokens', opt_stats[6], args.consumed_train_tokens) + writer.add_scalar('optimizer/weight_l1 vs tokens', opt_stats[7], args.consumed_train_tokens) + writer.add_scalar('optimizer/variance_abs_max vs tokens', opt_stats_2[0], args.consumed_train_tokens) + writer.add_scalar('optimizer/variance_sqrt_abs_max vs tokens', opt_stats_2[1], args.consumed_train_tokens) + writer.add_scalar('optimizer/momentum_abs_max vs tokens', opt_stats_2[2], args.consumed_train_tokens) + writer.add_scalar('optimizer/weight_abs_max vs tokens', opt_stats_2[3], args.consumed_train_tokens) + + writer.add_scalar('optimizer/variance_l2', opt_stats[0]**0.5, iteration) + writer.add_scalar('optimizer/variance_sqrt_l2', opt_stats[1]**0.5, iteration) + writer.add_scalar('optimizer/momentum_l2', opt_stats[2]**0.5, iteration) + writer.add_scalar('optimizer/weight_l2', opt_stats[3]**0.5, iteration) + writer.add_scalar('optimizer/variance_l1', opt_stats[4], iteration) + writer.add_scalar('optimizer/variance_sqrt_l1', opt_stats[5], iteration) + writer.add_scalar('optimizer/momentum_l1', opt_stats[6], iteration) + writer.add_scalar('optimizer/weight_l1', opt_stats[7], iteration) + writer.add_scalar('optimizer/variance_abs_max', opt_stats_2[0], iteration) + writer.add_scalar('optimizer/variance_sqrt_abs_max', opt_stats_2[1], iteration) + writer.add_scalar('optimizer/momentum_abs_max', opt_stats_2[2], iteration) + writer.add_scalar('optimizer/weight_abs_max', opt_stats_2[3], iteration) + + assert args is not None + if iteration % args.log_interval == 0: + elapsed_time = timers('interval-time').elapsed(barrier=True) + elapsed_time_per_iteration = elapsed_time / total_iterations + seq_len = args.seq_length + if hasattr(args, 'actual_seq_length'): + seq_len = args.actual_seq_length + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( + model, + args, + elapsed_time, + total_iterations + ) + samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size + tokens_per_sec = samples_per_sec * seq_len + tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size + tokens_per_gpu_per_second = tokens_per_sec / args.world_size + tokens_per_gpu_per_second_per_replica = tokens_per_gpu_per_second / args.data_parallel_size + wandb_metrics = {} + if wandb is not None and getattr(wandb, 'run', None) is not None: + assert wandb.run is not None + wandb_metrics = { + 'throughput/iteration-time': elapsed_time_per_iteration, # 1000 ms / s + 'throughput/samples_per_sec': samples_per_sec, + 'throughput/samples_per_sec_per_replica': samples_per_sec_per_replica, + 'throughput/tokens_per_sec': tokens_per_sec, + 'throughput/tokens_per_sec_per_replica': tokens_per_sec_per_replica, + 'throughput/tokens_per_gpu_per_sec': tokens_per_gpu_per_second, + 'throughput/tokens_per_gpu_per_sec_per_replica': tokens_per_gpu_per_second_per_replica, + 'throughput/tflops': tflops, + 'throughput/approx_params_in_billions': approx_parameters_in_billions, + 'throughput/elapsed_ms_per_iteration': elapsed_time_per_iteration, + 'throughput/iteration': iteration, + } + if loss_dict is not None: + wandb_metrics |= { + 'loss/iteration': iteration, + **{f'loss/{k}': v for k, v in loss_dict.items()} + } + if writer and args.log_timers_to_tensorboard: + writer.add_scalar('iteration-time/iteration-time', + elapsed_time_per_iteration, iteration) + writer.add_scalar('iteration-time/iteration-time vs samples', + elapsed_time_per_iteration, args.consumed_train_samples) + writer.add_scalar('iteration-time/iteration-time vs tokens', + elapsed_time_per_iteration, args.consumed_train_tokens) + log_string = ' iteration {:8d}/{:8d} |'.format( + iteration, args.train_iters) + log_string += ' consumed samples: {:12d} |'.format( + args.consumed_train_samples) + log_string += ' consumed tokens: {:12d} |'.format( + args.consumed_train_tokens) + log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( + elapsed_time_per_iteration * 1000.0) + log_string += ' learning rate: {:.3E} |'.format(learning_rate) + log_string += ' global batch size: {:5d} |'.format(batch_size) + if wandb is not None and getattr(wandb, 'run', None) is not None: + wandb_metrics |= { + 'training/iteration': iteration, + 'training/iteration_time': elapsed_time_per_iteration, + 'training/iteration_time_vs_tokens': ( + (elapsed_time_per_iteration + / args.consumed_train_tokens) + ), + 'training/iteration_time_vs_samples': ( + (elapsed_time_per_iteration + / args.consumed_train_samples), + ), + 'training/consumed_samples': args.consumed_train_samples, + 'training/consumed_tokens': args.consumed_train_tokens, + } + for key in total_loss_dict: + if key not in [advanced_iters_key, skipped_iters_key, + nan_iters_key]: + avg = total_loss_dict[key].item() / \ + float(max(1, total_loss_dict[advanced_iters_key])) + if avg > 0.0: + log_string += ' {}: {:.6E} |'.format(key, avg) + total_loss_dict[key] = get_accelerator().FloatTensor([0.0]) + if loss_scale is not None: + log_string += ' loss scale: {:.1f} |'.format(loss_scale) + wandb_metrics |= {'loss/loss_scale': loss_scale} + if grad_norm is not None: + log_string += ' grad norm: {:.3f} |'.format(grad_norm) + wandb_metrics |= {'loss/grad_norm': grad_norm} + if num_zeros_in_grad is not None: + log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad) + wandb_metrics |= {'loss/num_zeros_in_grad': num_zeros_in_grad} + if params_norm is not None: + log_string += ' params norm: {:.3f} |'.format(params_norm) + wandb_metrics |= {'loss/params_norm': params_norm} + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + log_string += ' curriculum seqlen: {:5d} |'.format(args.curriculum_seqlen) + if args.random_ltd: + log_string += ' random ltd reserved length: {:5d} |'.format(args.random_ltd_reserved_length) + log_string += ' actual seqlen: {:5d} |'.format(seq_len) + log_string += ' number of skipped iterations: {:3d} |'.format( + total_loss_dict[skipped_iters_key]) + log_string += ' number of nan iterations: {:3d} |'.format( + total_loss_dict[nan_iters_key]) + log_string += ' samples per second: {:.3f} |'.format(samples_per_sec) + log_string += ' tokens per gpu per second (tgs): {:.3f} |'.format(tokens_per_gpu_per_second) + log_string += ' TFLOPs: {:.2f} |'.format(tflops) + total_loss_dict[advanced_iters_key] = 0 + total_loss_dict[skipped_iters_key] = 0 + total_loss_dict[nan_iters_key] = 0 + print_rank_last(log_string) + if report_memory_flag and learning_rate > 0.: + # Report memory after optimizer state has been initialized. + report_memory('(after {} iterations)'.format(iteration)) + report_memory_flag = False + if wandb is not None and getattr(wandb, 'run', None) is not None: + wandb_metrics |= {'training/skiped_iterations': total_loss_dict[skipped_iters_key]} + wandb_metrics |= {'training/nan_iterations': total_loss_dict[nan_iters_key]} + wandb.log(wandb_metrics) + if timers is not None: + timers.log(timers_to_log, normalizer=args.log_interval) + + return report_memory_flag + + +def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler): + timers = get_timers() + # Extra barrier is added to make sure + # all ranks report the max time. + # assert timers is not None + timers('save-checkpoint', log_level=0).start(barrier=True) + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + timers('save-checkpoint').stop(barrier=True) + checkpoint_throughput_calculator(model, timers('save-checkpoint').elapsed(reset=False)) + timers.log(['save-checkpoint']) + + +def train(forward_step_func, model, optimizer, opt_param_scheduler, + train_data_iterator, valid_data_iterator, + process_non_loss_data_func): + """Train the model function.""" + args = get_args() + timers = get_timers() + + # Write args to tensorboard + write_args_to_tensorboard() + + if args.random_ltd: + # random-ltd requires different randomness on each rank + import random + random.seed(args.seed + torch.distributed.get_rank()) + + # Turn on training mode which enables dropout. + for model_module in model: + model_module.train() + + # Tracking loss. + total_loss_dict = {} + + # Iterations. + iteration = args.iteration + + # Translate args to core configuration + config = core_transformer_config_from_args(args) + if not args.deepspeed: + config.grad_scale_func = optimizer.scale_loss + config.timers = timers + + timers('interval-time', log_level=0).start(barrier=True) + print_datetime('before the start of training step') + report_memory_flag = True + if args.random_ltd: + assert model[0].random_ltd_enabled() + args.random_ltd_layer_num = model[0].random_ltd_scheduler.get_random_ltd_layer_num() + + while iteration < args.train_iters and (args.train_tokens is None or \ + args.consumed_train_tokens < args.train_tokens): + update_num_microbatches(args.consumed_train_samples) + if args.deepspeed: + # inform deepspeed of any batch size changes + global_batch_size = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + model[0].set_train_batch_size(global_batch_size) + + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \ + args.iteration + 1) + if iteration == 0 or curriculum_seqlen != args.curriculum_seqlen: + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(curriculum_seqlen) + args.curriculum_seqlen = curriculum_seqlen + args.curr_iteration = iteration + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ + train_step(forward_step_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config) + iteration += 1 + args.iteration = iteration + new_samples = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + args.consumed_train_samples += new_samples + # This actual_seq_length is used for actual consumed tokens calculation, flops calculation, and logging. + args.actual_seq_length = args.seq_length + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + args.actual_seq_length = args.curriculum_seqlen + if args.random_ltd: + args.random_ltd_reserved_length = model[0].random_ltd_scheduler.get_current_seq() + if args.random_ltd_reserved_length < args.actual_seq_length: + args.actual_seq_length = (args.actual_seq_length * (args.num_layers - args.random_ltd_layer_num) + args.random_ltd_reserved_length * args.random_ltd_layer_num) // args.num_layers + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + if hasattr(args, 'data_efficiency_curriculum_learning_numel'): + act_mbsz = args.data_efficiency_curriculum_learning_numel / args.curriculum_seqlen + act_token = act_mbsz * args.actual_seq_length + args.consumed_train_tokens += mpu.get_data_parallel_world_size() * \ + get_num_microbatches() * act_token + else: + args.consumed_train_tokens += new_samples * args.actual_seq_length + else: + args.consumed_train_tokens += new_samples * args.actual_seq_length + + # Logging. + if args.deepspeed: + if hasattr(model[0].optimizer, 'cur_scale'): + loss_scale = model[0].optimizer.cur_scale + else: + loss_scale = None + else: + loss_scale = optimizer.get_loss_scale().item() + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(model) + report_memory_flag = training_log(loss_dict, total_loss_dict, + optimizer.param_groups[0]['lr'], + iteration, loss_scale, + report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad, + model, optimizer) + + # Autoresume + if args.adlr_autoresume and \ + (iteration % args.adlr_autoresume_interval == 0): + check_adlr_autoresume_termination(iteration, model, optimizer, + opt_param_scheduler) + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0 and \ + args.do_valid: + prefix = 'iteration {}'.format(iteration) + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, + config, False) + + # Checkpointing + saved_checkpoint = False + if args.exit_signal_handler: + signal_handler = get_signal_handler() + if any(signal_handler.signals_received()): + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + print_datetime('exiting program after receiving SIGTERM.') + sys.exit() + + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + saved_checkpoint = True + + # Exiting based on duration + if args.exit_duration_in_mins: + train_time = (time.time() - _TRAIN_START_TIME) / 60.0 + done_cuda = get_accelerator().IntTensor( + [train_time > args.exit_duration_in_mins]) + torch.distributed.all_reduce( + done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + if done: + if not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + print_datetime('exiting program after {} minutes'.format(train_time)) + sys.exit() + + # Exiting based on iterations + if args.exit_interval and iteration % args.exit_interval == 0: + if args.save and not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + torch.distributed.barrier() + print_datetime('exiting program at iteration {}'.format(iteration)) + sys.exit() + + + return iteration + + +def evaluate(forward_step_func, + data_iterator, + model, + process_non_loss_data_func, + config, + verbose=False): + """Evaluation.""" + args = get_args() + + if args.vision_pretraining and args.vision_pretraining_type == "dino": + compute_feature_bank(model) + + # Turn on evaluation mode which disables dropout. + for model_module in model: + model_module.eval() + + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + # When curriculum learning is used with pipeline parallelism, we need + # this logic to ensure that the eval data is not truncated. If there + # is a seqlen change due to that, we need to call + # reset_activation_shape() to reset some buffers in deepspeed pipeline + # engine. + if args.curriculum_seqlen < args.seq_length: + args.curriculum_seqlen = args.seq_length + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(args.curriculum_seqlen) + model[0].reset_activation_shape() + + total_loss_dict = {} + + with torch.no_grad(): + iteration = 0 + while iteration < args.eval_iters: + iteration += 1 + if verbose and iteration % args.log_interval == 0: + print_rank_0('Evaluating iter {}/{}'.format(iteration, + args.eval_iters)) + + forward_backward_func = get_forward_backward_func() + # Don't care about timing during evaluation + config.timers = None + if args.deepspeed and args.ds_pipeline_enabled: + # DeepSpeed uses eval_batch() and already aggregates losses. + assert isinstance(model, list) and len(model) == 1 + loss = model[0].eval_batch(data_iterator) + loss_dicts = [{'lm loss' : loss}] * get_num_microbatches() + else: + loss_dicts = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True) + config.timers = get_timers() + + # Empty unused memory + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Reduce across processes. + for loss_dict in loss_dicts: + for key in loss_dict: + if 'moe' not in key: + total_loss_dict[key] = total_loss_dict.get( + key, get_accelerator().FloatTensor([0.0])) + loss_dict[key] + + args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ + * args.micro_batch_size \ + * get_num_microbatches() + collected_non_loss_data = None + if process_non_loss_data_func is not None and is_last_rank(): + collected_non_loss_data = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True, + collect_non_loss_data=True) + + # Move model back to the train mode. + for model_module in model: + model_module.train() + + for key in total_loss_dict: + total_loss_dict[key] /= args.eval_iters * get_num_microbatches() + + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + # roll back to actual curriculum seqlen at the end of eval. + args.curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \ + args.iteration + 1) + if args.curriculum_seqlen < args.seq_length: + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(args.curriculum_seqlen) + model[0].reset_activation_shape() + + return total_loss_dict, collected_non_loss_data + +def evaluate_and_print_results(prefix, forward_step_func, + data_iterator, model, + iteration, process_non_loss_data_func, config, + verbose=False, write_to_tensorboard=True, test=False): + """Helper function to evaluate and dump results on screen.""" + args = get_args() + if write_to_tensorboard: + writer = get_tensorboard_writer() + else: + writer = None + + total_loss_dict, collected_non_loss_data = evaluate( + forward_step_func, data_iterator, model, + process_non_loss_data_func, config, verbose) + string = ' validation loss at {} | '.format(prefix) + for key in total_loss_dict: + string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) + ppl = math.exp(min(20, total_loss_dict[key].item())) + string += '{} PPL: {:.6E} | '.format(key, ppl) + if writer and is_last_rank(): + data_type = 'test' if test else 'validation' + writer.add_scalar(f'lm-loss-validation/{key} {data_type}', + total_loss_dict[key].item(), + iteration) + writer.add_scalar(f'lm-loss-validation/{key} {data_type} vs samples', + total_loss_dict[key].item(), + args.consumed_train_samples) + writer.add_scalar(f'lm-loss-validation/{key} {data_type} vs tokens', + total_loss_dict[key].item(), + args.consumed_train_tokens) + if args.log_validation_ppl_to_tensorboard: + writer.add_scalar(f'lm-loss-validation/{key} {data_type} ppl', ppl, + iteration) + writer.add_scalar(f'lm-loss-validation/{key} {data_type} ppl vs samples', + ppl, args.consumed_train_samples) + writer.add_scalar(f'lm-loss-validation/{key} {data_type} ppl vs tokens', + ppl, args.consumed_train_tokens) + + if process_non_loss_data_func is not None and writer and is_last_rank(): + process_non_loss_data_func(collected_non_loss_data, iteration, writer) + + length = len(string) + 1 + print_rank_last('-' * length) + print_rank_last(string) + print_rank_last('-' * length) + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x + + +def build_train_valid_test_datasets(build_train_valid_test_datasets_provider): + """Build pretraining datasets.""" + + args = get_args() + + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + else: + train_samples = args.train_iters * args.global_batch_size + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print_rank_0(' > datasets target sizes (minimum size):') + print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) + print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) + print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) + + # Build the datasets. + return build_train_valid_test_datasets_provider(train_val_test_num_samples) + + +def build_train_valid_test_data_loaders( + build_train_valid_test_datasets_provider): + """Build pretraining data loaders.""" + + args = get_args() + + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) + + print_rank_0('> building train, validation, and test datasets ...') + + # Backward compatibility, assume fixed batch size. + if args.iteration > 0 and args.consumed_train_samples == 0: + assert args.train_samples is None, \ + 'only backward compatiblity support for iteration-based training' + args.consumed_train_samples = args.iteration * args.global_batch_size + if args.iteration > 0 and args.consumed_valid_samples == 0: + if args.train_samples is None: + args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ + args.eval_iters * args.global_batch_size + + # Data loader only on rank 0 of each model parallel group. + ds_sequence_parallel = mpu.get_sequence_parallel_world_size() > 1 or args.force_ds_sequence_parallel + rank_in_parallel_group = mpu.get_sequence_parallel_rank() if ds_sequence_parallel else mpu.get_tensor_model_parallel_rank() + if rank_in_parallel_group == 0: + # Build datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + build_train_valid_test_datasets_provider) + + # Build dataloders. + train_dataloader = build_pretraining_data_loader( + train_ds, args.consumed_train_samples) + valid_dataloader = build_pretraining_data_loader( + valid_ds, args.consumed_valid_samples) + test_dataloader = build_pretraining_data_loader(test_ds, 0) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and args.train_iters > 0 + do_valid = valid_dataloader is not None and args.eval_iters > 0 + do_test = test_dataloader is not None and args.eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = get_accelerator().LongTensor( + [int(do_train), int(do_valid), int(do_test)]) + else: + flags = get_accelerator().LongTensor([0, 0, 0]) + + # Broadcast num tokens. + if ds_sequence_parallel: + torch.distributed.broadcast(flags, + mpu.get_sequence_parallel_src_rank(), + group=mpu.get_sequence_parallel_group()) + else: + torch.distributed.broadcast(flags, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + args.do_train = flags[0].item() + args.do_valid = flags[1].item() + args.do_test = flags[2].item() + + return train_dataloader, valid_dataloader, test_dataloader + + +def build_train_valid_test_data_iterators( + build_train_valid_test_datasets_provider): + """Build pretraining data iterators.""" + + args = get_args() + + # Build loaders. + train_dataloader, valid_dataloader, test_dataloader = \ + build_train_valid_test_data_loaders( + build_train_valid_test_datasets_provider) + + # Build iterators. + dl_type = args.dataloader_type + assert dl_type in ['single', 'cyclic'] + + if train_dataloader is not None: + train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader)) + else: + train_data_iterator = None + + if valid_dataloader is not None: + valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(valid_dataloader)) + else: + valid_data_iterator = None + + if test_dataloader is not None: + test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(test_dataloader)) + else: + test_data_iterator = None + + return train_data_iterator, valid_data_iterator, test_data_iterator + From 5df25e746c676bb1f8bc9949d48cbf52267de816 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Wed, 20 Mar 2024 18:19:16 +0000 Subject: [PATCH 02/50] moved dpo_training script --- dpo_training.py => megatron/dpo_training.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename dpo_training.py => megatron/dpo_training.py (100%) diff --git a/dpo_training.py b/megatron/dpo_training.py similarity index 100% rename from dpo_training.py rename to megatron/dpo_training.py From 461524eed92039652921d522010f4ee6b70fea31 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 21 Mar 2024 16:19:23 +0000 Subject: [PATCH 03/50] commented out assert in load checkpoint func --- megatron/checkpointing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 3e4d20035e..f67ab00bf4 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -597,8 +597,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri # Check arguments. reset_train_valid_samples = args.reset_iteration if not load_only_weights and not reset_train_valid_samples: - assert args.consumed_train_samples == 0 - assert args.consumed_valid_samples == 0 + # assert args.consumed_train_samples == 0 + # assert args.consumed_valid_samples == 0 if 'args' in state_dict and not args.finetune: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) From 2c6677112e56e60030501f2df3215c48d220b4db Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 21 Mar 2024 16:21:47 +0000 Subject: [PATCH 04/50] tuple structure removed --- dpo_gpt_alcf.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/dpo_gpt_alcf.py b/dpo_gpt_alcf.py index 821c7cd8ab..c4ecd21c6e 100644 --- a/dpo_gpt_alcf.py +++ b/dpo_gpt_alcf.py @@ -539,16 +539,10 @@ def forward_step(data_iterator, model): ) else: output_tensor, other_losses = model( - tokens[0], - position_ids[0], - attention_mask[0], - labels=labels[0] - ) - output_tensor_u, other_losses_u = model( - tokens[1], - position_ids[1], - attention_mask[1], - labels=labels[1] + tokens, + position_ids, + attention_mask, + labels=labels ) if ( args.curriculum_learning_legacy and From f983241bef011350af5ed19ee05610ba1b06269c Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 21 Mar 2024 16:23:27 +0000 Subject: [PATCH 05/50] loading correct env --- ALCF/helpers.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ALCF/helpers.sh b/ALCF/helpers.sh index bc2adb26fa..00843d0018 100644 --- a/ALCF/helpers.sh +++ b/ALCF/helpers.sh @@ -217,7 +217,8 @@ setEnv() { elif [[ $(hostname) == x3* ]]; then echo "Running on Polaris !!" # ---- [load conda] --------------------- - module load conda/2023-10-04; conda activate cu118-pt221 ; unset PYTHONUSERBASE + # module load conda/2023-10-04; conda activate cu118-pt221 ; unset PYTHONUSERBASE + module load conda/2023-10-04; conda activate base # module load conda/2023-10-04 ; conda activate /lus/eagle/projects/datascience/foremans/miniconda3/envs/polaris/py311-cu118 # ; conda activate /lus/eagle/projects/datascience/foremans/miniconda3/envs/polaris/2024-03-06 # export PYTHONUSERBASE="${HOME}/.local/polaris/conda/py311-cu118" From 1aeebcaa8be204d43f70f8d824f91eb4f4114bf6 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 21 Mar 2024 16:24:59 +0000 Subject: [PATCH 06/50] modified main func --- pretrain_gpt_modified.py | 682 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 682 insertions(+) create mode 100644 pretrain_gpt_modified.py diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py new file mode 100644 index 0000000000..5055c90cd4 --- /dev/null +++ b/pretrain_gpt_modified.py @@ -0,0 +1,682 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain GPT""" + +import os +from rich import print +import torch +import math + +# The earliest we can measure the start time. +import time +from datetime import datetime + +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import GPTModel, GPTModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group, update_rotary_pos_emb +from megatron.arguments import core_transformer_config_from_args +from megatron.utils import ( + report_memory, + throughput_calculator, + checkpoint_throughput_calculator +) +from pathlib import Path + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator.real_accelerator import get_accelerator +import subprocess +import wandb + +import time +from torch import nn +import torch.nn.functional as F + +# from ezpz import get_logger +from ezpz.dist import get_world_size, setup_wandb, get_rank + +# More imports +from megatron.initialize import initialize_megatron +from megatron.initialize import set_jit_fusion_options +from megatron.training import print_datetime, _create_ds_config_dict +from megatron.training import setup_model_and_optimizer +from megatron.training import load_model_weights_only, get_model +from megatron.training import load_model_weights_only_modified +from megatron.training import get_optimizer_param_scheduler +from megatron.optimizer import get_megatron_optimizer +from megatron.checkpointing import load_checkpoint + +# RANK = setup_torch( +# backend='deepspeed', +# port='5432', +# ) +RANK = get_rank() +WORLD_SIZE = get_world_size() +LEVEL = "DEBUG" if RANK == 0 else "CRITICAL" + +WANDB_MODE = os.environ.get('WANDB_MODE', None) +DISABLE_WANDB = ( + WANDB_MODE is not None and str(WANDB_MODE).lower() == 'disabled' +) + +if RANK == 0 and not DISABLE_WANDB: + project_name = ( + os.environ.get( + 'WB_PROJECT', + os.environ.get( + 'WANDB_PROJECT', + 'AuroraGPT' + ), + ) + ) + print('--------------------------------------------------') + print(f"Setting up W&B from: {RANK} with {project_name}") + print('--------------------------------------------------') + setup_wandb(project_name=project_name) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + print_rank_0('building GPT model ...') + see_memory_usage("Before Building Model", force=True) + args = get_args() + config = core_transformer_config_from_args(args) + if wandb.run is not None: + print(f"Updating WandB run: [{wandb.run.name}]({wandb.run.url})") + wandb.run.config.update({"args": vars(args)}, allow_val_change=True) + if RANK == 0: + git_ds_info() + if hasattr(mpu, 'get_sequence_parallel_group'): + dpg = mpu.get_sequence_parallel_group() + elif hasattr(mpu, 'get_data_parallel_group'): + dpg = mpu.get_data_parallel_group() + else: + dpg = None + if wandb is not None and wandb.run is not None: + assert wandb is not None and wandb.run is not None + print(f'Updating {wandb.run.name=} at {wandb.run.url=}') + wandb.run.config.update({'args': vars(args)}, allow_val_change=True) + with deepspeed.zero.Init( + data_parallel_group=dpg, + remote_device=( + None if args.remote_device == 'none' else args.remote_device + ), + config_dict_or_path=args.deepspeed_config_dict, + enabled=args.zero_stage == 3, + mpu=mpu + ): + if args.deepspeed and not args.no_pipeline_parallel: + model = GPTModelPipe( + config=config, + num_tokentypes=0, + parallel_output=True + ) + # This is a hack to give us a reference to + # get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + # Predompute the attention mask and store it in args. + # This avoids having to pipeline it + # as an activation during training. + # The mask is constant, and thus we can reuse it. + attention_mask = torch.tril( + torch.ones( + (1, args.seq_length, args.seq_length), + device=get_accelerator().current_device_name() + ) + ).view(1, 1, args.seq_length, args.seq_length) + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + # Attention mask must be bool. + args.attn_mask = attention_mask.to(torch.bool) + + # For prertaining, since sequence length is fixed, + # cache rotary embedding in args, to avoid communicating around + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(args.seq_length) + + else: + print(f'Building model check..') + model = GPTModel( + config=config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + # print_rank_0('\n ------------------------ ') + # print_rank_0(f'num of parameters {num_params}') + # print_rank_0('------------------------\n ') + print_rank_0(80 * '-') + print_rank_0(f"Number of parameters in model: {num_params}") + print_rank_0(80 * '-') + see_memory_usage("After Building Model", force=True) + if wandb.run is not None: + wandb.run.config.update({'num_params': num_params}, allow_val_change=True) + # wandb.run.watch( + # model, + # log='all', + # log_graph=True, + # ) + # wandb.run.config.update({'num_params': num_params}) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + data = next(data_iterator) if data_iterator is not None else None + # # Broadcast data. + # if data_iterator is not None: + # data = next(data_iterator) + # else: + # data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + skip_mask = args.use_flash_attn or args.use_flash_attn_triton + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + skip_mask) + + # For DS's sequence parallel + seq_parallel_world_size = mpu.get_sequence_parallel_world_size() + seq_parallel_world_rank = mpu.get_sequence_parallel_rank() + + # For Megatron's sequence parallel + if args.sequence_parallel: + seq_parallel_world_size = mpu.get_tensor_model_parallel_world_size() + seq_parallel_world_rank = mpu.get_tensor_model_parallel_rank() + seq_length = tokens.size(1) + + assert seq_length % seq_parallel_world_size == 0 + sub_seq_length = seq_length // seq_parallel_world_size + sub_seq_start = seq_parallel_world_rank * sub_seq_length + sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length + + tokens = tokens[:, sub_seq_start:sub_seq_end] + position_ids = position_ids[:, sub_seq_start:sub_seq_end] + # For DS's sequence parallel + if mpu.get_sequence_parallel_world_size() > 1: + labels = labels[:, sub_seq_start:sub_seq_end] + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def data_post_process(data, data_sampler_state_dict): + args = get_args() + if args.data_efficiency_curriculum_learning: + if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate'] + if current_seqlen < args.seq_length: + data['text'] = data['text'][:, :(current_seqlen+1)].contiguous() + elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape'] + if current_seqlen < args.seq_length: + orig_num_token = torch.numel(data['text']) + reshape_len = (data['text'].size()[1] // (current_seqlen+1)) * (current_seqlen+1) + data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen+1), + data['text'][:, -(current_seqlen+1):]), 0).contiguous() + num_row = math.ceil(orig_num_token / (current_seqlen+1)) + num_row = min(num_row, data['text'].size()[0]) + if num_row > 1 and num_row % 2 != 0: + num_row -= 1 + data['text'] = data['text'][:num_row, :].contiguous() + else: + args.data_efficiency_curriculum_learning_seqlen_type = None + return data + + +def get_batch_pipe(data): + """ + Modification of `get_batch` to work on `next(data_iterator)` + instead of `data_iterator` + """ + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + if ( + args.curriculum_learning_legacy + and args.curriculum_seqlen < tokens.size()[1] + ): + # seqlen-based curriculum learning + # tokens, position_ids, labels, loss_mask + # have size [batch size, seqlen] + tokens = tokens[:, :args.curriculum_seqlen].contiguous() + position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + if labels is not None: + labels = labels[:, :args.curriculum_seqlen].contiguous() + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + return (tokens, position_ids, attention_mask), (labels, loss_mask) + + +def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + if args.mos or args.kd: + # assert max(args.num_experts) >= 1 + loss = loss + moe_loss + mos_loss + if args.mos: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'mos loss': mos_loss + } + elif args.kd: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'kd loss': mos_loss + } + print_rank_0( + f'>>> total loss: {loss}, ' + f'lm loss {averaged_loss[0]}, ' + f'kd loss {mos_loss}' + ) + else: + if max(args.num_experts) <= 1: + return loss, {'lm loss': averaged_loss[0]} + loss = loss + moe_loss + return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + + +def calculate_mos_loss( + args, + stu_output, + teacher_model, + tokens, + position_ids, + attention_mask +): + mos_loss = 0 + alpha = args.kd_alpha_ce + beta = args.kd_beta_ce + kd_temp = args.kd_temp + + if teacher_model: + with torch.no_grad(): + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + curriculum_seqlen = args.curriculum_seqlen + tokens = tokens[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + csl = curriculum_seqlen + attention_mask = ( + attention_mask[:, :, :csl, :csl].contiguous() + ) + # No need to truncate labels + # as we do not need it for the teacher logits + tea_output, tea_other_losses = teacher_model( + tokens, + position_ids, + attention_mask + ) + assert stu_output.size() == tea_output.size(), ( + 'teacher and student output should match in size. ' + f'Student: {stu_output.size()}, ' + f'Teacher: {tea_output.size()}, ' + f'CL seq length {args.curriculum_seqlen}' + ) + + student_logits = F.log_softmax(stu_output / kd_temp, dim=2) + # The target logits is expected to be probabilities. + # If we use log_softmax, + # then we need to set target_log to true + # when initializing the KLDivLoss. + tea_logits = F.softmax(tea_output / kd_temp, dim=2) + + mos_loss = kd_temp * kd_temp * nn.KLDivLoss(reduction='batchmean')( + student_logits, + tea_logits + ) + + mos_loss = mos_loss.div(args.seq_length) * beta + return mos_loss + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + if args.data_efficiency_curriculum_learning: + args.curriculum_seqlen = tokens.size()[1] + if ( + hasattr( + args, + 'data_efficiency_curriculum_learning_seqlen_type') + and ( + args.data_efficiency_curriculum_learning_seqlen_type + == 'seqlen_reshape' + ) + ): + args.data_efficiency_curriculum_learning_numel = ( + torch.numel(tokens) + ) + + if args.mos or args.kd: + # The forward func can return either the loss or the logits, + # depending on whether passing in the labels or not. + stu_output, other_losses = model(tokens, position_ids, attention_mask) + if ( + args.curriculum_learning_legacy + and args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + labels = labels[:, :args.curriculum_seqlen].contiguous() + output_tensor = tensor_parallel.vocab_parallel_cross_entropy( + stu_output.contiguous().float(), + labels + ) + else: + output_tensor, other_losses = model( + tokens, + position_ids, + attention_mask, + labels=labels + ) + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + moe_losses = [] + for moe_loss in other_losses: + if moe_loss is not None: + moe_losses.append(moe_loss) + moe_loss = sum(moe_losses) * args.moe_loss_coeff + + mos_loss = 0 + if args.mos or args.kd: + assert model.training + if args.teacher_forward and args.teacher_model is not None: + mos_loss = calculate_mos_loss( + args, + stu_output, + args.teacher_model[0], + tokens, + position_ids, + attention_mask + ) + + # Output_tensor stores the standard loss, + # loss_func calculates the total loss. + return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0('> building train, validation, and test datasets ' + 'for GPT ...') + files = [] + if args.data_file_list is not None: + with open(args.data_file_list, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files.append(float(w)) + files.append(fname) + elif len(args.data_path) == 1 and os.path.isdir(args.data_path[0]): + path = args.data_path[0] + "/" + for f in os.listdir(path): + if (os.path.isfile(path + f) and f.find(".bin") != -1): + files.append(1) + files.append(path + f.split(".bin")[0]) + else: + files = args.data_path + print_rank_0(f"file list {files}") + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=files, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +def command_exists(cmd): + result = subprocess.Popen( + f'type {cmd}', + stdout=subprocess.PIPE, + shell=True + ) + return result.wait() == 0 + + +def git_ds_info(): + if RANK != 0: + return + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print( + f'**** Git info for Megatron: ' + f'git_hash={git_hash} git_branch={git_branch} ****' + ) + + +def main(): + # if RANK == 0: + # setup_wandb() + if os.getenv('TORCH_PROFILER_ENABLED') == '1': + from torch.profiler import profile, record_function, ProfilerActivity + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron( + # extra_args_provider=extra_args_provider, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # external_args=external_args + ) + # Set pytorch JIT layer fusion options and warmup JIT functions. + if get_accelerator().device_name() == 'cuda': + set_jit_fusion_options() + + args = get_args() + timers = get_timers() + + # model = model_provider() + model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + + prof.export_chrome_trace(f"{args.tensorboard_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}.json") + else: + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron( + # extra_args_provider=extra_args_provider, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # external_args=external_args + ) + # Set pytorch JIT layer fusion options and warmup JIT functions. + if get_accelerator().device_name() == 'cuda': + set_jit_fusion_options() + + args = get_args() + timers = get_timers() + + if args.deepspeed: + args.deepspeed_config_dict = _create_ds_config_dict() + if "curriculum_learning" in args.deepspeed_config_dict and \ + "enabled" in args.deepspeed_config_dict["curriculum_learning"]: + args.curriculum_learning_legacy = args.deepspeed_config_dict[ \ + "curriculum_learning"]["enabled"] + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + from deepspeed.runtime.data_pipeline.curriculum_scheduler \ + import CurriculumScheduler + args.curriculum_scheduler = CurriculumScheduler( \ + args.deepspeed_config_dict["curriculum_learning"]) + if "compression_training" in args.deepspeed_config_dict: + args.compression_training = True + + # model = model_provider() + model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + # model_ref, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) # throwing assertion error + model_ref = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes? + # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider) + optimizer_2 = get_megatron_optimizer(model_ref, None, None, 1.0) + opt_param_scheduler_2 = get_optimizer_param_scheduler(optimizer_2) + model_ref, optimizer_2, _, opt_param_scheduler_2 = deepspeed.initialize( + model=model_ref[0], + optimizer=optimizer_2, + args=args, + lr_scheduler=opt_param_scheduler_2, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + if isinstance(model_ref, deepspeed.PipelineEngine): + print(f'Doing assertion checks on model_ref..') + # hack to get batch_fn from pretrain_gpt.py + model_ref.set_batch_fn(model_ref.module._megatron_batch_fn) + + assert model_ref.grid.get_pipe_parallel_rank() == mpu.get_pipeline_model_parallel_rank() + assert model_ref.grid.get_slice_parallel_rank() == mpu.get_tensor_model_parallel_rank() + assert model_ref.grid.get_data_parallel_rank() == mpu.get_data_parallel_rank() + + model_ref = [model_ref] + iteration2 = load_checkpoint(model_ref, optimizer_2, opt_param_scheduler_2) # THIS WORKED!! After commenting out assert args.consumed_train_samples == 0 in load_checkpoint() + + # THINGS THAT DID NOT WORK FOR LOADING FROM CHECKPOINT + # model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only(model_provider) # DID NOT WORK - train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size 32 != 8 * 1 * 8 + # model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only_modified(model_provider) # DID NOT WORK - optimizer = FusedAdam(TypeError: FusedAdam.__init__() got an unexpected keyword argument 'beta1' + + return model + +# def main(): +# # if RANK == 0: +# # setup_wandb() +# if os.getenv('TORCH_PROFILER_ENABLED') == '1': +# from torch.profiler import profile, record_function, ProfilerActivity +# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: +# model = pretrain( +# train_valid_test_datasets_provider, +# model_provider, +# ModelType.encoder_or_decoder, +# forward_step, +# args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, +# data_post_process=data_post_process +# ) + +# prof.export_chrome_trace(f"{args.tensorboard_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}.json") +# else: +# model = pretrain( +# train_valid_test_datasets_provider, +# model_provider, +# ModelType.encoder_or_decoder, +# forward_step, +# args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, +# data_post_process=data_post_process +# ) +# return model + + +if __name__ == "__main__": + # git_ds_info() + # pretrain(train_valid_test_datasets_provider, + # model_provider, + # ModelType.encoder_or_decoder, + # forward_step, + # args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # data_post_process=data_post_process) + import sys + import deepspeed.comm as dist + model = main() + dist.log_summary() + if wandb.run is not None: + print(f"wandb.run.name: {wandb.run.name}") + print(f"wandb.run.url: {wandb.run.url}") + wandb.finish() + sys.exit() From aa9d9b295dfc8470333890e0614816ec4f3b5453 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 21 Mar 2024 16:25:45 +0000 Subject: [PATCH 07/50] run script with modified exec --- train_llama_polaris_modified.sh | 110 ++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 train_llama_polaris_modified.sh diff --git a/train_llama_polaris_modified.sh b/train_llama_polaris_modified.sh new file mode 100644 index 0000000000..723fd06083 --- /dev/null +++ b/train_llama_polaris_modified.sh @@ -0,0 +1,110 @@ +#!/bin/bash --login +#PBS -l walltime=06:00:00 +#PBS -A argonne_tpc +#PBS -q prod +#PBS -l select=48 +#PBS -l filesystems=eagle:home + +function sourceFile() { + fp="$1" + echo "source-ing ${fp}" + if [[ -f "${fp}" ]]; then + # shellcheck source="${fp}" + source "${fp}" + else + echo "ERROR: UNABLE TO SOURCE ${fp}" + fi +} + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +# ---- 0. Navigate into `$PBS_O_WORKDIR` ------------------------------------- +cd "${PBS_O_WORKDIR}" || exit +HERE=$(python3 -c 'import os; print(os.getcwd())') +export HERE +# ---- 1. Assert `./pretrain_gpt_alcf.py` exists: ----------------------------- +# export EXEC="${HERE}/pretrain_gpt_alcf.py" +export EXEC="${HERE}/pretrain_gpt_modified.py" +[ -f "${EXEC}" ] || exit +# ---- 2. `source ./ALCF/helpers_alcf.sh`: ------------------------------------ +sourceFile "${HERE}/ALCF/helpers.sh" || exit +# ---- 3. Call fns from `./ALCF/helpers_alcf.sh` ------------------------------ +setEnv || exit # 1. load `conda` environment +saveDSenv || exit # 2. save env vars to `.deepspeed_env` +ezpz || exit # 3. determine WORLD_SIZE, etc. from `PBS_*` vars +makeHostfiles || exit # 4. create `deepspeed` hostfile from `$PBS_NODEFILE` +setParams || exit # 5. set command line arguments to pass to `"${EXEC}"` +buildDSconfig || exit # 6. create `deepspeed_config.json` from runtime params from ^ +setOutput || exit # 7. specify output directory for {logs, checkpoints, etc.} +setArgs || exit # 8. specify additional `deepspeed` arguments +setData "${DATA_FILE_LIST}"|| exit # 9. specify `DATA_FILE_LIST` for dolma dataset +setDSlauncher "${HERE}" || exit # 10. set `launcher` args for `deepspeed ${launcher} ${EXEC} ${args}` +printJobInfo || exit # 11. print job info +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +# Take custom args +custom_args=" $@" + +# Assert `./hostfile_deepspeed` exists +export hfds="${HERE}/hostfile_deepspeed" && [ -f "${hfds}" ] || exit + +# source "${HERE}/venvs/polaris/2024-03-14/bin/activate" || exit +# echo "Using $(which python3)" +# --launcher_args='--pmi=pmix' + # deepspeed --hostfile $hfds --launcher ${LAUNCHER} ${EXEC} \ + # ${launch_cmd} \ +run_cmd=" + deepspeed --hostfile $hfds --launcher MPICH ${EXEC} \ + --use-flash-attn-v2 \ + --$DTYPE \ + --num-workers 0 \ + --split 100,0,0 \ + --log-interval 1 \ + --no-bias-gelu-fusion \ + --lr-decay-style cosine \ + --no-bias-dropout-fusion \ + --no-masked-softmax-fusion \ + --tokenizer-type Llama2Tokenizer \ + --no-gradient-accumulation-fusion \ + --accumulate-allreduce-grads-in-fp32 \ + --use-checkpoint-opt_param-scheduler \ + --lr ${LR} \ + --seq-length $SEQ \ + --save ${CKPT_DIR} \ + --load ${CKPT_DIR} \ + --num-layers ${NLAYERS} \ + --hidden-size ${HIDDEN} \ + --train-iters ${TRAIN_ITER} \ + --eval-iters ${EVAL_ITERS} \ + --distributed-backend ${BE} \ + --num-attention-heads ${HEADS} \ + --save-interval ${SAVE_INTERVAL} \ + --eval-interval ${EVAL_INTERVAL} \ + --max-position-embeddings ${SEQ} \ + --micro-batch-size ${MICRO_BATCH} \ + --data-file-list ${DATA_FILE_LIST} \ + --tensor-model-parallel-size ${TP} \ + --global-batch-size ${GLOBAL_BATCH} \ + --pipeline-model-parallel-size ${PP} \ + --num-key-value-heads ${NUM_KV_HEAD} \ + --data-cache-path ${DATA_CACHE_PATH} \ + --ffn-hidden-size ${FFN_HIDDEN_SIZE} \ + --tokenizer-model ${TOKENIZER_MODEL} \ + ${LLAMA_ARGS} \ + $ds_args \ + ${gpt_args[*]} \ + $custom_args \ + |& tee ${OUTPUT_LOG} + " + + +echo "All DeepSpeed(s): $(which -a deepspeed)" +echo "Using $(which deepspeed)" +ds_report + +echo "${run_cmd}" + +printf "[!! \e[1;31m%s\e[0m] View output at:\n" "NOTE" +printf "\e[1;34m%s\e[0m\n" "${OUTPUT_LOG}" +# echo "${OUTPUT_LOG}" +eval "${run_cmd}" +set +x From ded92e35da1dd18915a9374d7d6ffbe31469fead Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 21 Mar 2024 16:27:06 +0000 Subject: [PATCH 08/50] added modified load model weights func --- megatron/training.py | 50 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/megatron/training.py b/megatron/training.py index 7e6c7dc6bb..7d123194fb 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -488,6 +488,7 @@ def load_model_weights_only(model_provider_func): lr_scheduler = None if args.deepspeed: + print(f'On args.deepspeed branch in load_model_weights_only..') # When loading just the model weights, ZeRO can be disabled. if 'zero_optimization' in args.deepspeed_config_dict: del args.deepspeed_config_dict['zero_optimization'] @@ -504,6 +505,46 @@ def load_model_weights_only(model_provider_func): print_datetime('before load checkpoint') if args.load is not None: + print(f'On args.load is not None branch in load_model_weights_only..') + iteration = load_checkpoint(model, optimizer, lr_scheduler, strict=True, load_only_weights=True) + + print_datetime('after load checkpoint weights') + + return model, optimizer, lr_scheduler + +def load_model_weights_only_modified(model_provider_func): + """Setup model and optimizer.""" + args = get_args() + print_rank_0('***>>>>> Args:{}'.format(args)) + + model = get_model(model_provider_func) + + optimizer = None + lr_scheduler = None + + if args.deepspeed: + print(f'On args.deepspeed branch in load_model_weights_only..') + # When loading just the model weights, ZeRO can be disabled. + if 'zero_optimization' in args.deepspeed_config_dict: + del args.deepspeed_config_dict['zero_optimization'] + + model, optimizer, _, lr_scheduler = deepspeed.initialize( + model=model[0], + # optimizer=optimizer, + args=args, + # lr_scheduler=lr_scheduler, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + + assert not isinstance(model, deepspeed.PipelineEngine), \ + 'Weight loading only mode is not supported in pipeline parallelism yet.' + + model = [model] + + print_datetime('before load checkpoint') + if args.load is not None: + print(f'On args.load is not None branch in load_model_weights_only..') iteration = load_checkpoint(model, optimizer, lr_scheduler, strict=True, load_only_weights=True) print_datetime('after load checkpoint weights') @@ -526,7 +567,9 @@ def setup_model_and_optimizer(model_provider_func, # initialize the compression here student_global_steps = 0 + print(f'Setting up model and optimizer..') if args.kd or args.mos: + print(f'On args.kd or args.mos branch..') model, _, _, _ = deepspeed.initialize( model=model[0], args=args, @@ -542,6 +585,7 @@ def setup_model_and_optimizer(model_provider_func, print_rank_0('***>>>>> Student model, global step:{}'.format(student_global_steps)) if args.compression_training: + print(f'On args.compression_training branch..') model, _, _, _ = deepspeed.initialize( model=model[0], args=args, @@ -555,6 +599,7 @@ def setup_model_and_optimizer(model_provider_func, (torchDDP, LocalDDP, Float16Module)) if args.inference: + print(f'On args.inference branch..') optimizer = None opt_param_scheduler = None else: @@ -605,6 +650,7 @@ def setup_model_and_optimizer(model_provider_func, ) model.set_data_post_process_func(data_post_process) else: + print(f'On deepspeed initialize branch without curriculum learning..') model, optimizer, _, opt_param_scheduler = deepspeed.initialize( model=model[0], optimizer=optimizer, @@ -625,13 +671,16 @@ def setup_model_and_optimizer(model_provider_func, # Compression has its own checkpoint loading path (e.g, loading both teacher and student models). So if compression is enabled, we skip the following checkpoint loading. no_post_init_checkpoint_loading = args.kd or args.mos if not no_post_init_checkpoint_loading: + print(f'On not no_post_init_checkpoint_loading branch..') if args.load is not None: + print(f'On not no_post_init_checkpoint_loading branch and args.load is not None..') timers = get_timers() timers('load-checkpoint', log_level=0).start(barrier=True) args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) timers('load-checkpoint').stop(barrier=True) timers.log(['load-checkpoint']) else: + print(f'On not no_post_init_checkpoint_loading branch and args.load is None..') args.iteration = 0 else: model[0].global_steps = student_global_steps @@ -650,6 +699,7 @@ def setup_model_and_optimizer(model_provider_func, # random-LTD requires converting transformer layers if args.random_ltd: + print(f'On args.random_ltd branch..') model[0] = convert_to_random_ltd(model[0], ParallelTransformerLayer) return model, optimizer, opt_param_scheduler From 96a1d262d3f7319f5ddf6c191873a46312ce3ade Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 21 Mar 2024 17:05:24 +0000 Subject: [PATCH 09/50] component wise loss func --- pretrain_gpt_modified.py | 111 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py index 5055c90cd4..43f139700e 100644 --- a/pretrain_gpt_modified.py +++ b/pretrain_gpt_modified.py @@ -340,6 +340,47 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): loss = loss + moe_loss return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} +def dpo_loss_func(loss_mask, dpo_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + if args.mos or args.kd: + # assert max(args.num_experts) >= 1 + loss = loss + moe_loss + mos_loss + if args.mos: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'mos loss': mos_loss + } + elif args.kd: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'kd loss': mos_loss + } + print_rank_0( + f'>>> total loss: {loss}, ' + f'lm loss {averaged_loss[0]}, ' + f'kd loss {mos_loss}' + ) + # else: + # if max(args.num_experts) <= 1: + # return loss, {'lm loss': averaged_loss[0]} + # loss = loss + moe_loss + # return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + else: + # if max(args.num_experts) <= 1: + # return loss, {'lm loss': averaged_loss[0]} + loss = dpo_loss + return loss, {'lm loss': averaged_loss[0], 'dpo loss': dpo_loss} + def calculate_mos_loss( args, @@ -397,6 +438,76 @@ def calculate_mos_loss( mos_loss = mos_loss.div(args.seq_length) * beta return mos_loss +def calculate_dpo_loss( + args, + stu_output, + teacher_model, + tokens, + position_ids, + attention_mask +): + mos_loss = 0 + alpha = args.kd_alpha_ce + beta = args.kd_beta_ce + kd_temp = args.kd_temp + kd_temp = 1.0 + beta = 0.1 + + if teacher_model: + with torch.no_grad(): + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + curriculum_seqlen = args.curriculum_seqlen + tokens = tokens[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + csl = curriculum_seqlen + attention_mask = ( + attention_mask[:, :, :csl, :csl].contiguous() + ) + # No need to truncate labels + # as we do not need it for the teacher logits + ref_output, ref_other_losses = teacher_model( + tokens, + position_ids, + attention_mask + ) + assert stu_output.size() == ref_output.size(), ( + 'ref and student output should match in size. ' + f'Student: {stu_output.size()}, ' + f'Reference: {ref_output.size()}, ' + f'CL seq length {args.curriculum_seqlen}' + ) + + student_logits = F.log_softmax(stu_output / kd_temp, dim=2) + # Labels ? + logprobs = torch.gather(student_logits, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + # The target logits is expected to be probabilities. + # If we use log_softmax, + # then we need to set target_log to true + # when initializing the KLDivLoss. + ref_logits = F.softmax(ref_output / kd_temp, dim=2) + ref_logprobs = torch.gather(ref_logits, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + # Partial DPO loss (from preferred/unpreferred) + logprob_ratio = logprobs - ref_logprobs + #------------ [ToDo]------------- + # # Get ratios of unpreferred log probabilities from model and ref model + # unpreferred_logprob_ratio = unpreferred_logprobs - ref_unpreferred_logprobs + + # Difference of logprobs ratios scaled by beta + # scaled_diff_logprob_ratios = self.beta * (preferred_logprob_ratio - unpreferred_logprob_ratio) + #------------ [ToDo]------------- + scaled_diff_logprob_ratios = beta * (logprob_ratio) + + # Losses computed as negative logsigmoid of scaled difference + dpo_loss = -F.logsigmoid(scaled_diff_logprob_ratios) + + return dpo_loss + def forward_step(data_iterator, model): """Forward step.""" From fa587cf32bb56b6f9dbd646924e836eaf8f73191 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 22 Mar 2024 10:18:42 +0000 Subject: [PATCH 10/50] added arguments and datapaths for pref unpref datasets --- ALCF/helpers.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ALCF/helpers.sh b/ALCF/helpers.sh index 00843d0018..b2b6aeb33e 100644 --- a/ALCF/helpers.sh +++ b/ALCF/helpers.sh @@ -258,11 +258,16 @@ setData() { # ---- [dfl: abbrv. for DATA_FILE_LIST] ------------------------- elif [[ $(hostname) == x1* ]]; then dfl_fallback="/gila/Aurora_deployment/AuroraGPT/datasets/dolma/data_file_list_reweighted.txt" elif [[ $(hostname) == x3* ]]; then - dfl_fallback="/eagle/datasets/dolma/data_file_list_reweighted.txt" + # dfl_fallback="/eagle/datasets/dolma/data_file_list_reweighted.txt" + dfl_fallback="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list.txt" + dfl_fallback_u="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list_u.txt" + dfl_fallback_p="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list_p.txt" else echo "Unknown hostname. Must manually specify DATA_FILE_LIST." fi dfl="${1:-${dfl_fallback}}" + dflu="${1:-${dfl_fallback_u}}" + dflp="${1:-${dfl_fallback_p}}" # dfl_fallback="/eagle/datasets/dolma/data_file_list_reweighted.txt" printf "Calling: \`setData()\` with %s\n" "${dfl}" ndocs=$(wc -l < "${dfl}") @@ -271,6 +276,8 @@ setData() { # ---- [dfl: abbrv. for DATA_FILE_LIST] ------------------------- dcp="${HERE}/.cache/${dfl_stem}/index-cache" mkdir -p dcp export DATA_FILE_LIST="${dfl}" + export DATA_FILE_LIST_U="${dflu}" + export DATA_FILE_LIST_P="${dflp}" export NUM_DOCS="${ndocs}" export WEIGHT_SUM="${ws}" export DFL_STEM="${dfl_stem}" From 3c5b89a5043cc98493480e212bc52327be0535cf Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 22 Mar 2024 10:19:37 +0000 Subject: [PATCH 11/50] added arguments for pref unpref datapaths --- megatron/arguments.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/megatron/arguments.py b/megatron/arguments.py index d83fe99856..270a886596 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1214,6 +1214,10 @@ def _add_data_args(parser): '--*-data-path args') group.add_argument('--data-file-list', type=str, default=None, help='The file with the list of dataset and weights') + group.add_argument('--data-file-list-u', type=str, default=None, + help='The file with the list of unpreferred dataset and weights') + group.add_argument('--data-file-list-p', type=str, default=None, + help='The file with the list of preferred dataset and weights') group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' From 4ef1df7f7806469b7c287ead0168afb6420dbc60 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 22 Mar 2024 10:22:27 +0000 Subject: [PATCH 12/50] pref unpref dataloaders and iterators --- pretrain_gpt_modified.py | 106 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py index 43f139700e..cdf85b16df 100644 --- a/pretrain_gpt_modified.py +++ b/pretrain_gpt_modified.py @@ -51,9 +51,10 @@ from megatron.training import setup_model_and_optimizer from megatron.training import load_model_weights_only, get_model from megatron.training import load_model_weights_only_modified -from megatron.training import get_optimizer_param_scheduler +from megatron.training import get_optimizer_param_scheduler, cyclic_iter from megatron.optimizer import get_megatron_optimizer from megatron.checkpointing import load_checkpoint +from megatron.data.data_samplers import build_pretraining_data_loader # RANK = setup_torch( # backend='deepspeed', @@ -715,6 +716,8 @@ def main(): # model = model_provider() model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + + # ---------- Reference model ------------- # model_ref, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) # throwing assertion error model_ref = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes? # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider) @@ -743,6 +746,107 @@ def main(): # THINGS THAT DID NOT WORK FOR LOADING FROM CHECKPOINT # model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only(model_provider) # DID NOT WORK - train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size 32 != 8 * 1 * 8 # model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only_modified(model_provider) # DID NOT WORK - optimizer = FusedAdam(TypeError: FusedAdam.__init__() got an unexpected keyword argument 'beta1' + # ---------------------------------------- + + if args.data_file_list_u is not None: + print(f'data files list unpreferred: {args.data_file_list_u}') + + # Number of train/valid/test samples. + if args.train_samples: + print(f'args.train_samples: {args.train_samples}') + train_samples = args.train_samples + else: + print(f'args.train_iters: {args.train_iters}') + print(f'args.global_batch_size: {args.global_batch_size}') + train_samples = args.train_iters * args.global_batch_size + + print(f'args.eval_interval: {args.eval_interval}') + print(f'args.eval_iters: {args.eval_iters}') + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print(f'train_val_test_num_samples: {train_val_test_num_samples}') + # print(f'args.data_impl: {args.data_impl}') + # print(f'args.split: {args.split}') + # print(f'args.seq_length: {args.seq_length}') + # print(f'args.seed: {args.seed}') + # print(f'args.train_data_path: {args.train_data_path}') + # print(f'args.valid_data_path: {args.valid_data_path}') + # print(f'args.test_data_path: {args.test_data_path}') + # print(f'args.data_cache_path: {args.data_cache_path}') + + files_u = [] + with open(args.data_file_list_u, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files_u.append(float(w)) + files_u.append(fname) + train_ds_u, valid_ds_u, test_ds_u = build_train_valid_test_datasets( + data_prefix=files_u, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating unpreferred GPT datasets ...") + + if args.data_file_list_p is not None: + print(f'data files list preferred: {args.data_file_list_p}') + + files_p = [] + with open(args.data_file_list_p, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files_p.append(float(w)) + files_p.append(fname) + train_ds_p, valid_ds_p, test_ds_p = build_train_valid_test_datasets( + data_prefix=files_p, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating preferred GPT datasets ...") + + # Data iterator + print(f'args.consumed_train_samples: {args.consumed_train_samples}') + print(f'args.dataloader_type: {args.dataloader_type}') + train_dataloader_u = build_pretraining_data_loader( + train_ds_u, args.consumed_train_samples) + + # Build train iterators + dl_type = args.dataloader_type + assert dl_type in ['single', 'cyclic'] + + if train_dataloader_u is not None: + print(f'train_dataloader_u is not None..') + train_data_iterator_u = iter(train_dataloader_u) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader_u)) + print_rank_0("> finished creating train_data_iterator_u ...") + + # Get batch + timers = get_timers() + timers('batch-generator-u', log_level=2).start() + tokens_u, labels_u, loss_mask_u, attention_mask_u, position_ids_u = get_batch( + train_data_iterator_u) + timers('batch-generator-u').stop() + # print(f'tokens shape: {tokens_u.shape}') + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for train_data_iterator_u ...") return model From d89121e47efbd5f763c7f53edfeda0932135d092 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 22 Mar 2024 10:23:24 +0000 Subject: [PATCH 13/50] testing model forward and loss compute --- pretrain_gpt_modified.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py index cdf85b16df..0b900f7537 100644 --- a/pretrain_gpt_modified.py +++ b/pretrain_gpt_modified.py @@ -848,6 +848,23 @@ def main(): # print(f'tokens shape: {tokens_u.shape}') print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for train_data_iterator_u ...") + # Model forward + # output_tensor, other_losses = model[0]( + # tokens_u, + # position_ids_u, + # attention_mask_u, + # labels=labels_u + # ) # OUT OF MEMORY ERROR even with 4 nodes + + stu_output, other_losses = model[0](tokens_u, position_ids_u, attention_mask_u) # THIS WORKED with 4 nodes + print_rank_0("> finished a forward pass to get logits ...") + + output_tensor = tensor_parallel.vocab_parallel_cross_entropy( + stu_output.contiguous().float(), + labels_u + ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR + print(f'Computed output_tensor: {output_tensor}') + return model # def main(): From e550a4336a14048871fe596a633d934ae060aa29 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 22 Mar 2024 10:25:01 +0000 Subject: [PATCH 14/50] additional args for pref unpref data files --- train_llama_polaris_modified.sh | 2 ++ 1 file changed, 2 insertions(+) mode change 100644 => 100755 train_llama_polaris_modified.sh diff --git a/train_llama_polaris_modified.sh b/train_llama_polaris_modified.sh old mode 100644 new mode 100755 index 723fd06083..d9f9864189 --- a/train_llama_polaris_modified.sh +++ b/train_llama_polaris_modified.sh @@ -82,6 +82,8 @@ run_cmd=" --max-position-embeddings ${SEQ} \ --micro-batch-size ${MICRO_BATCH} \ --data-file-list ${DATA_FILE_LIST} \ + --data-file-list-u ${DATA_FILE_LIST_U} \ + --data-file-list-p ${DATA_FILE_LIST_P} \ --tensor-model-parallel-size ${TP} \ --global-batch-size ${GLOBAL_BATCH} \ --pipeline-model-parallel-size ${PP} \ From a40fb1480e8ea355995198bbffe4d489bba1e41c Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 22 Mar 2024 23:19:44 +0000 Subject: [PATCH 15/50] returning tensor parallel logps --- megatron/core/tensor_parallel/cross_entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/tensor_parallel/cross_entropy.py b/megatron/core/tensor_parallel/cross_entropy.py index 9dcdc0459f..1a2de5044e 100644 --- a/megatron/core/tensor_parallel/cross_entropy.py +++ b/megatron/core/tensor_parallel/cross_entropy.py @@ -86,12 +86,13 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): mean_log_probs = log_probs.mean(dim=-1) loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs + log_probs = torch.log(exp_logits) ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size # Store softmax, target-mask and masked-target for backward pass. ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - return loss + return loss, log_probs @staticmethod def backward(ctx, grad_output): From e1f039626faad6120f7d908e249e2d3873fe1ab4 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 22 Mar 2024 23:34:42 +0000 Subject: [PATCH 16/50] tensor parallel logps for pref and unpref batches --- pretrain_gpt_modified.py | 52 ++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py index 0b900f7537..bebd7232f5 100644 --- a/pretrain_gpt_modified.py +++ b/pretrain_gpt_modified.py @@ -823,30 +823,44 @@ def main(): data_cache_path=args.data_cache_path) print_rank_0("> finished creating preferred GPT datasets ...") - # Data iterator + # Data loaders print(f'args.consumed_train_samples: {args.consumed_train_samples}') print(f'args.dataloader_type: {args.dataloader_type}') train_dataloader_u = build_pretraining_data_loader( train_ds_u, args.consumed_train_samples) + train_dataloader_p = build_pretraining_data_loader( + train_ds_p, args.consumed_train_samples) # Build train iterators dl_type = args.dataloader_type assert dl_type in ['single', 'cyclic'] if train_dataloader_u is not None: - print(f'train_dataloader_u is not None..') + print(f'unpreferred train_dataloader is not None..') train_data_iterator_u = iter(train_dataloader_u) if dl_type == 'single' \ else iter(cyclic_iter(train_dataloader_u)) - print_rank_0("> finished creating train_data_iterator_u ...") + print_rank_0("> finished creating unpreferred train_data_iterator...") + if train_dataloader_p is not None: + print(f'preferred train_dataloader is not None..') + train_data_iterator_p = iter(train_dataloader_p) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader_p)) + print_rank_0("> finished creating preferred train_data_iterator...") # Get batch timers = get_timers() - timers('batch-generator-u', log_level=2).start() + timers('batch-generator-unpreferred', log_level=2).start() tokens_u, labels_u, loss_mask_u, attention_mask_u, position_ids_u = get_batch( - train_data_iterator_u) - timers('batch-generator-u').stop() + train_data_iterator_u) + timers('batch-generator-unpreferred').stop() # print(f'tokens shape: {tokens_u.shape}') - print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for train_data_iterator_u ...") + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for unpref train_data_iterator ...") + + timers('batch-generator-preferred', log_level=2).start() + tokens_p, labels_p, loss_mask_p, attention_mask_p, position_ids_p = get_batch( + train_data_iterator_p) + timers('batch-generator-preferred').stop() + # print(f'tokens shape: {tokens_u.shape}') + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for pref train_data_iterator ...") # Model forward # output_tensor, other_losses = model[0]( @@ -856,14 +870,26 @@ def main(): # labels=labels_u # ) # OUT OF MEMORY ERROR even with 4 nodes - stu_output, other_losses = model[0](tokens_u, position_ids_u, attention_mask_u) # THIS WORKED with 4 nodes - print_rank_0("> finished a forward pass to get logits ...") + # Computing logits and logps for preferred and unpreferred data batches + output_u, other_losses_u = model[0](tokens_u, position_ids_u, attention_mask_u) # THIS WORKED with 4 nodes for 7B model + print_rank_0("> finished a forward pass to get unpref logits ...") - output_tensor = tensor_parallel.vocab_parallel_cross_entropy( - stu_output.contiguous().float(), + output_tensor_u, logprobs_u = tensor_parallel.vocab_parallel_cross_entropy( + output_u.contiguous().float(), labels_u - ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR - print(f'Computed output_tensor: {output_tensor}') + ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + print(f'Computed unpreferred output_tensor: {output_tensor_u}') + print(f'Computed unpreferred logprobs: {logprobs_u}') + + output_p, other_losses_p = model[0](tokens_p, position_ids_p, attention_mask_p) # THIS WORKED with 4 nodes for 7B model + print_rank_0("> finished a forward pass to get pref logits ...") + + output_tensor_p, logprobs_p = tensor_parallel.vocab_parallel_cross_entropy( + output_p.contiguous().float(), + labels_p + ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + print(f'Computed preferred output_tensor: {output_tensor_p}') + print(f'Computed preferred logprobs: {logprobs_p}') return model From c22d090fba417b466d23adc2c72a9b7868b46024 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 22 Mar 2024 23:42:11 +0000 Subject: [PATCH 17/50] computing pref and unpref logps from reference model --- pretrain_gpt_modified.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py index bebd7232f5..8c157c69de 100644 --- a/pretrain_gpt_modified.py +++ b/pretrain_gpt_modified.py @@ -891,6 +891,29 @@ def main(): print(f'Computed preferred output_tensor: {output_tensor_p}') print(f'Computed preferred logprobs: {logprobs_p}') + # Reference model in inference mode + # Computing logits and logps for preferred and unpreferred data batches from ref model + with torch.no_grad(): + ref_output_u, _ = model_ref[0](tokens_u, position_ids_u, attention_mask_u) # THIS WORKED with 4 nodes for 7B model + print_rank_0("> finished a forward pass to get unpref logits ...") + + ref_output_tensor_u, ref_logprobs_u = tensor_parallel.vocab_parallel_cross_entropy( + ref_output_u.contiguous().float(), + labels_u + ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + print(f'Computed unpreferred output_tensor: {ref_output_tensor_u}') + print(f'Computed unpreferred logprobs: {ref_logprobs_u}') + + ref_output_p, _ = model_ref[0](tokens_p, position_ids_p, attention_mask_p) # THIS WORKED with 4 nodes for 7B model + print_rank_0("> finished a forward pass to get pref logits ...") + + ref_output_tensor_p, ref_logprobs_p = tensor_parallel.vocab_parallel_cross_entropy( + ref_output_p.contiguous().float(), + labels_p + ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + print(f'Computed preferred output_tensor: {ref_output_tensor_p}') + print(f'Computed preferred logprobs: {ref_logprobs_p}') + return model # def main(): From 5c3db719b385a9a22f8a14ca58951d99d93edae9 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Sat, 23 Mar 2024 08:42:38 +0000 Subject: [PATCH 18/50] dpo loss func restructured --- pretrain_gpt_modified.py | 44 ++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py index 8c157c69de..cfd01e1816 100644 --- a/pretrain_gpt_modified.py +++ b/pretrain_gpt_modified.py @@ -443,6 +443,10 @@ def calculate_dpo_loss( args, stu_output, teacher_model, + logprobs_p, + logprobs_u, + ref_logprobs_p, + ref_logprobs_u, tokens, position_ids, attention_mask @@ -452,7 +456,7 @@ def calculate_dpo_loss( beta = args.kd_beta_ce kd_temp = args.kd_temp kd_temp = 1.0 - beta = 0.1 + beta = 0.1 # add to cmdline args if teacher_model: with torch.no_grad(): @@ -490,24 +494,34 @@ def calculate_dpo_loss( # If we use log_softmax, # then we need to set target_log to true # when initializing the KLDivLoss. - ref_logits = F.softmax(ref_output / kd_temp, dim=2) - ref_logprobs = torch.gather(ref_logits, dim=2, index=labels.unsqueeze(2)).squeeze(2) - # Partial DPO loss (from preferred/unpreferred) - logprob_ratio = logprobs - ref_logprobs - #------------ [ToDo]------------- - # # Get ratios of unpreferred log probabilities from model and ref model - # unpreferred_logprob_ratio = unpreferred_logprobs - ref_unpreferred_logprobs + # Get ratios of preferred log probabilities from model and ref model + logprob_ratio_p = logprobs_p - ref_logprobs_p - # Difference of logprobs ratios scaled by beta - # scaled_diff_logprob_ratios = self.beta * (preferred_logprob_ratio - unpreferred_logprob_ratio) - #------------ [ToDo]------------- - scaled_diff_logprob_ratios = beta * (logprob_ratio) + # Get ratios of unpreferred log probabilities from model and ref model + logprob_ratio_u = logprobs_u - ref_logprobs_u - # Losses computed as negative logsigmoid of scaled difference - dpo_loss = -F.logsigmoid(scaled_diff_logprob_ratios) + # Difference of logprobs ratios scaled by beta + scaled_diff_logprob_ratios = beta * (logprob_ratio_p - logprob_ratio_u) - return dpo_loss + # Losses computed as negative logsigmoid of scaled difference + losses = -F.logsigmoid(scaled_diff_logprob_ratios) + + # preferred dpo rewards + pref_dpo_rewards = (beta * logprob_ratio_p).detach() + + # unpreferred dpo rewards + unpref_dpo_rewards = (beta * logprob_ratio_u).detach() + + # Implicit DPO rewards + implicit_dpo_rewards = (pref_dpo_rewards > unpref_dpo_rewards).float() + rewards = implicit_dpo_rewards.cpu().mean() + + # Compute mean loss + dpo_loss = losses.mean() + # print(f'Loss dtype: {loss.dtype}') + + return dpo_loss, rewards def forward_step(data_iterator, model): From 8915f2ab2e632904c299b48b0a038c50f99f8522 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Mon, 25 Mar 2024 08:45:48 +0000 Subject: [PATCH 19/50] tracing the forward backward func --- megatron/core/pipeline_parallel/schedules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 407bb16d56..486eaf78c5 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -102,6 +102,7 @@ def forward_step(data_iterator, model): else: forward_backward_func = forward_backward_pipelining_without_interleaving else: + print(f'On forward_backward_no_pipelining branch ..') forward_backward_func = forward_backward_no_pipelining return forward_backward_func From eec8697ef6cfbd2e04ba5864e915826d30cf4456 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Mon, 25 Mar 2024 08:46:45 +0000 Subject: [PATCH 20/50] train step func --- megatron/training.py | 90 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/megatron/training.py b/megatron/training.py index 7d123194fb..04adf5551c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -704,6 +704,95 @@ def setup_model_and_optimizer(model_provider_func, return model, optimizer, opt_param_scheduler +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP +import contextlib +from megatron.core.utils import get_attr_wrapped_model, get_model_type, get_model_config +def train_step_dpo(data_iterator, model, + optimizer, opt_param_scheduler, config, + loss, + forward_only=False): + """Single training step.""" + args = get_args() + timers = get_timers() + + if args.deepspeed and args.ds_pipeline_enabled: + print(f'In train step if args.deepspeed and args.ds_pipeline_enabled..') + skipped_iter = 0 + num_zeros_in_grad = 0 + assert isinstance(model[0], deepspeed.PipelineEngine) + loss = model[0].train_batch(data_iter=data_iterator) + grad_norm = model[0].get_global_grad_norm() + return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad + + # Set grad to zero + if not args.deepspeed: + print(f'In train step and NOT args.deepspeed with optimi zero_grad..') + if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp: + for partition in model: + partition.zero_grad_buffer() + optimizer.zero_grad() + + # Forward backward pass + timers('forward-backward', log_level=1).start( + barrier=args.barrier_with_L1_time) + forward_backward_func = get_forward_backward_func() + + # losses_reduced = forward_backward_func( + # ) + if isinstance(model, list): + assert len(model) == 1, \ + "list of models .." + print(f'number of models in the list: {len(model)}') + model = model[0] + mconfig = get_model_config(model) + + no_sync_func = mconfig.no_sync_func + print(f'no_syn_func: {no_sync_func}') + if no_sync_func is None and isinstance(model, torchDDP): + print(f'On no_sync_func is None and isinstance(model, torchDDP) branch ..') + no_sync_func = model.no_sync + if no_sync_func is None: + print(f'On no_sync_func is None branch ..') + no_sync_func = contextlib.nullcontext + + if args.deepspeed: + print(f'setting gradient accumulation boundary to false ..') + model.set_gradient_accumulation_boundary(False) + + model_type = get_model_type(model) + + num_microbatches = get_num_microbatches() + print(f'num_microbatches: {num_microbatches}') + + forward_data_store = [] + input_tensor, output_tensor_grad = None, None + + # if False: + with no_sync_func(): + + if mconfig.enable_autocast: + print(f'mconfig.enable_autocast: {mconfig.enable_autocast}') + context_manager = torch.autocast("cuda", dtype=mconfig.autocast_dtype) + else: + context_manager = contextlib.nullcontext() + + with context_manager: + print(f'context_manager: {context_manager}') + output_tensor = loss + # for i in range(num_microbatches - 1): + # output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches, + # input_tensor, forward_data_store, config, collect_non_loss_data) + if not forward_only: + # backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config, model) + # Backward pass + if args.deepspeed: + model.backward(output_tensor) + + if args.deepspeed: + model.set_gradient_accumulation_boundary(True) + + return None + def train_step(forward_step_func, data_iterator, @@ -713,6 +802,7 @@ def train_step(forward_step_func, data_iterator, timers = get_timers() if args.deepspeed and args.ds_pipeline_enabled: + print(f'In train step if args.deepspeed and args.ds_pipeline_enabled..') skipped_iter = 0 num_zeros_in_grad = 0 assert isinstance(model[0], deepspeed.PipelineEngine) From 09aa61b065a3d35b22e5073269dada95f28ef761 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Mon, 25 Mar 2024 08:47:52 +0000 Subject: [PATCH 21/50] training iteration --- pretrain_gpt_modified.py | 155 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 4 deletions(-) diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py index cfd01e1816..5b425743e4 100644 --- a/pretrain_gpt_modified.py +++ b/pretrain_gpt_modified.py @@ -52,9 +52,15 @@ from megatron.training import load_model_weights_only, get_model from megatron.training import load_model_weights_only_modified from megatron.training import get_optimizer_param_scheduler, cyclic_iter +from megatron.training import train, train_step +from megatron.training import train_step_dpo from megatron.optimizer import get_megatron_optimizer from megatron.checkpointing import load_checkpoint from megatron.data.data_samplers import build_pretraining_data_loader +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.arguments import core_transformer_config_from_args +from megatron import update_num_microbatches +from megatron import get_num_microbatches # RANK = setup_torch( # backend='deepspeed', @@ -382,6 +388,31 @@ def dpo_loss_func(loss_mask, dpo_loss, output_tensor): loss = dpo_loss return loss, {'lm loss': averaged_loss[0], 'dpo loss': dpo_loss} +def batch_seq_logprobs(logits, labels): + """ Function to compute a batch of sequence log probabilities """ + + logits = logits[:, :-1, :] # skip last logit + logits_logsoftmax = logits.log_softmax(-1) # compute log softmax of logits + + labels = labels[:, 1:].clone() # clone labels + + # # Loss mask to avoid padded tokens while computing loss + # loss_mask = labels != tokenizer.pad_token_id + + # print(f'Labels shape: {labels.shape}') + # print(f'loss_mask shape: {loss_mask.shape}') + # print(f'loss_mask dtype: {loss_mask.dtype}') + + # Gather logps and squeeze last dimension + logprobs = torch.gather(logits_logsoftmax, dim=2, index=labels.unsqueeze(2)).squeeze(2) + # print(f'seq_logprobs shape: {logprobs.shape}') + + # Weighted sum over logprobs using loss mask + # seq_logprobs = (logprobs * loss_mask).sum(-1) + seq_logprobs = logprobs.sum(-1) + + return seq_logprobs + def calculate_mos_loss( args, @@ -523,6 +554,28 @@ def calculate_dpo_loss( return dpo_loss, rewards +def compute_dp_loss(logprobs_p, ref_logprobs_p, + logprobs_u, ref_logprobs_u, + beta=0.1): + + # Get ratios of preferred log probabilities from model and ref model + logprob_ratio_p = logprobs_p - ref_logprobs_p + + # Get ratios of unpreferred log probabilities from model and ref model + logprob_ratio_u = logprobs_u - ref_logprobs_u + + # Difference of logprobs ratios scaled by beta + scaled_diff_logprob_ratios = beta * (logprob_ratio_p - logprob_ratio_u) + + # Losses computed as negative logsigmoid of scaled difference + losses = -F.logsigmoid(scaled_diff_logprob_ratios) + + # Compute mean loss + dp_loss = losses.mean() + + return dp_loss + + def forward_step(data_iterator, model): """Forward step.""" @@ -892,7 +945,8 @@ def main(): output_u.contiguous().float(), labels_u ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) - print(f'Computed unpreferred output_tensor: {output_tensor_u}') + # logprobs_u = batch_seq_logprobs(output_u, labels_u) + # print(f'Computed unpreferred output_tensor: {output_tensor_u}') print(f'Computed unpreferred logprobs: {logprobs_u}') output_p, other_losses_p = model[0](tokens_p, position_ids_p, attention_mask_p) # THIS WORKED with 4 nodes for 7B model @@ -902,7 +956,8 @@ def main(): output_p.contiguous().float(), labels_p ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) - print(f'Computed preferred output_tensor: {output_tensor_p}') + # logprobs_p = batch_seq_logprobs(output_p, labels_p) + # print(f'Computed preferred output_tensor: {output_tensor_p}') print(f'Computed preferred logprobs: {logprobs_p}') # Reference model in inference mode @@ -915,7 +970,8 @@ def main(): ref_output_u.contiguous().float(), labels_u ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) - print(f'Computed unpreferred output_tensor: {ref_output_tensor_u}') + # ref_logprobs_u = batch_seq_logprobs(ref_output_u, labels_u) + # print(f'Computed unpreferred output_tensor: {ref_output_tensor_u}') print(f'Computed unpreferred logprobs: {ref_logprobs_u}') ref_output_p, _ = model_ref[0](tokens_p, position_ids_p, attention_mask_p) # THIS WORKED with 4 nodes for 7B model @@ -925,9 +981,100 @@ def main(): ref_output_p.contiguous().float(), labels_p ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) - print(f'Computed preferred output_tensor: {ref_output_tensor_p}') + # ref_logprobs_p = batch_seq_logprobs(ref_output_p, labels_p) + # print(f'Computed preferred output_tensor: {ref_output_tensor_p}') print(f'Computed preferred logprobs: {ref_logprobs_p}') + # Compute loss + loss = compute_dp_loss(logprobs_p, ref_logprobs_p, + logprobs_u, ref_logprobs_u, + beta=0.1) + print(f'Computed loss: {loss}') + + print(f'args.ds_pipeline_enabled: {args.ds_pipeline_enabled}') + if args.deepspeed and args.ds_pipeline_enabled: + print(f'In train step if args.deepspeed and args.ds_pipeline_enabled..') + skipped_iter = 0 + num_zeros_in_grad = 0 + assert isinstance(model[0], deepspeed.PipelineEngine) + loss = model[0].train_batch(data_iter=train_data_iterator_p) + grad_norm = model[0].get_global_grad_norm() + # return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad + + if not args.skip_train: + print_rank_0('training ...') + + if args.dataloader_type == 'cyclic' and args.retro_add_retriever: + args.train_iters = args.retro_cyclic_train_iters + print_rank_0("retro cyclic train iters : %d" % args.train_iters) + + iteration = 0 + if args.train_iters > 0: + print(f'In train step if args.train_iters: {args.train_iters} ..') + + # Turn on training mode which enables dropout. + for model_module in model: + model_module.train() + + # Tracking loss. + total_loss_dict = {} + + # Iterations. + iteration = args.iteration + + # Translate args to core configuration + config = core_transformer_config_from_args(args) + + config.timers = timers + + timers('interval-time', log_level=0).start(barrier=True) + print_datetime('before the start of training step') + report_memory_flag = True + + while iteration < args.train_iters and (args.train_tokens is None or \ + args.consumed_train_tokens < args.train_tokens): + print(f'args.train_tokens: {args.train_tokens}, args.consumed_train_tokens: {args.consumed_train_tokens}') + print(f'args.consumed_train_samples: {args.consumed_train_samples}') + update_num_microbatches(args.consumed_train_samples) + + if args.deepspeed: + # inform deepspeed of any batch size changes + global_batch_size = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + print(f'global batch size: {global_batch_size}') + model[0].set_train_batch_size(global_batch_size) + + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + print(f'args.curriculum_learning_legacy and not args.no_pipeline_parallel is {args.curriculum_learning_legacy} and {args.no_pipeline_parallel} ..') + curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \ + args.iteration + 1) + if iteration == 0 or curriculum_seqlen != args.curriculum_seqlen: + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(curriculum_seqlen) + args.curriculum_seqlen = curriculum_seqlen + + args.curr_iteration = iteration + print(f'iteration: {iteration}') + + # model[0].backward(loss) + results = train_step_dpo(train_data_iterator_p, + model, + optimizer, + opt_param_scheduler, + config, + loss, + forward_only=False) # FAILS HERE with the model backward step: TypeError: _VocabParallelCrossEntropy.backward() takes 2 positional arguments but 3 were given + + # Update parameters + if args.deepspeed: + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + model[0].step(lr_kwargs={'increment': increment}) + update_successful = model[0].was_step_applied() + + return model # def main(): From 6402e6247017146602742f825888474a8bb8c74f Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Wed, 27 Mar 2024 09:56:54 +0000 Subject: [PATCH 22/50] restored return in cross entropy --- megatron/core/tensor_parallel/cross_entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/tensor_parallel/cross_entropy.py b/megatron/core/tensor_parallel/cross_entropy.py index 1a2de5044e..cd86583013 100644 --- a/megatron/core/tensor_parallel/cross_entropy.py +++ b/megatron/core/tensor_parallel/cross_entropy.py @@ -92,7 +92,8 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): # Store softmax, target-mask and masked-target for backward pass. ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - return loss, log_probs + return loss + # return loss, log_probs @staticmethod def backward(ctx, grad_output): From 1a04e84c3d17aeb0f8f4b594b136c37bfd92604d Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Wed, 27 Mar 2024 09:57:56 +0000 Subject: [PATCH 23/50] post processing cross entropy return --- pretrain_gpt_modified.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pretrain_gpt_modified.py b/pretrain_gpt_modified.py index 5b425743e4..73c4fbbe77 100644 --- a/pretrain_gpt_modified.py +++ b/pretrain_gpt_modified.py @@ -391,10 +391,10 @@ def dpo_loss_func(loss_mask, dpo_loss, output_tensor): def batch_seq_logprobs(logits, labels): """ Function to compute a batch of sequence log probabilities """ - logits = logits[:, :-1, :] # skip last logit + logits = logits[:-1, :, :] # skip last logit logits_logsoftmax = logits.log_softmax(-1) # compute log softmax of logits - labels = labels[:, 1:].clone() # clone labels + labels = labels[1:, :].clone() # clone labels # # Loss mask to avoid padded tokens while computing loss # loss_mask = labels != tokenizer.pad_token_id @@ -941,10 +941,11 @@ def main(): output_u, other_losses_u = model[0](tokens_u, position_ids_u, attention_mask_u) # THIS WORKED with 4 nodes for 7B model print_rank_0("> finished a forward pass to get unpref logits ...") - output_tensor_u, logprobs_u = tensor_parallel.vocab_parallel_cross_entropy( + output_tensor_u = tensor_parallel.vocab_parallel_cross_entropy( output_u.contiguous().float(), labels_u ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + logprobs_u = torch.exp(output_tensor_u) # logprobs_u = batch_seq_logprobs(output_u, labels_u) # print(f'Computed unpreferred output_tensor: {output_tensor_u}') print(f'Computed unpreferred logprobs: {logprobs_u}') @@ -952,12 +953,13 @@ def main(): output_p, other_losses_p = model[0](tokens_p, position_ids_p, attention_mask_p) # THIS WORKED with 4 nodes for 7B model print_rank_0("> finished a forward pass to get pref logits ...") - output_tensor_p, logprobs_p = tensor_parallel.vocab_parallel_cross_entropy( + output_tensor_p = tensor_parallel.vocab_parallel_cross_entropy( output_p.contiguous().float(), labels_p ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + logprobs_p = torch.exp(output_tensor_p) # logprobs_p = batch_seq_logprobs(output_p, labels_p) - # print(f'Computed preferred output_tensor: {output_tensor_p}') + print(f'Computed preferred output_tensor: {output_tensor_p}') print(f'Computed preferred logprobs: {logprobs_p}') # Reference model in inference mode @@ -966,10 +968,11 @@ def main(): ref_output_u, _ = model_ref[0](tokens_u, position_ids_u, attention_mask_u) # THIS WORKED with 4 nodes for 7B model print_rank_0("> finished a forward pass to get unpref logits ...") - ref_output_tensor_u, ref_logprobs_u = tensor_parallel.vocab_parallel_cross_entropy( + ref_output_tensor_u = tensor_parallel.vocab_parallel_cross_entropy( ref_output_u.contiguous().float(), labels_u ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + ref_logprobs_u = torch.exp(ref_output_tensor_u) # ref_logprobs_u = batch_seq_logprobs(ref_output_u, labels_u) # print(f'Computed unpreferred output_tensor: {ref_output_tensor_u}') print(f'Computed unpreferred logprobs: {ref_logprobs_u}') @@ -977,10 +980,11 @@ def main(): ref_output_p, _ = model_ref[0](tokens_p, position_ids_p, attention_mask_p) # THIS WORKED with 4 nodes for 7B model print_rank_0("> finished a forward pass to get pref logits ...") - ref_output_tensor_p, ref_logprobs_p = tensor_parallel.vocab_parallel_cross_entropy( + ref_output_tensor_p = tensor_parallel.vocab_parallel_cross_entropy( ref_output_p.contiguous().float(), labels_p ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + ref_logprobs_p = torch.exp(ref_output_tensor_p) # ref_logprobs_p = batch_seq_logprobs(ref_output_p, labels_p) # print(f'Computed preferred output_tensor: {ref_output_tensor_p}') print(f'Computed preferred logprobs: {ref_logprobs_p}') From f956f3070bc53ba07b2fecd9284f1f480447df7f Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Wed, 27 Mar 2024 09:59:00 +0000 Subject: [PATCH 24/50] calls a different script --- train_llama_polaris_modified.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_llama_polaris_modified.sh b/train_llama_polaris_modified.sh index d9f9864189..fe3902b586 100755 --- a/train_llama_polaris_modified.sh +++ b/train_llama_polaris_modified.sh @@ -23,7 +23,7 @@ HERE=$(python3 -c 'import os; print(os.getcwd())') export HERE # ---- 1. Assert `./pretrain_gpt_alcf.py` exists: ----------------------------- # export EXEC="${HERE}/pretrain_gpt_alcf.py" -export EXEC="${HERE}/pretrain_gpt_modified.py" +export EXEC="${HERE}/dpo_training.py" [ -f "${EXEC}" ] || exit # ---- 2. `source ./ALCF/helpers_alcf.sh`: ------------------------------------ sourceFile "${HERE}/ALCF/helpers.sh" || exit From 66ae84413ed3b37ae875988c8c9eea4ee16462d5 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Wed, 27 Mar 2024 09:59:42 +0000 Subject: [PATCH 25/50] working backprop --- dpo_training.py | 1109 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1109 insertions(+) create mode 100644 dpo_training.py diff --git a/dpo_training.py b/dpo_training.py new file mode 100644 index 0000000000..025b30f542 --- /dev/null +++ b/dpo_training.py @@ -0,0 +1,1109 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain GPT""" + +import os +from rich import print +import torch +import math + +# The earliest we can measure the start time. +import time +from datetime import datetime + +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import GPTModel, GPTModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group, update_rotary_pos_emb +from megatron.arguments import core_transformer_config_from_args +from megatron.utils import ( + report_memory, + throughput_calculator, + checkpoint_throughput_calculator +) +from pathlib import Path + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator.real_accelerator import get_accelerator +import subprocess +import wandb + +import time +from torch import nn +import torch.nn.functional as F + +# from ezpz import get_logger +from ezpz.dist import get_world_size, setup_wandb, get_rank + +# More imports +from megatron.initialize import initialize_megatron +from megatron.initialize import set_jit_fusion_options +from megatron.training import print_datetime, _create_ds_config_dict +from megatron.training import setup_model_and_optimizer +from megatron.training import load_model_weights_only, get_model +from megatron.training import load_model_weights_only_modified +from megatron.training import get_optimizer_param_scheduler, cyclic_iter +from megatron.training import train, train_step +from megatron.training import train_step_dpo +from megatron.optimizer import get_megatron_optimizer +from megatron.checkpointing import load_checkpoint +from megatron.data.data_samplers import build_pretraining_data_loader +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.arguments import core_transformer_config_from_args +from megatron import update_num_microbatches +from megatron import get_num_microbatches + +# RANK = setup_torch( +# backend='deepspeed', +# port='5432', +# ) +RANK = get_rank() +WORLD_SIZE = get_world_size() +LEVEL = "DEBUG" if RANK == 0 else "CRITICAL" + +WANDB_MODE = os.environ.get('WANDB_MODE', None) +DISABLE_WANDB = ( + WANDB_MODE is not None and str(WANDB_MODE).lower() == 'disabled' +) + +if RANK == 0 and not DISABLE_WANDB: + project_name = ( + os.environ.get( + 'WB_PROJECT', + os.environ.get( + 'WANDB_PROJECT', + 'AuroraGPT' + ), + ) + ) + print('--------------------------------------------------') + print(f"Setting up W&B from: {RANK} with {project_name}") + print('--------------------------------------------------') + setup_wandb(project_name=project_name) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + print_rank_0('building GPT model ...') + see_memory_usage("Before Building Model", force=True) + args = get_args() + config = core_transformer_config_from_args(args) + if wandb.run is not None: + print(f"Updating WandB run: [{wandb.run.name}]({wandb.run.url})") + wandb.run.config.update({"args": vars(args)}, allow_val_change=True) + if RANK == 0: + git_ds_info() + if hasattr(mpu, 'get_sequence_parallel_group'): + dpg = mpu.get_sequence_parallel_group() + elif hasattr(mpu, 'get_data_parallel_group'): + dpg = mpu.get_data_parallel_group() + else: + dpg = None + if wandb is not None and wandb.run is not None: + assert wandb is not None and wandb.run is not None + print(f'Updating {wandb.run.name=} at {wandb.run.url=}') + wandb.run.config.update({'args': vars(args)}, allow_val_change=True) + with deepspeed.zero.Init( + data_parallel_group=dpg, + remote_device=( + None if args.remote_device == 'none' else args.remote_device + ), + config_dict_or_path=args.deepspeed_config_dict, + enabled=args.zero_stage == 3, + mpu=mpu + ): + if args.deepspeed and not args.no_pipeline_parallel: + model = GPTModelPipe( + config=config, + num_tokentypes=0, + parallel_output=True + ) + # This is a hack to give us a reference to + # get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + # Predompute the attention mask and store it in args. + # This avoids having to pipeline it + # as an activation during training. + # The mask is constant, and thus we can reuse it. + attention_mask = torch.tril( + torch.ones( + (1, args.seq_length, args.seq_length), + device=get_accelerator().current_device_name() + ) + ).view(1, 1, args.seq_length, args.seq_length) + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + # Attention mask must be bool. + args.attn_mask = attention_mask.to(torch.bool) + + # For prertaining, since sequence length is fixed, + # cache rotary embedding in args, to avoid communicating around + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(args.seq_length) + + else: + print(f'Building model check..') + model = GPTModel( + config=config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + # print_rank_0('\n ------------------------ ') + # print_rank_0(f'num of parameters {num_params}') + # print_rank_0('------------------------\n ') + print_rank_0(80 * '-') + print_rank_0(f"Number of parameters in model: {num_params}") + print_rank_0(80 * '-') + see_memory_usage("After Building Model", force=True) + if wandb.run is not None: + wandb.run.config.update({'num_params': num_params}, allow_val_change=True) + # wandb.run.watch( + # model, + # log='all', + # log_graph=True, + # ) + # wandb.run.config.update({'num_params': num_params}) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + # print(f'len(tokenizer.vocab): {len(tokenizer.vocab)}') + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + data = next(data_iterator) if data_iterator is not None else None + # # Broadcast data. + # if data_iterator is not None: + # data = next(data_iterator) + # else: + # data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + skip_mask = args.use_flash_attn or args.use_flash_attn_triton + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + skip_mask) + + # For DS's sequence parallel + seq_parallel_world_size = mpu.get_sequence_parallel_world_size() + seq_parallel_world_rank = mpu.get_sequence_parallel_rank() + + # For Megatron's sequence parallel + if args.sequence_parallel: + seq_parallel_world_size = mpu.get_tensor_model_parallel_world_size() + seq_parallel_world_rank = mpu.get_tensor_model_parallel_rank() + seq_length = tokens.size(1) + + assert seq_length % seq_parallel_world_size == 0 + sub_seq_length = seq_length // seq_parallel_world_size + sub_seq_start = seq_parallel_world_rank * sub_seq_length + sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length + + tokens = tokens[:, sub_seq_start:sub_seq_end] + position_ids = position_ids[:, sub_seq_start:sub_seq_end] + # For DS's sequence parallel + if mpu.get_sequence_parallel_world_size() > 1: + labels = labels[:, sub_seq_start:sub_seq_end] + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def data_post_process(data, data_sampler_state_dict): + args = get_args() + if args.data_efficiency_curriculum_learning: + if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate'] + if current_seqlen < args.seq_length: + data['text'] = data['text'][:, :(current_seqlen+1)].contiguous() + elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape'] + if current_seqlen < args.seq_length: + orig_num_token = torch.numel(data['text']) + reshape_len = (data['text'].size()[1] // (current_seqlen+1)) * (current_seqlen+1) + data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen+1), + data['text'][:, -(current_seqlen+1):]), 0).contiguous() + num_row = math.ceil(orig_num_token / (current_seqlen+1)) + num_row = min(num_row, data['text'].size()[0]) + if num_row > 1 and num_row % 2 != 0: + num_row -= 1 + data['text'] = data['text'][:num_row, :].contiguous() + else: + args.data_efficiency_curriculum_learning_seqlen_type = None + return data + + +def get_batch_pipe(data): + """ + Modification of `get_batch` to work on `next(data_iterator)` + instead of `data_iterator` + """ + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + if ( + args.curriculum_learning_legacy + and args.curriculum_seqlen < tokens.size()[1] + ): + # seqlen-based curriculum learning + # tokens, position_ids, labels, loss_mask + # have size [batch size, seqlen] + tokens = tokens[:, :args.curriculum_seqlen].contiguous() + position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + if labels is not None: + labels = labels[:, :args.curriculum_seqlen].contiguous() + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + return (tokens, position_ids, attention_mask), (labels, loss_mask) + + +def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + if args.mos or args.kd: + # assert max(args.num_experts) >= 1 + loss = loss + moe_loss + mos_loss + if args.mos: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'mos loss': mos_loss + } + elif args.kd: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'kd loss': mos_loss + } + print_rank_0( + f'>>> total loss: {loss}, ' + f'lm loss {averaged_loss[0]}, ' + f'kd loss {mos_loss}' + ) + else: + if max(args.num_experts) <= 1: + return loss, {'lm loss': averaged_loss[0]} + loss = loss + moe_loss + return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + +def dpo_loss_func(loss_mask, dpo_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + if args.mos or args.kd: + # assert max(args.num_experts) >= 1 + loss = loss + moe_loss + mos_loss + if args.mos: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'mos loss': mos_loss + } + elif args.kd: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'kd loss': mos_loss + } + print_rank_0( + f'>>> total loss: {loss}, ' + f'lm loss {averaged_loss[0]}, ' + f'kd loss {mos_loss}' + ) + # else: + # if max(args.num_experts) <= 1: + # return loss, {'lm loss': averaged_loss[0]} + # loss = loss + moe_loss + # return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + else: + # if max(args.num_experts) <= 1: + # return loss, {'lm loss': averaged_loss[0]} + loss = dpo_loss + return loss, {'lm loss': averaged_loss[0], 'dpo loss': dpo_loss} + +def batch_seq_logprobs(logits, labels): + """ Function to compute a batch of sequence log probabilities """ + + logits = logits[:-1, :, :] # skip last logit + logits_logsoftmax = logits.log_softmax(-1) # compute log softmax of logits + + labels = labels[1:, :].clone() # clone labels + + # # Loss mask to avoid padded tokens while computing loss + # loss_mask = labels != tokenizer.pad_token_id + + # print(f'Labels shape: {labels.shape}') + # print(f'loss_mask shape: {loss_mask.shape}') + # print(f'loss_mask dtype: {loss_mask.dtype}') + + # Gather logps and squeeze last dimension + logprobs = torch.gather(logits_logsoftmax, dim=2, index=labels.unsqueeze(2)).squeeze(2) + # print(f'seq_logprobs shape: {logprobs.shape}') + + # Weighted sum over logprobs using loss mask + # seq_logprobs = (logprobs * loss_mask).sum(-1) + seq_logprobs = logprobs.sum(-1) + + return seq_logprobs + + +def calculate_mos_loss( + args, + stu_output, + teacher_model, + tokens, + position_ids, + attention_mask +): + mos_loss = 0 + alpha = args.kd_alpha_ce + beta = args.kd_beta_ce + kd_temp = args.kd_temp + + if teacher_model: + with torch.no_grad(): + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + curriculum_seqlen = args.curriculum_seqlen + tokens = tokens[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + csl = curriculum_seqlen + attention_mask = ( + attention_mask[:, :, :csl, :csl].contiguous() + ) + # No need to truncate labels + # as we do not need it for the teacher logits + tea_output, tea_other_losses = teacher_model( + tokens, + position_ids, + attention_mask + ) + assert stu_output.size() == tea_output.size(), ( + 'teacher and student output should match in size. ' + f'Student: {stu_output.size()}, ' + f'Teacher: {tea_output.size()}, ' + f'CL seq length {args.curriculum_seqlen}' + ) + + student_logits = F.log_softmax(stu_output / kd_temp, dim=2) + # The target logits is expected to be probabilities. + # If we use log_softmax, + # then we need to set target_log to true + # when initializing the KLDivLoss. + tea_logits = F.softmax(tea_output / kd_temp, dim=2) + + mos_loss = kd_temp * kd_temp * nn.KLDivLoss(reduction='batchmean')( + student_logits, + tea_logits + ) + + mos_loss = mos_loss.div(args.seq_length) * beta + return mos_loss + +def calculate_dpo_loss( + args, + stu_output, + teacher_model, + logprobs_p, + logprobs_u, + ref_logprobs_p, + ref_logprobs_u, + tokens, + position_ids, + attention_mask +): + mos_loss = 0 + alpha = args.kd_alpha_ce + beta = args.kd_beta_ce + kd_temp = args.kd_temp + kd_temp = 1.0 + beta = 0.1 # add to cmdline args + + if teacher_model: + with torch.no_grad(): + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + curriculum_seqlen = args.curriculum_seqlen + tokens = tokens[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + csl = curriculum_seqlen + attention_mask = ( + attention_mask[:, :, :csl, :csl].contiguous() + ) + # No need to truncate labels + # as we do not need it for the teacher logits + ref_output, ref_other_losses = teacher_model( + tokens, + position_ids, + attention_mask + ) + assert stu_output.size() == ref_output.size(), ( + 'ref and student output should match in size. ' + f'Student: {stu_output.size()}, ' + f'Reference: {ref_output.size()}, ' + f'CL seq length {args.curriculum_seqlen}' + ) + + student_logits = F.log_softmax(stu_output / kd_temp, dim=2) + # Labels ? + logprobs = torch.gather(student_logits, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + # The target logits is expected to be probabilities. + # If we use log_softmax, + # then we need to set target_log to true + # when initializing the KLDivLoss. + + # Get ratios of preferred log probabilities from model and ref model + logprob_ratio_p = logprobs_p - ref_logprobs_p + + # Get ratios of unpreferred log probabilities from model and ref model + logprob_ratio_u = logprobs_u - ref_logprobs_u + + # Difference of logprobs ratios scaled by beta + scaled_diff_logprob_ratios = beta * (logprob_ratio_p - logprob_ratio_u) + + # Losses computed as negative logsigmoid of scaled difference + losses = -F.logsigmoid(scaled_diff_logprob_ratios) + + # preferred dpo rewards + pref_dpo_rewards = (beta * logprob_ratio_p).detach() + + # unpreferred dpo rewards + unpref_dpo_rewards = (beta * logprob_ratio_u).detach() + + # Implicit DPO rewards + implicit_dpo_rewards = (pref_dpo_rewards > unpref_dpo_rewards).float() + rewards = implicit_dpo_rewards.cpu().mean() + + # Compute mean loss + dpo_loss = losses.mean() + # print(f'Loss dtype: {loss.dtype}') + + return dpo_loss, rewards + +def compute_dp_loss(logprobs_p, ref_logprobs_p, + logprobs_u, ref_logprobs_u, + beta=0.1): + + # Get ratios of preferred log probabilities from model and ref model + logprob_ratio_p = logprobs_p - ref_logprobs_p + + # Get ratios of unpreferred log probabilities from model and ref model + logprob_ratio_u = logprobs_u - ref_logprobs_u + + # Difference of logprobs ratios scaled by beta + scaled_diff_logprob_ratios = beta * (logprob_ratio_p - logprob_ratio_u) + + # Losses computed as negative logsigmoid of scaled difference + losses = -F.logsigmoid(scaled_diff_logprob_ratios) + + # Compute mean loss + dp_loss = losses.mean() + + return dp_loss + + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + if args.data_efficiency_curriculum_learning: + args.curriculum_seqlen = tokens.size()[1] + if ( + hasattr( + args, + 'data_efficiency_curriculum_learning_seqlen_type') + and ( + args.data_efficiency_curriculum_learning_seqlen_type + == 'seqlen_reshape' + ) + ): + args.data_efficiency_curriculum_learning_numel = ( + torch.numel(tokens) + ) + + if args.mos or args.kd: + # The forward func can return either the loss or the logits, + # depending on whether passing in the labels or not. + stu_output, other_losses = model(tokens, position_ids, attention_mask) + if ( + args.curriculum_learning_legacy + and args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + labels = labels[:, :args.curriculum_seqlen].contiguous() + output_tensor = tensor_parallel.vocab_parallel_cross_entropy( + stu_output.contiguous().float(), + labels + ) + else: + output_tensor, other_losses = model( + tokens, + position_ids, + attention_mask, + labels=labels + ) + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + moe_losses = [] + for moe_loss in other_losses: + if moe_loss is not None: + moe_losses.append(moe_loss) + moe_loss = sum(moe_losses) * args.moe_loss_coeff + + mos_loss = 0 + if args.mos or args.kd: + assert model.training + if args.teacher_forward and args.teacher_model is not None: + mos_loss = calculate_mos_loss( + args, + stu_output, + args.teacher_model[0], + tokens, + position_ids, + attention_mask + ) + + # Output_tensor stores the standard loss, + # loss_func calculates the total loss. + return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0('> building train, validation, and test datasets ' + 'for GPT ...') + files = [] + if args.data_file_list is not None: + with open(args.data_file_list, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files.append(float(w)) + files.append(fname) + elif len(args.data_path) == 1 and os.path.isdir(args.data_path[0]): + path = args.data_path[0] + "/" + for f in os.listdir(path): + if (os.path.isfile(path + f) and f.find(".bin") != -1): + files.append(1) + files.append(path + f.split(".bin")[0]) + else: + files = args.data_path + print_rank_0(f"file list {files}") + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=files, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +def command_exists(cmd): + result = subprocess.Popen( + f'type {cmd}', + stdout=subprocess.PIPE, + shell=True + ) + return result.wait() == 0 + + +def git_ds_info(): + if RANK != 0: + return + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print( + f'**** Git info for Megatron: ' + f'git_hash={git_hash} git_branch={git_branch} ****' + ) + + +def main(): + # if RANK == 0: + # setup_wandb() + if os.getenv('TORCH_PROFILER_ENABLED') == '1': + from torch.profiler import profile, record_function, ProfilerActivity + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron( + # extra_args_provider=extra_args_provider, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # external_args=external_args + ) + # Set pytorch JIT layer fusion options and warmup JIT functions. + if get_accelerator().device_name() == 'cuda': + set_jit_fusion_options() + + args = get_args() + timers = get_timers() + + # model = model_provider() + model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + + prof.export_chrome_trace(f"{args.tensorboard_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}.json") + else: + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron( + # extra_args_provider=extra_args_provider, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # external_args=external_args + ) + # Set pytorch JIT layer fusion options and warmup JIT functions. + if get_accelerator().device_name() == 'cuda': + set_jit_fusion_options() + + args = get_args() + timers = get_timers() + + if args.deepspeed: + args.deepspeed_config_dict = _create_ds_config_dict() + if "curriculum_learning" in args.deepspeed_config_dict and \ + "enabled" in args.deepspeed_config_dict["curriculum_learning"]: + args.curriculum_learning_legacy = args.deepspeed_config_dict[ \ + "curriculum_learning"]["enabled"] + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + from deepspeed.runtime.data_pipeline.curriculum_scheduler \ + import CurriculumScheduler + args.curriculum_scheduler = CurriculumScheduler( \ + args.deepspeed_config_dict["curriculum_learning"]) + if "compression_training" in args.deepspeed_config_dict: + args.compression_training = True + + # model = model_provider() + model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + + # ---------- Reference model ------------- + # model_ref, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) # throwing assertion error + model_ref = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes? + # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider) + optimizer_2 = get_megatron_optimizer(model_ref, None, None, 1.0) + opt_param_scheduler_2 = get_optimizer_param_scheduler(optimizer_2) + model_ref, optimizer_2, _, opt_param_scheduler_2 = deepspeed.initialize( + model=model_ref[0], + optimizer=optimizer_2, + args=args, + lr_scheduler=opt_param_scheduler_2, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + if isinstance(model_ref, deepspeed.PipelineEngine): + print(f'Doing assertion checks on model_ref..') + # hack to get batch_fn from pretrain_gpt.py + model_ref.set_batch_fn(model_ref.module._megatron_batch_fn) + + assert model_ref.grid.get_pipe_parallel_rank() == mpu.get_pipeline_model_parallel_rank() + assert model_ref.grid.get_slice_parallel_rank() == mpu.get_tensor_model_parallel_rank() + assert model_ref.grid.get_data_parallel_rank() == mpu.get_data_parallel_rank() + + model_ref = [model_ref] + iteration2 = load_checkpoint(model_ref, optimizer_2, opt_param_scheduler_2) # THIS WORKED!! After commenting out assert args.consumed_train_samples == 0 in load_checkpoint() + + # THINGS THAT DID NOT WORK FOR LOADING FROM CHECKPOINT + # model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only(model_provider) # DID NOT WORK - train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size 32 != 8 * 1 * 8 + # model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only_modified(model_provider) # DID NOT WORK - optimizer = FusedAdam(TypeError: FusedAdam.__init__() got an unexpected keyword argument 'beta1' + # ---------------------------------------- + + if args.data_file_list_u is not None: + print(f'data files list unpreferred: {args.data_file_list_u}') + + # Number of train/valid/test samples. + if args.train_samples: + print(f'args.train_samples: {args.train_samples}') + train_samples = args.train_samples + else: + print(f'args.train_iters: {args.train_iters}') + print(f'args.global_batch_size: {args.global_batch_size}') + train_samples = args.train_iters * args.global_batch_size + + print(f'args.eval_interval: {args.eval_interval}') + print(f'args.eval_iters: {args.eval_iters}') + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print(f'train_val_test_num_samples: {train_val_test_num_samples}') + # print(f'args.data_impl: {args.data_impl}') + # print(f'args.split: {args.split}') + # print(f'args.seq_length: {args.seq_length}') + # print(f'args.seed: {args.seed}') + # print(f'args.train_data_path: {args.train_data_path}') + # print(f'args.valid_data_path: {args.valid_data_path}') + # print(f'args.test_data_path: {args.test_data_path}') + # print(f'args.data_cache_path: {args.data_cache_path}') + + files_u = [] + with open(args.data_file_list_u, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files_u.append(float(w)) + files_u.append(fname) + train_ds_u, valid_ds_u, test_ds_u = build_train_valid_test_datasets( + data_prefix=files_u, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating unpreferred GPT datasets ...") + + if args.data_file_list_p is not None: + print(f'data files list preferred: {args.data_file_list_p}') + + files_p = [] + with open(args.data_file_list_p, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files_p.append(float(w)) + files_p.append(fname) + train_ds_p, valid_ds_p, test_ds_p = build_train_valid_test_datasets( + data_prefix=files_p, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating preferred GPT datasets ...") + + # Data loaders + print(f'args.consumed_train_samples: {args.consumed_train_samples}') + print(f'args.dataloader_type: {args.dataloader_type}') + train_dataloader_u = build_pretraining_data_loader( + train_ds_u, args.consumed_train_samples) + train_dataloader_p = build_pretraining_data_loader( + train_ds_p, args.consumed_train_samples) + + # Build train iterators + dl_type = args.dataloader_type + assert dl_type in ['single', 'cyclic'] + + if train_dataloader_u is not None: + print(f'unpreferred train_dataloader is not None..') + train_data_iterator_u = iter(train_dataloader_u) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader_u)) + print_rank_0("> finished creating unpreferred train_data_iterator...") + if train_dataloader_p is not None: + print(f'preferred train_dataloader is not None..') + train_data_iterator_p = iter(train_dataloader_p) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader_p)) + print_rank_0("> finished creating preferred train_data_iterator...") + + + iteration = 0 + print_rank_0(f'args.train_iters: {args.train_iters}') + + for i in range(args.train_iters): + # Get batch + timers = get_timers() + timers('batch-generator-unpreferred', log_level=2).start() + tokens_u, labels_u, loss_mask_u, attention_mask_u, position_ids_u = get_batch( + train_data_iterator_u) + timers('batch-generator-unpreferred').stop() + # print(f'tokens shape: {tokens_u.shape}') + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for unpref train_data_iterator ...") + + timers('batch-generator-preferred', log_level=2).start() + tokens_p, labels_p, loss_mask_p, attention_mask_p, position_ids_p = get_batch( + train_data_iterator_p) + timers('batch-generator-preferred').stop() + # print(f'tokens shape: {tokens_u.shape}') + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for pref train_data_iterator ...") + + # Model forward + # output_tensor, other_losses = model[0]( + # tokens_u, + # position_ids_u, + # attention_mask_u, + # labels=labels_u + # ) # OUT OF MEMORY ERROR even with 4 nodes + + # Computing logits and logps for preferred and unpreferred data batches + # output_u, other_losses_u = model[0](tokens_u, position_ids_u, attention_mask_u) # THIS WORKED with 4 nodes for 7B model + # print_rank_0("> finished a forward pass to get unpref logits ...") + + + tokens_c = torch.cat((tokens_p,tokens_u), 0) + position_ids_c = torch.cat((position_ids_p,position_ids_u), 0) + labels_c = torch.cat((labels_p,labels_u), 0) + loss_mask_c = torch.cat((loss_mask_p,loss_mask_u), 0) + + output_c, other_losses_c = model[0]( + tokens_c, + position_ids_c, + None, + # labels=labels_u + ) + + loss_c = tensor_parallel.vocab_parallel_cross_entropy( + output_c.contiguous().float(), + labels_c + ) + + with torch.no_grad(): + routput_c, rother_losses_c = model_ref[0]( + tokens_c, + position_ids_c, + None, + # labels=labels_u + ) + rloss_c = tensor_parallel.vocab_parallel_cross_entropy( + routput_c.contiguous().float(), + labels_c + ) + + # print(f'tokens_p: {tokens_p}') + # print(f'tokens_u: {tokens_u}') + # # print(f'output_p[0]: {output_p[0]}') + # # print(f'output_u[0]: {output_u[0]}') + # print(f'output_c[0]: {output_c[0]}') + # print(f'tokens_p shape: {tokens_p.size()}, tokens_u shape: {tokens_u.size()}') + # print(f'tokens_c shape: {tokens_c.size()}') + # print(f'position_ids_p shape: {position_ids_p.size()}, position_ids_u shape: {position_ids_u.size()}') + # print(f'position_ids_c shape: {position_ids_c.size()}') + # print(f'output_c shape: {output_c.size()}') + # print(f'loss_c shape: {loss_c.size()}') + # print(f'routput_c shape: {routput_c.size()}') + # print(f'rloss_c shape: {rloss_c.size()}') + # print(f'loss_mask_p shape: {loss_mask_p.size()}') + # print(f'loss_mask_u shape: {loss_mask_u.size()}') + # print(f'loss_mask_c shape: {loss_mask_c.size()}') + # print(f'attention_mask_u: {attention_mask_u}') + + seq_logps_p = torch.sum(loss_c[:8,:] * loss_mask_p, dim=-1) / torch.sum(loss_mask_p, dim=-1) + seq_logps_u = torch.sum(loss_c[8:,:] * loss_mask_u, dim=-1) / torch.sum(loss_mask_u, dim=-1) + rseq_logps_p = torch.sum(rloss_c[:8,:] * loss_mask_p, dim=-1) / torch.sum(loss_mask_p, dim=-1) + rseq_logps_u = torch.sum(rloss_c[8:,:] * loss_mask_u, dim=-1) / torch.sum(loss_mask_u, dim=-1) + # print(f'seq_logps_p shape: {seq_logps_p.size()}') + # print(f'seq_logps_u shape: {seq_logps_u.size()}') + # print(f'rseq_logps_p shape: {rseq_logps_p.size()}') + # print(f'rseq_logps_u shape: {rseq_logps_u.size()}') + + pu_ratio = seq_logps_p - seq_logps_u + rpu_ratio = rseq_logps_p - rseq_logps_u + sdiff_ratio = 0.1*(pu_ratio - rpu_ratio) + # print(f'sdiff_ratio: {sdiff_ratio}') + final = -F.logsigmoid(sdiff_ratio) + # print(f'final: {final}') + + mos_loss = torch.sum(final) + print_rank_0(f'iteration: {iteration}, mos_loss: {mos_loss}') + # print(f'mos_loss shape: {mos_loss.size()}') + + # print(f'args.ds_pipeline_enabled: {args.ds_pipeline_enabled}') + # print(f'args.no_pipeline_parallel: {args.no_pipeline_parallel}') + # if args.deepspeed and args.ds_pipeline_enabled: + # print(f'In train step if args.deepspeed and args.ds_pipeline_enabled..') + + # print(f'loss_mask_p: {loss_mask_p}') + + # print(f'loss_mask_p sum: {torch.sum(loss_mask_p), 8*4096}')# print(f'loss_mask_p shape: {loss_mask_p.size()}') + + model[0].train() + model[0].backward(mos_loss) + + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + # print(f'increment: {increment}') + # model[0].step(lr_kwargs={'increment': increment}) + model[0].step() + update_successful = model[0].was_step_applied() + print_rank_0(f'update_successful: {update_successful}') + + iteration += 1 + args.iteration = iteration + new_samples = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + + # print(f'args.consumed_train_samples: {args.consumed_train_samples}') + args.consumed_train_samples += new_samples + # print(f'args.consumed_train_samples: {args.consumed_train_samples}') + + # logprobs_u = torch.exp(output_tensor_u) + # # print(f'Computed unpreferred output_tensor: {output_tensor_u}') + # print(f'Computed unpreferred logprobs: {logprobs_u}') + + # output_p, other_losses_p = model[0](tokens_p, position_ids_p, attention_mask_p) # THIS WORKED with 4 nodes for 7B model + # print_rank_0("> finished a forward pass to get pref logits ...") + + # output_tensor_p = tensor_parallel.vocab_parallel_cross_entropy( + # output_p.contiguous().float(), + # labels_p + # ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) + + + + return model + +# def main(): +# # if RANK == 0: +# # setup_wandb() +# if os.getenv('TORCH_PROFILER_ENABLED') == '1': +# from torch.profiler import profile, record_function, ProfilerActivity +# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: +# model = pretrain( +# train_valid_test_datasets_provider, +# model_provider, +# ModelType.encoder_or_decoder, +# forward_step, +# args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, +# data_post_process=data_post_process +# ) + +# prof.export_chrome_trace(f"{args.tensorboard_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}.json") +# else: +# model = pretrain( +# train_valid_test_datasets_provider, +# model_provider, +# ModelType.encoder_or_decoder, +# forward_step, +# args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, +# data_post_process=data_post_process +# ) +# return model + + +if __name__ == "__main__": + # git_ds_info() + # pretrain(train_valid_test_datasets_provider, + # model_provider, + # ModelType.encoder_or_decoder, + # forward_step, + # args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # data_post_process=data_post_process) + import sys + import deepspeed.comm as dist + model = main() + dist.log_summary() + if wandb.run is not None: + print(f"wandb.run.name: {wandb.run.name}") + print(f"wandb.run.url: {wandb.run.url}") + wandb.finish() + sys.exit() From de885213f4a64fd4364ad6b07c5cb0ae482cf58f Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Wed, 27 Mar 2024 12:05:24 +0000 Subject: [PATCH 26/50] training log func --- megatron/training.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/megatron/training.py b/megatron/training.py index 04adf5551c..eda8e380e3 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -919,6 +919,36 @@ def train_step(forward_step_func, data_iterator, return {}, skipped_iter, grad_norm, num_zeros_in_grad +def training_log_dpo(loss_dict, iteration, report_memory_flag): + """Log training information such as losses ....""" + + if wandb is not None and getattr(wandb, 'run', None) is not None: + assert wandb.run is not None + # wandb_metrics = { + # 'throughput/iteration-time': elapsed_time_per_iteration, # 1000 ms / s + # 'throughput/samples_per_sec': samples_per_sec, + # 'throughput/samples_per_sec_per_replica': samples_per_sec_per_replica, + # 'throughput/tokens_per_sec': tokens_per_sec, + # 'throughput/tokens_per_sec_per_replica': tokens_per_sec_per_replica, + # 'throughput/tokens_per_gpu_per_sec': tokens_per_gpu_per_second, + # 'throughput/tokens_per_gpu_per_sec_per_replica': tokens_per_gpu_per_second_per_replica, + # 'throughput/tflops': tflops, + # 'throughput/approx_params_in_billions': approx_parameters_in_billions, + # 'throughput/elapsed_ms_per_iteration': elapsed_time_per_iteration, + # 'throughput/iteration': iteration, + # } + if loss_dict is not None: + wandb_metrics = { + 'loss/iteration': iteration, + 'loss': loss_dict['loss'] + } + + wandb.log(wandb_metrics) + + return report_memory_flag + + + def training_log(loss_dict, total_loss_dict, learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, From b50e140d289df6ff7e51c830f778848ba95694a7 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Wed, 27 Mar 2024 12:05:42 +0000 Subject: [PATCH 27/50] loggin training metrics --- dpo_training.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/dpo_training.py b/dpo_training.py index 025b30f542..9ba7412f4d 100644 --- a/dpo_training.py +++ b/dpo_training.py @@ -53,7 +53,7 @@ from megatron.training import load_model_weights_only_modified from megatron.training import get_optimizer_param_scheduler, cyclic_iter from megatron.training import train, train_step -from megatron.training import train_step_dpo +from megatron.training import train_step_dpo, training_log_dpo from megatron.optimizer import get_megatron_optimizer from megatron.checkpointing import load_checkpoint from megatron.data.data_samplers import build_pretraining_data_loader @@ -734,6 +734,7 @@ def git_ds_info(): def main(): # if RANK == 0: # setup_wandb() + if os.getenv('TORCH_PROFILER_ENABLED') == '1': from torch.profiler import profile, record_function, ProfilerActivity with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: @@ -917,7 +918,10 @@ def main(): iteration = 0 print_rank_0(f'args.train_iters: {args.train_iters}') - + report_memory_flag = True + if torch.distributed.get_rank() == 0: + averaged_loss_iter = [] + averaged_rewards_iter = [] for i in range(args.train_iters): # Get batch timers = get_timers() @@ -1011,8 +1015,8 @@ def main(): final = -F.logsigmoid(sdiff_ratio) # print(f'final: {final}') - mos_loss = torch.sum(final) - print_rank_0(f'iteration: {iteration}, mos_loss: {mos_loss}') + dloss = torch.sum(final) + # print_rank_0(f'iteration: {iteration}, mos_loss: {dloss}') # print(f'mos_loss shape: {mos_loss.size()}') # print(f'args.ds_pipeline_enabled: {args.ds_pipeline_enabled}') @@ -1025,7 +1029,7 @@ def main(): # print(f'loss_mask_p sum: {torch.sum(loss_mask_p), 8*4096}')# print(f'loss_mask_p shape: {loss_mask_p.size()}') model[0].train() - model[0].backward(mos_loss) + model[0].backward(dloss) increment = get_num_microbatches() * \ args.micro_batch_size * \ @@ -1046,6 +1050,21 @@ def main(): args.consumed_train_samples += new_samples # print(f'args.consumed_train_samples: {args.consumed_train_samples}') + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([dloss]) + loss_dict = {'loss': averaged_loss} + print_rank_0(f'iteration: {iteration}, dloss: {averaged_loss.detach().cpu().tolist()}') + psrewards_p = (0.1 * (seq_logps_p - rseq_logps_p)).detach() + psrewards_u = (0.1 * (seq_logps_u - rseq_logps_u)).detach() + psrewards = (psrewards_p > psrewards_u).float() + rewards = psrewards.cpu().mean() + print_rank_0(f'iteration: {iteration}, rewards: {rewards}') + if torch.distributed.get_rank() == 0: + averaged_loss_iter.append(averaged_loss.detach().cpu().tolist()[0]) + averaged_rewards_iter.append(rewards.tolist()) + + # report_memory_flag = training_log_dpo(loss_dict, iteration, report_memory_flag) + # logprobs_u = torch.exp(output_tensor_u) # # print(f'Computed unpreferred output_tensor: {output_tensor_u}') # print(f'Computed unpreferred logprobs: {logprobs_u}') @@ -1058,7 +1077,10 @@ def main(): # labels_p # ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) - + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(averaged_loss_iter) + print(averaged_rewards_iter) return model From 24737c04695155791fa7f727b23a3bad1694cbe3 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Tue, 2 Apr 2024 17:52:52 +0000 Subject: [PATCH 28/50] text-sequence data files --- ALCF/helpers.sh | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ALCF/helpers.sh b/ALCF/helpers.sh index b2b6aeb33e..ce84346881 100644 --- a/ALCF/helpers.sh +++ b/ALCF/helpers.sh @@ -260,8 +260,14 @@ setData() { # ---- [dfl: abbrv. for DATA_FILE_LIST] ------------------------- elif [[ $(hostname) == x3* ]]; then # dfl_fallback="/eagle/datasets/dolma/data_file_list_reweighted.txt" dfl_fallback="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list.txt" - dfl_fallback_u="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list_u.txt" - dfl_fallback_p="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list_p.txt" + # dfl_fallback_u="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list_u.txt" + # dfl_fallback_p="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list_p.txt" + # dfl_fallback_u="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_yeast_file_list_u.txt" + # dfl_fallback_p="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_yeast_file_list_p.txt" + # dfl_fallback_u="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_textseq_yeast_file_list_u.txt" + # dfl_fallback_p="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_textseq_yeast_file_list_p.txt" + dfl_fallback_u="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_textseq_proteingym_indels_file_list_u.txt" + dfl_fallback_p="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_textseq_proteingym_indels_file_list_p.txt" else echo "Unknown hostname. Must manually specify DATA_FILE_LIST." fi From cbbd2a6279ab726aa8253ef0deeae03f90a017f4 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:31:13 +0000 Subject: [PATCH 29/50] generate func post model training --- generate_utils.py | 354 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 generate_utils.py diff --git a/generate_utils.py b/generate_utils.py new file mode 100644 index 0000000000..efb87d576e --- /dev/null +++ b/generate_utils.py @@ -0,0 +1,354 @@ + + +"""Generate function post training""" + +import os +from rich import print +import torch +import math +import numpy as np +import time +from datetime import datetime +import threading + +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import GPTModel, GPTModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.arguments import core_transformer_config_from_args +from megatron.utils import ( + report_memory, + throughput_calculator, + checkpoint_throughput_calculator +) +from pathlib import Path + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator.real_accelerator import get_accelerator +import subprocess +import wandb + +from torch import nn +import torch.nn.functional as F + +# from ezpz import get_logger +from ezpz.dist import get_world_size, setup_wandb, get_rank + +# More imports +from megatron.initialize import initialize_megatron +from megatron.initialize import set_jit_fusion_options +from megatron.training import print_datetime, _create_ds_config_dict +from megatron.training import setup_model_and_optimizer +from megatron.training import load_model_weights_only, get_model +from megatron.training import get_optimizer_param_scheduler, cyclic_iter +from megatron.optimizer import get_megatron_optimizer +from megatron.checkpointing import load_checkpoint +from megatron.data.data_samplers import build_pretraining_data_loader +from megatron.arguments import core_transformer_config_from_args +from megatron import update_num_microbatches +from megatron import get_num_microbatches +from megatron.utils import throughput_calculator, get_parameters_in_billions +from megatron.text_generation import generate_and_post_process, beam_search_and_post_process +from megatron.text_generation.forward_step import ForwardStep, InferenceParams +from megatron.text_generation.sampling import sample +from megatron.text_generation.tokenization import detokenize_generations +from megatron.text_generation.communication import ( + copy_from_last_to_first_pipeline_stage, + broadcast_from_last_pipeline_stage, + broadcast_from_last_to_first_pipeline_stage) +from megatron.checkpointing import save_checkpoint +from megatron.utils import get_ltor_masks_and_position_ids + + +def generate_post_training( + model, prompts, tokens_to_generate, + top_k = 0, + top_p = 1.0, + temperature = 1.0, + top_p_decay=0.0, + top_p_bound=0.0, + add_BOS=False, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False, + prevent_newline_after_colon=False, + random_seed=42, + return_output_log_probs = False, + fprint=True + ): + + print_rank_0(f'Generation mode..') + model[0].eval() + + args = get_args() + print_rank_0(f'Seq length in args: {args.seq_length}') + + tokenizer = get_tokenizer() + print_rank_0(f'Number of elements in tokenizer vocab: {len(tokenizer.vocab)}') + # prompts=["A sequence", "A sequence","A sequence", "A sequence", "A sequence"] + # tokens_to_generate = 64 + + # add_BOS = False + if add_BOS: + prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt) + for prompt in prompts] + else: + prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts] + + if fprint: print_rank_0(f'prompts_tokens: {prompts_tokens}') + + # Make all tokenized prompts to be of same length as max length of the prompts + prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens] + max_prompt_len = max(prompts_length) + samples_length = max_prompt_len + tokens_to_generate + for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length): + padding_size = samples_length - prompt_length + prompt_tokens.extend([tokenizer.eod] * padding_size) + + # Now we are in a structured format, we can convert to tensors + prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens) + prompts_length_tensor = torch.cuda.LongTensor(prompts_length) + if fprint: + print_rank_0(f'prompts_tokens_tensor: {prompts_tokens_tensor}') + print_rank_0(f'prompts_length_tensor: {prompts_length_tensor}') + + # Getting attributes to set inference_params + batch_size = prompts_tokens_tensor.size(0) + min_prompt_length = prompts_length_tensor.min().item() + max_sequence_length = prompts_tokens_tensor.size(1) + + if fprint: + print_rank_0(f'batch_size: {batch_size}') + print_rank_0(f'min_prompt_length: {min_prompt_length}') + print_rank_0(f'max_sequence_length: {max_sequence_length}') + print_rank_0(f'max_position_embeddings: {args.max_position_embeddings}') + print_rank_0(f'args.max_tokens_to_oom: {args.max_tokens_to_oom}') + + if max_sequence_length > args.max_position_embeddings: + raise ValueError("Length of prompt + tokens_to_generate longer than allowed") + + if max_sequence_length * batch_size > args.max_tokens_to_oom: + raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) + + # INSTANTIATING FORWARD_STEP ? + # model_fwd = ForwardStep(model[0], batch_size, max_sequence_length) + inference_params = InferenceParams(batch_size, + max_sequence_length) + + if hasattr(args, 'eos_id'): + termination_id = args.eos_id + print_rank_0(f'args.eos_id: {args.eos_id}') + else: + termination_id = tokenizer.eod + print_rank_0(f'tokenizer.eod: {tokenizer.eod}') + + # Log probability of the sequence (prompt + generated tokens). + output_log_probs = None + output_log_probs_size = (batch_size, max_sequence_length - 1) + # Lengths of generated seuquence including including prompts. + generated_sequence_lengths = None + + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: + output_log_probs = torch.empty(output_log_probs_size, + dtype=torch.float32, + device=torch.cuda.current_device()) + if fprint: print_rank_0(f'On mpu.is_pipeline_last_stage branch and output_log_probs is set: {output_log_probs}') + generated_sequence_lengths = torch.ones( + batch_size, dtype=torch.int64, + device=torch.cuda.current_device()) * max_sequence_length + if fprint: print_rank_0(f'On mpu.is_pipeline_last_stage branch and generated_sequence_lengths: {generated_sequence_lengths}') + + # Whether we have reached a termination id. + is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, + device=torch.cuda.current_device()) + + + with torch.no_grad(): + prompts_attention_mask, _, prompts_position_ids = get_ltor_masks_and_position_ids( + data=prompts_tokens_tensor, + eod_token=None, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False + ) + prev_context_length = 0 + for context_length in range(min_prompt_length, max_sequence_length): + # Pick the slice that we need to pass through the network. + tokens2use = prompts_tokens_tensor[:, prev_context_length:context_length] + positions2use = prompts_position_ids[:, prev_context_length:context_length] + attention_mask2use = prompts_attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # #logits will be meanigful only in the last pipeline stage. + if fprint: + print_rank_0(f'tokens2use shape: {tokens2use.size()}') + print_rank_0(f'positions2use shape: {positions2use.size()}') + print_rank_0(f'attention_mask2use shape: {attention_mask2use.size()}') + print_rank_0(f'prompts_tokens_tensor shape: {prompts_tokens_tensor.size()}') + print_rank_0(f'prompts_position_ids shape: {prompts_position_ids.size()}') + print_rank_0(f'prompts_attention_mask shape: {prompts_attention_mask.size()}') + + # ------ + # plogits = forward_step(tokens2use, positions2use, attention_mask2use) + # plogits = plogits[0] + # print_rank_0(f'context_length: {context_length}, plogits: {plogits}') + + # plogits = model[0](prompts_tokens_tensor, + # prompts_position_ids, + # prompts_attention_mask, + # inference_params=inference_params + # ) + # print_rank_0(f'logits: {plogits}') + #------- + + # Changing seq length in inference params dynamically + inference_params = InferenceParams(batch_size, + tokens2use.size(1)) + plogits = model[0](tokens2use, + positions2use, + attention_mask2use, + inference_params=inference_params + ) + plogits = plogits[0] + # plogits = torch.cuda.FloatTensor(plogits) + if fprint: + print_rank_0(f'plogits: {plogits.size()}') + print_rank_0(f'plogits type: {plogits.dtype}') + + if mpu.is_pipeline_last_stage(): + if prevent_newline_after_colon: + plogits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":" + # Always the last stage should have an output. + assert plogits is not None + + # Sample. + last_token_logits = plogits[:, -1, :] + new_sample = sample(last_token_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + vocab_size=tokenizer.vocab_size) + if top_p > 0.0 and top_p_decay > 0.0: + top_p = top_p * top_p_decay + if top_p_bound > 0.0: + top_p = max(top_p, top_p_bound) + + if fprint: + print_rank_0(f'new_sample: {new_sample}') + for nidx, ns in enumerate(new_sample.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, new_sample[{nidx}]: {tokenizer.detokenize(ns)}') + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = prompts_length_tensor <= context_length + # Update the tokens. + if fprint: + print_rank_0(f'started: {started}') + # print_rank_0(f'prompts_tokens_tensor before copying new_sample: {prompts_tokens_tensor}') + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor before[{nidx}]: {tokenizer.detokenize(ns)}') + + prompts_tokens_tensor[started, context_length] = new_sample[started] + if fprint: + # print_rank_0(f'prompts_tokens_tensor after copying new_sample: {prompts_tokens_tensor}') + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor after[{nidx}]: {tokenizer.detokenize(ns)}') + + # Update the tokens on the first stage so the next input to + # the network is correct. + copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, + prompts_tokens_tensor[:, context_length]) + # for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + # print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor after copy_from_last_to_first_pipeline_stage [{nidx}]: {tokenizer.detokenize(ns)}') + + # Update the context length for the next token generation. + prev_context_length = context_length + if fprint: print_rank_0(f'prev_context_length: {prev_context_length}') + + # Check if all the sequences have hit the termination_id. + done = None + if mpu.is_pipeline_last_stage(): + # These stopping methods are tokenizer dependent + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() + + just_finished = (done_token & ~is_generation_done).bool() + generated_sequence_lengths[just_finished.view(-1)] = \ + context_length + 1 + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) + done = broadcast_from_last_pipeline_stage(1, torch.uint8, + tensor=done) + if use_eod_token_for_early_termination and done: + print_rank_0(f'done: {done}') + break + + # =================================================== + # Update the length of based on max generated length. + # =================================================== + # for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + # print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop [{nidx}]: {tokenizer.detokenize(ns)}') + prompts_tokens_tensor = prompts_tokens_tensor[:, :(context_length + 1)] + # for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + # print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and slicing with ctx length[{nidx}]: {tokenizer.detokenize(ns)}') + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: + output_log_probs = output_log_probs[:, :context_length] + + # ====================================== + # Broadcast to the first pipeline stage. + # ====================================== + + generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( + batch_size, torch.int64, generated_sequence_lengths) + if return_output_log_probs: + output_log_probs_size = (batch_size, context_length) + output_log_probs = broadcast_from_last_to_first_pipeline_stage( + output_log_probs_size, torch.float32, output_log_probs) + + # if fprint: + # for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + # print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and befoer final post-process[{nidx}]: {tokenizer.detokenize(ns)}') + # Only post-process on first stage. + if mpu.is_pipeline_first_stage(): + prompts_plus_generations = [] + + if fprint: + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and after final post-process[{nidx}]: {tokenizer.detokenize(ns)}') + + rtokens = prompts_tokens_tensor.cpu().numpy().tolist() + rlengths = prompts_length_tensor.cpu().numpy().tolist() + if fprint: print_rank_0(f'rlengths: {rlengths}') + # for sequence_tokens, slength in zip(rtokens, rlengths): + for sequence_tokens in rtokens: + # sequence_tokens = sequence_tokens[:slength] + prompts_plus_generations.append( + tokenizer.detokenize(sequence_tokens)) + # _, prompts_plus_generations, prompts_plus_generations_segments = \ + # detokenize_generations(prompts_tokens_tensor, prompts_length_tensor, True) + + for prompt, prompt_response in zip(prompts, prompts_plus_generations): + print_rank_0(f'------------------') + print_rank_0(f'prompt: {prompt}') + print_rank_0(f'prompt and response: {prompt_response}') + + return prompts_plus_generations From e1d3e6ff79eafcf11503020d1b8debf99423a48f Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:31:49 +0000 Subject: [PATCH 30/50] flops profiler enabled --- ALCF/helpers.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ALCF/helpers.sh b/ALCF/helpers.sh index ce84346881..e42fe27c09 100644 --- a/ALCF/helpers.sh +++ b/ALCF/helpers.sh @@ -259,7 +259,8 @@ setData() { # ---- [dfl: abbrv. for DATA_FILE_LIST] ------------------------- dfl_fallback="/gila/Aurora_deployment/AuroraGPT/datasets/dolma/data_file_list_reweighted.txt" elif [[ $(hostname) == x3* ]]; then # dfl_fallback="/eagle/datasets/dolma/data_file_list_reweighted.txt" - dfl_fallback="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list.txt" + dfl_fallback="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/convergence_debug_small.txt" + # dfl_fallback="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list.txt" # dfl_fallback_u="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list_u.txt" # dfl_fallback_p="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_file_list_p.txt" # dfl_fallback_u="/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/data_yeast_file_list_u.txt" From a976fd74b8524043ee1349d3a620d48b135ea75d Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:33:07 +0000 Subject: [PATCH 31/50] flops profiler enabled --- generate_config.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generate_config.sh b/generate_config.sh index 6bea420a2a..a300db85d7 100644 --- a/generate_config.sh +++ b/generate_config.sh @@ -48,8 +48,8 @@ common="\ flops_profiler="\ \"flops_profiler\": { - \"enabled\": false, - \"profile_step\": 45, + \"enabled\": true, + \"profile_step\": 5, \"module_depth\": -1, \"top_modules\": 1, \"detailed\": true, From b964d58d0d6fce81c9df570b91f28c68f04094ff Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:34:35 +0000 Subject: [PATCH 32/50] removed dtype assertion that throws error --- megatron/text_generation/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/text_generation/sampling.py b/megatron/text_generation/sampling.py index 370773a36c..06ef504778 100644 --- a/megatron/text_generation/sampling.py +++ b/megatron/text_generation/sampling.py @@ -53,8 +53,8 @@ def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): # Check logits for consistency. assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' - assert logits.type() == 'torch.cuda.FloatTensor', \ - 'input logits should be floats.' + # assert logits.type() == 'torch.cuda.FloatTensor', \ + # 'input logits should be floats.' # Greedy is just simple argmax. From ad780cb1a31e1b861732501c5856877f2674f5cd Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:35:15 +0000 Subject: [PATCH 33/50] changed to generation after dpo --- dpo_training.py | 708 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 545 insertions(+), 163 deletions(-) diff --git a/dpo_training.py b/dpo_training.py index 9ba7412f4d..dd02b537f3 100644 --- a/dpo_training.py +++ b/dpo_training.py @@ -6,10 +6,12 @@ from rich import print import torch import math +import numpy as np # The earliest we can measure the start time. import time from datetime import datetime +import threading from functools import partial from megatron import get_args @@ -61,6 +63,18 @@ from megatron.arguments import core_transformer_config_from_args from megatron import update_num_microbatches from megatron import get_num_microbatches +from megatron.utils import throughput_calculator, get_parameters_in_billions +from megatron.text_generation import generate_and_post_process, beam_search_and_post_process +from megatron.text_generation.forward_step import ForwardStep, InferenceParams +from megatron.text_generation.sampling import sample +from megatron.text_generation.tokenization import detokenize_generations +from megatron.text_generation.communication import ( + copy_from_last_to_first_pipeline_stage, + broadcast_from_last_pipeline_stage, + broadcast_from_last_to_first_pipeline_stage) +from megatron.checkpointing import save_checkpoint +from megatron.utils import get_ltor_masks_and_position_ids +from generate_utils import generate_post_training # RANK = setup_torch( # backend='deepspeed', @@ -90,7 +104,6 @@ print('--------------------------------------------------') setup_wandb(project_name=project_name) - def model_provider(pre_process=True, post_process=True): """Build the model.""" print_rank_0('building GPT model ...') @@ -185,6 +198,33 @@ def model_provider(pre_process=True, post_process=True): # wandb.run.config.update({'num_params': num_params}) return model +def throughput_flops(model, args, iteration_time, total_iterations): + batch_size = args.micro_batch_size * get_num_microbatches() * args.data_parallel_size + approx_parameters_in_billions = None if (model is None) else get_parameters_in_billions(model) + elapsed_time_per_iter = iteration_time/total_iterations + samples_per_second = batch_size / elapsed_time_per_iter + + #flops calculator + hidden_size = args.hidden_size + num_layers = args.num_layers + vocab_size = args.padded_vocab_size + + # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of + # https://arxiv.org/pdf/2104.04473.pdf). + # The factor of 4 is when used with activation check-pointing, + # otherwise it will be 3. + checkpoint_activations_factor = 3 + if hasattr(args, 'checkpoint_activations') and args.checkpoint_activations: + checkpoint_activations_factor = 4 + if hasattr(args, 'recompute_granularity') and (args.recompute_granularity == 'selective' or args.recompute_granularity == 'full'): + checkpoint_activations_factor = 4 + seq_len = args.seq_length + if hasattr(args, 'actual_seq_length'): + seq_len = args.actual_seq_length + flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) + tflops = flops_per_iteration / (elapsed_time_per_iter * args.world_size * (10**12)) + + return tflops def get_batch(data_iterator): """Generate a batch""" @@ -784,7 +824,20 @@ def main(): args.compression_training = True # model = model_provider() - model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + # model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + model = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes? + # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider) + optimizer = get_megatron_optimizer(model, None, None, 1.0) + opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + model, optimizer, _, opt_param_scheduler = deepspeed.initialize( + model=model[0], + optimizer=optimizer, + args=args, + lr_scheduler=opt_param_scheduler, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + model = [model] # ---------- Reference model ------------- # model_ref, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) # throwing assertion error @@ -800,6 +853,25 @@ def main(): mpu=mpu if args.no_pipeline_parallel else None, config=args.deepspeed_config_dict, ) + # model_ref, _, _, _ = deepspeed.initialize( + # model=model_ref[0], + # optimizer=None, + # args=args, + # lr_scheduler=None, + # mpu=mpu if args.no_pipeline_parallel else None, + # config=args.deepspeed_config_dict, + # ) + # engine = deepspeed.init_inference(model=model_ref[0], + # mp_size=args.tensor_model_parallel_size, + # tensor_parallel={"mpu": mpu}, + # dtype=torch.half, + # replace_with_kernel_inject=True, + # # moe_experts=args.num_experts, + # # moe_type=args.mlp_type + # ) + # model_ref = engine.module + + if isinstance(model_ref, deepspeed.PipelineEngine): print(f'Doing assertion checks on model_ref..') # hack to get batch_fn from pretrain_gpt.py @@ -837,7 +909,7 @@ def main(): train_val_test_num_samples = [train_samples, eval_iters * args.global_batch_size, test_iters * args.global_batch_size] - print(f'train_val_test_num_samples: {train_val_test_num_samples}') + print_rank_0(f'train_val_test_num_samples: {train_val_test_num_samples}') # print(f'args.data_impl: {args.data_impl}') # print(f'args.split: {args.split}') # print(f'args.seq_length: {args.seq_length}') @@ -869,7 +941,7 @@ def main(): print_rank_0("> finished creating unpreferred GPT datasets ...") if args.data_file_list_p is not None: - print(f'data files list preferred: {args.data_file_list_p}') + print_rank_0(f'data files list preferred: {args.data_file_list_p}') files_p = [] with open(args.data_file_list_p, 'r') as flist: @@ -893,8 +965,8 @@ def main(): print_rank_0("> finished creating preferred GPT datasets ...") # Data loaders - print(f'args.consumed_train_samples: {args.consumed_train_samples}') - print(f'args.dataloader_type: {args.dataloader_type}') + print_rank_0(f'args.consumed_train_samples: {args.consumed_train_samples}') + print_rank_0(f'args.dataloader_type: {args.dataloader_type}') train_dataloader_u = build_pretraining_data_loader( train_ds_u, args.consumed_train_samples) train_dataloader_p = build_pretraining_data_loader( @@ -905,182 +977,489 @@ def main(): assert dl_type in ['single', 'cyclic'] if train_dataloader_u is not None: - print(f'unpreferred train_dataloader is not None..') + print_rank_0(f'unpreferred train_dataloader is not None..') train_data_iterator_u = iter(train_dataloader_u) if dl_type == 'single' \ else iter(cyclic_iter(train_dataloader_u)) print_rank_0("> finished creating unpreferred train_data_iterator...") if train_dataloader_p is not None: - print(f'preferred train_dataloader is not None..') + print_rank_0(f'preferred train_dataloader is not None..') train_data_iterator_p = iter(train_dataloader_p) if dl_type == 'single' \ else iter(cyclic_iter(train_dataloader_p)) print_rank_0("> finished creating preferred train_data_iterator...") - iteration = 0 print_rank_0(f'args.train_iters: {args.train_iters}') + print_rank_0(f'args.save_interval: {args.save_interval}') report_memory_flag = True + + # Train model + model[0].train() + if torch.distributed.get_rank() == 0: averaged_loss_iter = [] averaged_rewards_iter = [] - for i in range(args.train_iters): - # Get batch - timers = get_timers() - timers('batch-generator-unpreferred', log_level=2).start() - tokens_u, labels_u, loss_mask_u, attention_mask_u, position_ids_u = get_batch( - train_data_iterator_u) - timers('batch-generator-unpreferred').stop() - # print(f'tokens shape: {tokens_u.shape}') - print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for unpref train_data_iterator ...") - - timers('batch-generator-preferred', log_level=2).start() - tokens_p, labels_p, loss_mask_p, attention_mask_p, position_ids_p = get_batch( - train_data_iterator_p) - timers('batch-generator-preferred').stop() - # print(f'tokens shape: {tokens_u.shape}') - print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for pref train_data_iterator ...") - - # Model forward - # output_tensor, other_losses = model[0]( - # tokens_u, - # position_ids_u, - # attention_mask_u, - # labels=labels_u - # ) # OUT OF MEMORY ERROR even with 4 nodes - - # Computing logits and logps for preferred and unpreferred data batches - # output_u, other_losses_u = model[0](tokens_u, position_ids_u, attention_mask_u) # THIS WORKED with 4 nodes for 7B model - # print_rank_0("> finished a forward pass to get unpref logits ...") - - - tokens_c = torch.cat((tokens_p,tokens_u), 0) - position_ids_c = torch.cat((position_ids_p,position_ids_u), 0) - labels_c = torch.cat((labels_p,labels_u), 0) - loss_mask_c = torch.cat((loss_mask_p,loss_mask_u), 0) - - output_c, other_losses_c = model[0]( - tokens_c, - position_ids_c, - None, - # labels=labels_u + avg_loss_epoch = [] + avg_rewards_epoch = [] + + for epoch in range(1): + iteration = 0 + for i in range(args.train_iters): + # Get batch + timers = get_timers() + timers('batch-generator-unpreferred', log_level=2).start() + tokens_u, labels_u, loss_mask_u, attention_mask_u, position_ids_u = get_batch( + train_data_iterator_u) + timers('batch-generator-unpreferred').stop() + # print_rank_0(f'tokens_u[0].size(): {tokens_u[0].size()}') + # print_rank_0(f'tokens_u[0,400:1024]: {tokens_u[0,400:1024]}') + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for unpref train_data_iterator ...") + + timers('batch-generator-preferred', log_level=2).start() + tokens_p, labels_p, loss_mask_p, attention_mask_p, position_ids_p = get_batch( + train_data_iterator_p) + timers('batch-generator-preferred').stop() + # print(f'tokens shape: {tokens_u.shape}') + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for pref train_data_iterator ...") + + # Model forward + # output_tensor, other_losses = model[0]( + # tokens_u, + # position_ids_u, + # attention_mask_u, + # labels=labels_u + # ) # OUT OF MEMORY ERROR even with 4 nodes + + # Model forward with concatenated inputs + tokens_c = torch.cat((tokens_p,tokens_u), 0) + position_ids_c = torch.cat((position_ids_p,position_ids_u), 0) + labels_c = torch.cat((labels_p,labels_u), 0) + loss_mask_c = torch.cat((loss_mask_p,loss_mask_u), 0) + + # Logits and loss + output_c, other_losses_c = model[0]( + tokens_c, + position_ids_c, + None, + # labels=labels_u + ) + + loss_c = tensor_parallel.vocab_parallel_cross_entropy( + output_c.contiguous().float(), + labels_c ) - loss_c = tensor_parallel.vocab_parallel_cross_entropy( - output_c.contiguous().float(), - labels_c + # Reference model forward with concatenated inputs + with torch.no_grad(): + # Logits and loss + routput_c, rother_losses_c = model_ref[0]( + tokens_c, + position_ids_c, + None, + # labels=labels_u ) - - with torch.no_grad(): - routput_c, rother_losses_c = model_ref[0]( - tokens_c, - position_ids_c, - None, - # labels=labels_u + rloss_c = tensor_parallel.vocab_parallel_cross_entropy( + routput_c.contiguous().float(), + labels_c ) - rloss_c = tensor_parallel.vocab_parallel_cross_entropy( - routput_c.contiguous().float(), - labels_c - ) - # print(f'tokens_p: {tokens_p}') - # print(f'tokens_u: {tokens_u}') - # # print(f'output_p[0]: {output_p[0]}') - # # print(f'output_u[0]: {output_u[0]}') - # print(f'output_c[0]: {output_c[0]}') - # print(f'tokens_p shape: {tokens_p.size()}, tokens_u shape: {tokens_u.size()}') - # print(f'tokens_c shape: {tokens_c.size()}') - # print(f'position_ids_p shape: {position_ids_p.size()}, position_ids_u shape: {position_ids_u.size()}') - # print(f'position_ids_c shape: {position_ids_c.size()}') - # print(f'output_c shape: {output_c.size()}') - # print(f'loss_c shape: {loss_c.size()}') - # print(f'routput_c shape: {routput_c.size()}') - # print(f'rloss_c shape: {rloss_c.size()}') - # print(f'loss_mask_p shape: {loss_mask_p.size()}') - # print(f'loss_mask_u shape: {loss_mask_u.size()}') - # print(f'loss_mask_c shape: {loss_mask_c.size()}') - # print(f'attention_mask_u: {attention_mask_u}') - - seq_logps_p = torch.sum(loss_c[:8,:] * loss_mask_p, dim=-1) / torch.sum(loss_mask_p, dim=-1) - seq_logps_u = torch.sum(loss_c[8:,:] * loss_mask_u, dim=-1) / torch.sum(loss_mask_u, dim=-1) - rseq_logps_p = torch.sum(rloss_c[:8,:] * loss_mask_p, dim=-1) / torch.sum(loss_mask_p, dim=-1) - rseq_logps_u = torch.sum(rloss_c[8:,:] * loss_mask_u, dim=-1) / torch.sum(loss_mask_u, dim=-1) - # print(f'seq_logps_p shape: {seq_logps_p.size()}') - # print(f'seq_logps_u shape: {seq_logps_u.size()}') - # print(f'rseq_logps_p shape: {rseq_logps_p.size()}') - # print(f'rseq_logps_u shape: {rseq_logps_u.size()}') - - pu_ratio = seq_logps_p - seq_logps_u - rpu_ratio = rseq_logps_p - rseq_logps_u - sdiff_ratio = 0.1*(pu_ratio - rpu_ratio) - # print(f'sdiff_ratio: {sdiff_ratio}') - final = -F.logsigmoid(sdiff_ratio) - # print(f'final: {final}') - - dloss = torch.sum(final) - # print_rank_0(f'iteration: {iteration}, mos_loss: {dloss}') - # print(f'mos_loss shape: {mos_loss.size()}') - - # print(f'args.ds_pipeline_enabled: {args.ds_pipeline_enabled}') - # print(f'args.no_pipeline_parallel: {args.no_pipeline_parallel}') - # if args.deepspeed and args.ds_pipeline_enabled: - # print(f'In train step if args.deepspeed and args.ds_pipeline_enabled..') - - # print(f'loss_mask_p: {loss_mask_p}') - - # print(f'loss_mask_p sum: {torch.sum(loss_mask_p), 8*4096}')# print(f'loss_mask_p shape: {loss_mask_p.size()}') - - model[0].train() - model[0].backward(dloss) - - increment = get_num_microbatches() * \ - args.micro_batch_size * \ - args.data_parallel_size - # print(f'increment: {increment}') - # model[0].step(lr_kwargs={'increment': increment}) - model[0].step() - update_successful = model[0].was_step_applied() - print_rank_0(f'update_successful: {update_successful}') - - iteration += 1 - args.iteration = iteration - new_samples = mpu.get_data_parallel_world_size() * \ - args.micro_batch_size * \ - get_num_microbatches() - - # print(f'args.consumed_train_samples: {args.consumed_train_samples}') - args.consumed_train_samples += new_samples - # print(f'args.consumed_train_samples: {args.consumed_train_samples}') - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([dloss]) - loss_dict = {'loss': averaged_loss} - print_rank_0(f'iteration: {iteration}, dloss: {averaged_loss.detach().cpu().tolist()}') - psrewards_p = (0.1 * (seq_logps_p - rseq_logps_p)).detach() - psrewards_u = (0.1 * (seq_logps_u - rseq_logps_u)).detach() - psrewards = (psrewards_p > psrewards_u).float() - rewards = psrewards.cpu().mean() - print_rank_0(f'iteration: {iteration}, rewards: {rewards}') - if torch.distributed.get_rank() == 0: - averaged_loss_iter.append(averaged_loss.detach().cpu().tolist()[0]) - averaged_rewards_iter.append(rewards.tolist()) - - # report_memory_flag = training_log_dpo(loss_dict, iteration, report_memory_flag) - - # logprobs_u = torch.exp(output_tensor_u) - # # print(f'Computed unpreferred output_tensor: {output_tensor_u}') - # print(f'Computed unpreferred logprobs: {logprobs_u}') - - # output_p, other_losses_p = model[0](tokens_p, position_ids_p, attention_mask_p) # THIS WORKED with 4 nodes for 7B model - # print_rank_0("> finished a forward pass to get pref logits ...") - - # output_tensor_p = tensor_parallel.vocab_parallel_cross_entropy( - # output_p.contiguous().float(), - # labels_p - # ) # BUT THIS DID NOT WORK WITH 4 NODES - OOM ERROR for 7B model (but worked for 1B model on 2 nodes) - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(averaged_loss_iter) - print(averaged_rewards_iter) + # # Print statements for debugging + # print(f'tokens_p: {tokens_p}') + # print(f'tokens_u: {tokens_u}') + # # print(f'output_p[0]: {output_p[0]}') + # # print(f'output_u[0]: {output_u[0]}') + # print(f'output_c[0]: {output_c[0]}') + # print(f'tokens_p shape: {tokens_p.size()}, tokens_u shape: {tokens_u.size()}') + # print(f'tokens_c shape: {tokens_c.size()}') + # print(f'position_ids_p shape: {position_ids_p.size()}, position_ids_u shape: {position_ids_u.size()}') + # print(f'position_ids_c shape: {position_ids_c.size()}') + # print(f'output_c shape: {output_c.size()}') + # print(f'loss_c shape: {loss_c.size()}') + # print(f'routput_c shape: {routput_c.size()}') + # print(f'rloss_c shape: {rloss_c.size()}') + # print(f'loss_mask_p shape: {loss_mask_p.size()}') + # print(f'loss_mask_u shape: {loss_mask_u.size()}') + # print(f'loss_mask_c shape: {loss_mask_c.size()}') + # print(f'attention_mask_u: {attention_mask_u}') + # print(f'loss_mask_p sum: {torch.sum(loss_mask_p), 8*4096}')# print(f'loss_mask_p shape: {loss_mask_p.size()}') + + # Seq logprobs + print_rank_0(f'args.micro_batch_size: {args.micro_batch_size}') + seq_logps_p = torch.sum(loss_c[:args.micro_batch_size,:] * loss_mask_p, dim=-1) / torch.sum(loss_mask_p, dim=-1) + seq_logps_u = torch.sum(loss_c[args.micro_batch_size:,:] * loss_mask_u, dim=-1) / torch.sum(loss_mask_u, dim=-1) + rseq_logps_p = torch.sum(rloss_c[:args.micro_batch_size,:] * loss_mask_p, dim=-1) / torch.sum(loss_mask_p, dim=-1) + rseq_logps_u = torch.sum(rloss_c[args.micro_batch_size:,:] * loss_mask_u, dim=-1) / torch.sum(loss_mask_u, dim=-1) + + # # Print statements for debugging + # print(f'seq_logps_p shape: {seq_logps_p.size()}') + # print(f'seq_logps_u shape: {seq_logps_u.size()}') + # print(f'rseq_logps_p shape: {rseq_logps_p.size()}') + # print(f'rseq_logps_u shape: {rseq_logps_u.size()}') + + # Loss + pu_ratio = seq_logps_p - seq_logps_u + rpu_ratio = rseq_logps_p - rseq_logps_u + sdiff_ratio = 0.1*(pu_ratio - rpu_ratio) + # print(f'sdiff_ratio: {sdiff_ratio}') + final = -F.logsigmoid(sdiff_ratio) + # print(f'final: {final}') + # dloss = torch.sum(final) + dloss = torch.mean(final) + + # Model backward and update + model[0].backward(dloss) + + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + # print(f'increment: {increment}') + # model[0].step(lr_kwargs={'increment': increment}) + model[0].step() + update_successful = model[0].was_step_applied() + print_rank_0(f'update_successful: {update_successful}') + + # Iteration updates + iteration += 1 + args.iteration = iteration + # print(f'args.consumed_train_samples: {args.consumed_train_samples}') + new_samples = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + + + args.consumed_train_samples += new_samples + # print(f'args.consumed_train_samples: {args.consumed_train_samples}') + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([dloss]) + loss_dict = {'loss': averaged_loss} + print_rank_0(f'iteration: {iteration}, dloss: {averaged_loss.detach().cpu().tolist()}') + psrewards_p = (0.1 * (seq_logps_p - rseq_logps_p)).detach() + psrewards_u = (0.1 * (seq_logps_u - rseq_logps_u)).detach() + psrewards = (psrewards_p > psrewards_u).float() + rewards = psrewards.cpu().mean() + print_rank_0(f'iteration: {iteration}, rewards: {rewards}') + + # wandb logging + # report_memory_flag = training_log_dpo(loss_dict, iteration, report_memory_flag) + + if torch.distributed.get_rank() == 0: + averaged_loss_iter.append(averaged_loss.detach().cpu().tolist()[0]) + averaged_rewards_iter.append(rewards.tolist()) + + if (i % args.save_interval == 0) and (i > 0) and (torch.distributed.get_rank() == 0): + TPL = os.environ.get('TP') + GRAD_ACC = os.environ.get('GRAD_ACC_STEPS') + print(f'Checkpointing loss and rewards at iteration {i} ..') + np.savez(f'./runs/proteingym_indels/loss-rewards_indels_textseq_nranks-{WORLD_SIZE}_model-nlayers-{args.num_layers}_TP-{TPL}_zero-{args.zero_stage}_gradacc-{GRAD_ACC}_seq-{args.seq_length}_bs-{args.micro_batch_size}_iters-{args.train_iters}-chkpt-{i}.npz', loss=np.array(averaged_loss_iter), rewards=np.array(averaged_rewards_iter)) + + # if torch.distributed.get_rank() == 0: + # avg_loss_epoch.append(np.array(averaged_loss_iter).mean()) + # avg_rewards_epoch.append(np.array(averaged_rewards_iter).mean()) + + # Aggregated loss and rewards + # torch.distributed.barrier() + # if torch.distributed.get_rank() == 0: + # print(averaged_loss_iter) + # print(averaged_rewards_iter) + # print(avg_loss_epoch) + # print(avg_rewards_epoch) + # np.savez(f'./runs/proteingym_indels/loss-rewards_iters-{args.train_iters}.npz', loss=np.array(averaged_loss_iter), rewards=np.array(averaged_rewards_iter)) + + # Generate - NOT WORKING + if False: + model[0].eval() + print_rank_0(f'Generation mode..') + print_rank_0(f'args.seq_length: {args.seq_length}') + tokenizer = get_tokenizer() + print_rank_0(f'len(tokenizer.vocab): {len(tokenizer.vocab)}') + prompts=["A sequence", "A sequence","A sequence", "A sequence", "A sequence"] + tokens_to_generate = 64 + add_BOS = False + if add_BOS: + prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt) + for prompt in prompts] + else: + prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts] + + print_rank_0(f'prompts_tokens: {prompts_tokens}') + + # Now we have a list of list of tokens which each list has a different + # size. We want to extend this list to: + # - incorporate the tokens that need to be generated + # - make all the sequences equal length. + # Get the prompts length. + prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens] + # Get the max prompts length. + max_prompt_len = max(prompts_length) + # Number of tokens in the each sample of the batch. + samples_length = max_prompt_len + tokens_to_generate + # Now update the list of list to be of the same size: samples_length. + for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length): + padding_size = samples_length - prompt_length + prompt_tokens.extend([tokenizer.eod] * padding_size) + + # Now we are in a structured format, we can convert to tensors. + prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens) + prompts_length_tensor = torch.cuda.LongTensor(prompts_length) + print_rank_0(f'prompts_tokens_tensor: {prompts_tokens_tensor}') + print_rank_0(f'prompts_length_tensor: {prompts_length_tensor}') + + batch_size = prompts_tokens_tensor.size(0) + min_prompt_length = prompts_length_tensor.min().item() + max_sequence_length = prompts_tokens_tensor.size(1) + + print_rank_0(f'batch_size: {batch_size}') + print_rank_0(f'min_prompt_length: {min_prompt_length}') + print_rank_0(f'max_sequence_length: {max_sequence_length}') + print_rank_0(f'max_position_embeddings: {args.max_position_embeddings}') + print_rank_0(f'args.max_tokens_to_oom: {args.max_tokens_to_oom}') + if max_sequence_length > args.max_position_embeddings: + raise ValueError("Length of prompt + tokens_to_generate longer than allowed") + + if max_sequence_length * batch_size > args.max_tokens_to_oom: + raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) + + # INSTANTIATING FORWARD_STEP ? + model_fwd = ForwardStep(model[0], batch_size, max_sequence_length) + inference_params = InferenceParams(batch_size, + max_sequence_length) + + if hasattr(args, 'eos_id'): + termination_id = args.eos_id + print(f'args.eos_id: {args.eos_id}') + else: + termination_id = tokenizer.eod + print(f'tokenizer.eod: {tokenizer.eod}') + + # Log probability of the sequence (prompt + generated tokens). + output_log_probs = None + output_log_probs_size = (batch_size, max_sequence_length - 1) + # Lengths of generated seuquence including including prompts. + generated_sequence_lengths = None + top_k = 0 + top_p = 1.0 + temperature = 1.0 + top_p_decay=0.0 + top_p_bound=0.0 + add_BOS=False + use_eod_token_for_early_termination=True + stop_on_double_eol=False + stop_on_eol=False + prevent_newline_after_colon=False + random_seed=42 + return_output_log_probs = False + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: + output_log_probs = torch.empty(output_log_probs_size, + dtype=torch.float32, + device=torch.cuda.current_device()) + print_rank_0(f'On mpu.is_pipeline_last_stage branch and output_log_probs is set: {output_log_probs}') + generated_sequence_lengths = torch.ones( + batch_size, dtype=torch.int64, + device=torch.cuda.current_device()) * max_sequence_length + print_rank_0(f'On mpu.is_pipeline_last_stage branch and generated_sequence_lengths: {generated_sequence_lengths}') + + # Whether we have reached a termination id. + is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, + device=torch.cuda.current_device()) + + + with torch.no_grad(): + prompts_attention_mask, _, prompts_position_ids = get_ltor_masks_and_position_ids( + data=prompts_tokens_tensor, + eod_token=None, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False + ) + prev_context_length = 0 + for context_length in range(min_prompt_length, max_sequence_length): + # Pick the slice that we need to pass through the network. + tokens2use = prompts_tokens_tensor[:, prev_context_length:context_length] + positions2use = prompts_position_ids[:, prev_context_length:context_length] + attention_mask2use = prompts_attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + # print_rank_0(f'tokens2use shape: {tokens2use.size()}') + # print_rank_0(f'positions2use shape: {positions2use.size()}') + # print_rank_0(f'attention_mask2use shape: {attention_mask2use.size()}') + # print_rank_0(f'prompts_tokens_tensor shape: {prompts_tokens_tensor.size()}') + # print_rank_0(f'prompts_position_ids shape: {prompts_position_ids.size()}') + # print_rank_0(f'prompts_attention_mask shape: {prompts_attention_mask.size()}') + + # ------ + # plogits = forward_step(tokens2use, positions2use, attention_mask2use) + # plogits = plogits[0] + # print_rank_0(f'context_length: {context_length}, plogits: {plogits}') + + # plogits = model[0](prompts_tokens_tensor, + # prompts_position_ids, + # prompts_attention_mask, + # inference_params=inference_params + # ) + # print_rank_0(f'logits: {plogits}') + #------- + inference_params = InferenceParams(batch_size, + tokens2use.size(1)) + plogits = model[0](tokens2use, + positions2use, + attention_mask2use, + inference_params=inference_params + ) + plogits = plogits[0] + # plogits = torch.cuda.FloatTensor(plogits) + # print_rank_0(f'plogits: {plogits.size()}') + # print_rank_0(f'plogits type: {plogits.dtype}') + + if mpu.is_pipeline_last_stage(): + if prevent_newline_after_colon: + plogits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":" + # Always the last stage should have an output. + assert plogits is not None + + # Sample. + last_token_logits = plogits[:, -1, :] + new_sample = sample(last_token_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + vocab_size=tokenizer.vocab_size) + if top_p > 0.0 and top_p_decay > 0.0: + top_p = top_p * top_p_decay + if top_p_bound > 0.0: + top_p = max(top_p, top_p_bound) + print_rank_0(f'new_sample: {new_sample}') + for nidx, ns in enumerate(new_sample.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, new_sample[{nidx}]: {tokenizer.detokenize(ns)}') + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = prompts_length_tensor <= context_length + # Update the tokens. + print_rank_0(f'started: {started}') + # print_rank_0(f'prompts_tokens_tensor before copying new_sample: {prompts_tokens_tensor}') + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor before[{nidx}]: {tokenizer.detokenize(ns)}') + prompts_tokens_tensor[started, context_length] = new_sample[started] + # print_rank_0(f'prompts_tokens_tensor after copying new_sample: {prompts_tokens_tensor}') + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor after[{nidx}]: {tokenizer.detokenize(ns)}') + + # Update the tokens on the first stage so the next input to + # the network is correct. + copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, + prompts_tokens_tensor[:, context_length]) + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor after copy_from_last_to_first_pipeline_stage [{nidx}]: {tokenizer.detokenize(ns)}') + + # Update the context length for the next token generation. + prev_context_length = context_length + print_rank_0(f'prev_context_length: {prev_context_length}') + + # Check if all the sequences have hit the termination_id. + done = None + if mpu.is_pipeline_last_stage(): + # These stopping methods are tokenizer dependent + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() + + just_finished = (done_token & ~is_generation_done).bool() + generated_sequence_lengths[just_finished.view(-1)] = \ + context_length + 1 + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) + done = broadcast_from_last_pipeline_stage(1, torch.uint8, + tensor=done) + if use_eod_token_for_early_termination and done: + print_rank_0(f'done: {done}') + break + + # =================================================== + # Update the length of based on max generated length. + # =================================================== + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop [{nidx}]: {tokenizer.detokenize(ns)}') + prompts_tokens_tensor = prompts_tokens_tensor[:, :(context_length + 1)] + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and slicing with ctx length[{nidx}]: {tokenizer.detokenize(ns)}') + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: + output_log_probs = output_log_probs[:, :context_length] + + # ====================================== + # Broadcast to the first pipeline stage. + # ====================================== + + generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( + batch_size, torch.int64, generated_sequence_lengths) + if return_output_log_probs: + output_log_probs_size = (batch_size, context_length) + output_log_probs = broadcast_from_last_to_first_pipeline_stage( + output_log_probs_size, torch.float32, output_log_probs) + + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and befoer final post-process[{nidx}]: {tokenizer.detokenize(ns)}') + # Only post-process on first stage. + if mpu.is_pipeline_first_stage(): + prompts_plus_generations = [] + + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and after final post-process[{nidx}]: {tokenizer.detokenize(ns)}') + + rtokens = prompts_tokens_tensor.cpu().numpy().tolist() + rlengths = prompts_length_tensor.cpu().numpy().tolist() + print_rank_0(f'rlengths: {rlengths}') + for sequence_tokens, slength in zip(rtokens, rlengths): + sequence_tokens = sequence_tokens[:slength] + prompts_plus_generations.append( + tokenizer.detokenize(sequence_tokens)) + # _, prompts_plus_generations, prompts_plus_generations_segments = \ + # detokenize_generations(prompts_tokens_tensor, prompts_length_tensor, True) + + print_rank_0(f'prompts_plus_generations: {prompts_plus_generations}') + + if True: + prompts=["Pen is mightier than", "A sequence", "Pythagoras theorem", "A sequence", "Hello world"] + tokens_to_generate = 64 + generated_responses = generate_post_training(model, prompts, tokens_to_generate, fprint=False) + + if False: + print_rank_0(f'Generation mode..') + print_rank_0(f'args.seq_length: {args.seq_length}') + tokenizer = get_tokenizer() + print_rank_0(f'len(tokenizer.vocab): {len(tokenizer.vocab)}') + model[0].eval() + choice = torch.cuda.LongTensor(1) + torch.distributed.broadcast(choice, 0) + # if choice[0].item() == 0: + try: + tokens_to_generate_len = 1021 + response, _, _, _ = generate_and_post_process(model[0], prompts=["A sequence", "A sequence","A sequence", "A sequence", "A sequence", "A sequence","A sequence", "A sequence"], tokens_to_generate=tokens_to_generate_len) + print_rank_0(f'generation completed..\n response:{response}') + except ValueError as ve: + print_rank_0(f'ValueError: {ve}') + pass + # elif choice[0].item() == 1: + # try: + # response, _, _ = beam_search_and_post_process(model[0], prompts=["A sequence", "A sequence", "A sequence", "A sequence",], tokens_to_generate=32) + # print(f'generation completed..\n response:{response}') + # except ValueError as ve: + # print(f'ValueError: {ve}') + # pass + + # # Checkpointing + # if args.save and iteration != 0: + # save_checkpoint(iteration, model, optimizer, opt_param_scheduler) return model @@ -1122,7 +1501,10 @@ def main(): # data_post_process=data_post_process) import sys import deepspeed.comm as dist + + # Return trained model model = main() + dist.log_summary() if wandb.run is not None: print(f"wandb.run.name: {wandb.run.name}") From 6a3b03edc477f52c27d6ce4a7f4c978c911534ae Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:38:34 +0000 Subject: [PATCH 34/50] newer preprocess data script --- preprocess_data.py | 384 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 preprocess_data.py diff --git a/preprocess_data.py b/preprocess_data.py new file mode 100644 index 0000000000..bc76e324a9 --- /dev/null +++ b/preprocess_data.py @@ -0,0 +1,384 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Processing large data for pretraining.""" +import argparse +import math +import json +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), + os.path.pardir))) +import time +import gzip +import glob +import torch +import numpy as np +import multiprocessing +try: + import nltk + nltk_available = True +except ImportError: + nltk_available = False + +from megatron.tokenizer import build_tokenizer +from megatron.data import indexed_dataset + + +# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer +class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): + + _period_context_fmt = r""" + \S* # some word material + %(SentEndChars)s # a potential sentence ending + \s* # <-- THIS is what I changed + (?=(?P + %(NonWord)s # either other punctuation + | + (?P\S+) # <-- Normally you would have \s+ here + ))""" + +class IdentitySplitter(object): + def tokenize(self, *text): + return text + + +class Encoder(object): + def __init__(self, args): + self.args = args + + def initializer(self): + # Use Encoder class as a container for global data + Encoder.tokenizer = build_tokenizer(self.args) + if self.args.split_sentences: + if not nltk_available: + print("NLTK is not available to split sentences.") + exit() + library = "tokenizers/punkt/{}.pickle".format(self.args.lang) + splitter = nltk.load(library) + if self.args.keep_newlines: + # this prevents punkt from eating newlines after sentences + Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer( + train_text = splitter._params, + lang_vars = CustomLanguageVars()) + else: + Encoder.splitter = splitter + + else: + Encoder.splitter = IdentitySplitter() + + def split(self, json_line): + data = json.loads(json_line) + output = {} + for key in self.args.json_keys: + text = data[key] + max_len = 1000000 + tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)] + output[key] = [tokens for partial in tokens_list for tokens in partial] + return json.dumps(output), len(json_line) + + def encode(self, json_line): + data = json.loads(json_line) + ids = {} + lens = {} + for key in self.args.json_keys: + text = data[key] + if isinstance(text, list): + sentences = text + else: + sentences = [text] + doc_ids = [] + sentence_lens = [] + for sentence in sentences: + sentence_ids = Encoder.tokenizer.tokenize(sentence) + if len(sentence_ids) > 0: + doc_ids.extend(sentence_ids) + sentence_lens.append(len(sentence_ids)) + if len(doc_ids) > 0 and self.args.append_eod: + doc_ids.append(Encoder.tokenizer.eod) + ids[key] = doc_ids + lens[key] = sentence_lens + return ids, lens, len(json_line) + +import os +class Partition(object): + def __init__(self, args, workers): + self.args = args + self.workers = workers + + def print_processing_stats(self, count, proc_start, total_bytes_processed): + if count % self.args.log_interval == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed/elapsed/1024/1024 + print(f"Thread {os.getpid()}: Processed {count} documents", + f"({count/elapsed} docs/s, {mbs} MB/s).", + file=sys.stderr) + + def split_sentences(self, file_name): + input_file_name, output_file_name = file_name + print("Opening", input_file_name) + fin = open(input_file_name, 'r', encoding='utf-8') + fout = open(output_file_name, 'w') + + encoder = Encoder(self.args) + pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer) + split_docs = pool.imap(encoder.split, fin, 32) + + proc_start = time.time() + total_bytes_processed = 0 + for i, (doc, bytes_processed) in enumerate(split_docs, start=1): + total_bytes_processed += bytes_processed + fout.write(doc + "\n") + self.print_processing_stats(i, proc_start, total_bytes_processed) + + fin.close() + fout.close() + + + def process_json_file(self, file_name): + input_file_name, output_prefix = file_name + print("Opening", input_file_name) + if input_file_name.endswith(".gz"): + fin = gzip.open(input_file_name, "rt") + else: + fin = open(input_file_name, 'r', encoding='utf-8') + + startup_start = time.time() + encoder = Encoder(self.args) + tokenizer = build_tokenizer(self.args) + pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer) + encoded_docs = pool.imap(encoder.encode, fin, 32) + + level = "document" + if self.args.split_sentences: + level = "sentence" + + output_bin_files = {} + output_idx_files = {} + builders = {} + + for key in self.args.json_keys: + output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix, + key, level) + output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix, + key, level) + builders[key] = indexed_dataset.make_builder(output_bin_files[key], + impl=self.args.dataset_impl, + vocab_size=tokenizer.vocab_size) + + startup_end = time.time() + proc_start = time.time() + total_bytes_processed = 0 + print("Time to startup:", startup_end - startup_start) + for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + for key in doc.keys(): + builders[key].add_doc(doc[key], sentence_lens[key]) + self.print_processing_stats(i, proc_start, total_bytes_processed) + + fin.close() + builders[key].finalize(output_idx_files[key]) + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title='input data') + group.add_argument('--input', type=str, required=True, + help='Path to input JSON') + group.add_argument('--json-keys', nargs='+', default=['text'], + help='space separate listed of keys to extract from json') + group.add_argument('--split-sentences', action='store_true', + help='Split documents into sentences.') + group.add_argument('--keep-newlines', action='store_true', + help='Keep newlines between sentences when splitting.') + + group = parser.add_argument_group(title='tokenizer') + group.add_argument('--tokenizer-type', type=str, required=True, + choices=['BertWordPieceLowerCase','BertWordPieceCase', + 'GPT2BPETokenizer', 'SentencePieceTokenizer', + 'GPTSentencePieceTokenizer', 'NullTokenizer', 'Llama2Tokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--tokenizer-model', type=str, default=None, + help='YTTM tokenizer model.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file') + group.add_argument('--vocab-size', default=786, + help='size of vocab for use with NullTokenizer') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file (if necessary).') + group.add_argument('--append-eod', action='store_true', + help='Append an token to the end of a document.') + group.add_argument('--lang', type=str, default='english', + help='Language to use for NLTK-powered sentence splitting.') + group = parser.add_argument_group(title='output data') + group.add_argument('--output-prefix', type=str, required=True, + help='Path to binary output file without suffix') + group.add_argument('--dataset-impl', type=str, default='mmap', + choices=['lazy', 'cached', 'mmap']) + + group = parser.add_argument_group(title='runtime') + group.add_argument('--workers', type=int, required=True, + help=('Number of worker processes to launch.' + 'A good default for fast pre-processing ' + 'is: (workers * partitions) = available CPU cores.')) + group.add_argument('--partitions', type=int, default=1, + help='Number of file partitions') + group.add_argument('--log-interval', type=int, default=1000, + help='Interval between progress updates') + args = parser.parse_args() + args.keep_empty = False + + if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences: + print("Are you sure you don't want to split sentences?") + + # some default/dummy values for the tokenizer + args.rank = 1 + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + args.vocab_extra_ids = 0 + + return args + + +def get_file_name(args, file_id): + file_name, extension = os.path.splitext(args.input) + input_file_name = file_name + "_" + str(file_id) + extension + sentence_split_file = file_name + "_ss_" + str(file_id) + extension + output_prefix = args.output_prefix + "_" + str(file_id) + file_names = { + 'partition': input_file_name, + 'sentence_split': sentence_split_file, + 'output_prefix': output_prefix} + return file_names + + +def check_files_exist(in_ss_out_names, key, num_partitions): + for i in range(num_partitions): + if not os.path.exists(in_ss_out_names[i][key]): + return False + return True + + +def main(): + args = get_args() + + if args.split_sentences: + if nltk_available: + nltk.download("punkt", quiet=True) + else: + raise Exception( + "nltk library required for sentence splitting is not available.") + + in_ss_out_names = [] + if args.partitions == 1: + file_name, extension = os.path.splitext(args.input) + sentence_split_file = file_name + "_ss" + extension + file_names = { + 'partition': args.input, + 'sentence_split': sentence_split_file, + 'output_prefix': args.output_prefix} + in_ss_out_names.append(file_names) + else: + in_file_names = glob.glob(args.input) + + # create .jsonl parition files + for idx in range(args.partitions): + in_ss_out_name = get_file_name(args, idx) + in_ss_out_names.append(in_ss_out_name) + + # check to see if paritions were already created + partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions) + + # check to see if paritions with split sentences already created + split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) + + if not partitions_present and not split_sentences_present: + # populate .jsonl partition files from parent files + partitioned_input_files = [] + for idx in range(args.partitions): + partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w') + partitioned_input_files.append(partitioned_input_file) + + index = 0 + for in_file_name in in_file_names: + # support for gzip files + if in_file_name.endswith(".gz"): + fin = gzip.open(in_file_name, 'rt') + else: + fin = open(in_file_name, 'r', encoding='utf-8') + + for line in fin: + partitioned_input_files[index].write(line) + index = (index + 1)%args.partitions + + fin.close() + + for idx in range(args.partitions): + partitioned_input_files[idx].close() + + assert args.workers % args.partitions == 0 + partition = Partition(args, args.workers//args.partitions) + + # check to see if paritions with split sentences already created + split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) + + # split sentences in partition files + if args.split_sentences and not split_sentences_present: + processes = [] + for name in in_ss_out_names: + p = multiprocessing.Process(target=partition.split_sentences, + args=((name['partition'], name['sentence_split']),)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + if args.partitions == 1: + return + + + # encode partition files in parallel + processes = [] + input_key = 'sentence_split' if args.split_sentences else 'partition' + for name in in_ss_out_names: + p = multiprocessing.Process(target=partition.process_json_file, + args=((name[input_key], name['output_prefix']),)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + if args.partitions == 1: + return + + # merge bin/idx partitions + level = "document" + if args.split_sentences: + level = "sentence" + + output_bin_files = {} + output_idx_files = {} + builders = {} + tokenizer = build_tokenizer(args) + + for key in args.json_keys: + output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, + key, level) + output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, + key, level) + builders[key] = indexed_dataset.make_builder(output_bin_files[key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size) + for name in in_ss_out_names: + parition_output_prefix = name['output_prefix'] + full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix, + key, level) + builders[key].merge_file_(full_partition_output_prefix) + builders[key].finalize(output_idx_files[key]) + + +if __name__ == '__main__': + main() + From 9cb008941ce0b0c0e327e0c45002d9e01e12b862 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:38:51 +0000 Subject: [PATCH 35/50] batch preprocess --- preprocess_data_batch.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 preprocess_data_batch.py diff --git a/preprocess_data_batch.py b/preprocess_data_batch.py new file mode 100644 index 0000000000..b71a679984 --- /dev/null +++ b/preprocess_data_batch.py @@ -0,0 +1,26 @@ +# Preprocess batch of files to bin and idx format + +def main(): + + import subprocess + import glob + import os + from tqdm import tqdm + + nfiles = glob.glob('./protein_gym/indels/DMS_ProteinGym_indels_multi_prop_fit_meg-ds/*.json') + print(f'num files: {len(nfiles)}') + + for i in tqdm(range(len(nfiles))): + sname = nfiles[i].split('/')[-1].split('.')[0] + print(f'Input json filename: {sname}') + cmd = f'python preprocess_data.py --input ./protein_gym/indels/DMS_ProteinGym_indels_multi_prop_fit_meg-ds/{sname}.json --output-prefix ./protein_gym/indels/DMS_ProteinGym_indels_multi_prop_fit_meg-ds_bin-idx/{sname} --tokenizer-type Llama2Tokenizer --tokenizer-model /lus/eagle/projects/datasets/dolma/utils/tokenizer.model --workers 16' + returned_value = os.system(cmd) + +if __name__ == '__main__': + main() + + +# python preprocess_data.py --input ./protein_gym/indels/DMS_ProteinGym_indels_multi_prop_fit_meg-ds/HIS7_YEAST_Pokusaeva_2019_indels_multi_prop_fit_pref.json --output-prefix ./protein_gym/indels/DMS_ProteinGym_indels_multi_prop_fit_meg-ds_bin-idx/HIS7_YEAST_Pokusaeva_2019_indels_multi_prop_fit_pref --tokenizer-type Llama2Tokenizer --tokenizer-model /lus/eagle/projects/datasets/dolma/utils/tokenizer.model --workers 16 + + +'' \ No newline at end of file From a2aba88b3b274fde6baf1a9efc47527a7bc63fe2 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:40:05 +0000 Subject: [PATCH 36/50] convert to megatron format --- convert_to_megds_fmt.py | 72 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 convert_to_megds_fmt.py diff --git a/convert_to_megds_fmt.py b/convert_to_megds_fmt.py new file mode 100644 index 0000000000..e2c0015cdc --- /dev/null +++ b/convert_to_megds_fmt.py @@ -0,0 +1,72 @@ +# Convert ProteinGym text-sequences files to megatron-deepspeed format + +import glob +import os +from tqdm import tqdm + +def return_text_seq(jsonob): + + fit_text_seq = [] + unfit_text_seq = [] + fit_seq = [] + unfit_seq = [] + for i in tqdm(range(len(jsonob))): + if jsonob['context'][i]['fitness'] == 'fit': + fit_seq.append(jsonob['context'][i]['sequence']) + fit_text_seq.append(jsonob['text'][i]) + elif jsonob['context'][i]['fitness'] == 'unfit': + unfit_seq.append(jsonob['context'][i]['sequence']) + unfit_text_seq.append(jsonob['text'][i]) + + return fit_text_seq, unfit_text_seq, fit_seq, unfit_seq + + +def write_to_megds_fmt(outpath, fname, in_seqs, tag=None): + + # Convert to json for meg-ds + exm_message_ch = in_seqs[0] + message_text = "" + message_text += exm_message_ch + import json + d0 = {"id": f"{0}", "text": message_text} + st = json.dumps(d0) + # st = f'{d0}' + st + # print(f'st[0]: {st}') + + for i in tqdm(range(1,len(in_seqs))): + + exm_message_ch = in_seqs[i] + message_text = "" + message_text += exm_message_ch + di = {"id": f"{i}", "text": message_text} + st = st + '\n' + json.dumps(di) + + fname += f'_{tag}.json' + with open(os.path.join(outpath,fname), 'w') as f: + f.write(st) + +def convert_to_megds_fmt(inpath, outpath): + + import pandas as pd + jsonobjf = pd.read_json(path_or_buf=inpath, lines=True) + + f_text_seq, uf_text_seq, f_seq, uf_seq = return_text_seq(jsonobjf) + + fname = inpath.split('/')[-1].split('.jsonl')[0] + + write_to_megds_fmt(outpath, fname, f_text_seq, tag='pref') + write_to_megds_fmt(outpath, fname, uf_text_seq, tag='unpref') + + +def main(): + + outpath = '/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/protein_gym/substitutions/DMS_ProteinGym_substitutions_multi_prop_fit_meg-ds' + nfiles = glob.glob('/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/protein_gym/substitutions/DMS_ProteinGym_substitutions_multi_prop_fit/*.jsonl') + print(f'Number of Substition files: {nfiles}') + + for nf in tqdm(nfiles): + convert_to_megds_fmt(nf, outpath) + +if __name__ == '__main__': + main() \ No newline at end of file From 9e13d4ea8f3b8922ab67ed5558d902c7bc277967 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Fri, 5 Apr 2024 23:40:27 +0000 Subject: [PATCH 37/50] write megatron data file --- write_megds_data_file.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 write_megds_data_file.py diff --git a/write_megds_data_file.py b/write_megds_data_file.py new file mode 100644 index 0000000000..43e437132e --- /dev/null +++ b/write_megds_data_file.py @@ -0,0 +1,33 @@ +# Write data file with weights +# default weights set to 1 + +import sys +import glob +from tqdm import tqdm + +def main(): + + inpath = '/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/protein_gym/indels/DMS_ProteinGym_indels_multi_prop_fit_meg-ds_bin-idx/' + outpath = '/lus/eagle/projects/RL-fold/gdharuman/Megatron-DeepSpeed/ultrafeedback_dataset/' + + # tag = "pref" + tag = sys.argv[1] + inpath += '*_'+tag+'_*.bin' + print(f'inpath: {inpath}') + + nfiles = glob.glob(inpath) + print(f'Number of files with the tag: {len(nfiles)}') + + lines = [] + # for nf in nfiles: + for i in tqdm(range(len(nfiles))): + lines.append('1.0 ' + nfiles[i].split('.bin')[0]) + + # print(lines) + + with open(outpath+f'data_textseq_proteingym_indels_file_list_{tag[0]}.txt', 'w') as f: + for line in lines: + f.write(f"{line}\n") + +if __name__ == '__main__': + main() From 702bb2653660352d661f089319eabd67616cac7f Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Thu, 11 Apr 2024 17:04:06 -0500 Subject: [PATCH 38/50] Add @hzheng s example `run_hzheng.sh` --- ds_config-gpt.json | 54 ++++++++++++++++++++++++++++++++++++++++++++++ run_hzheng.sh | 16 ++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 ds_config-gpt.json create mode 100755 run_hzheng.sh diff --git a/ds_config-gpt.json b/ds_config-gpt.json new file mode 100644 index 0000000000..5ef4ebe0d1 --- /dev/null +++ b/ds_config-gpt.json @@ -0,0 +1,54 @@ + { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 16, + "steps_per_print": 1, + "wall_clock_breakdown" : true, + "zero_force_ds_cpu_optimizer": false, + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "reduce_scatter": true, + "allgather_bucket_size": 5e8, + "overlap_comm": true, + "contiguous_gradients": true, + "offload_param": { + "device": "cpu", + "nvme_path": "/raid/scratch", + "pin_memory": false + }, + "offload_optimizer": { + "device": "cpu", + "nvme_path": "/raid/scratch/" +} +}, +"scheduler": { +"type": "WarmupLR", +"params": { +"warmup_min_lr": 0, +"warmup_max_lr": 0.001, +"warmup_num_steps": 1000 +} +}, +"fp16": { +"enabled": true, +"initial_scale_power": 12 +}, +"flops_profiler": { +"enabled": true, +"profile_step": 1, +"module_depth": -1, +"top_modules": 3, +"detailed": true, +"output_file": null +}, +"comms_logger": { +"enabled": true, +"verbose": false, +"prof_all": false, +"debug": false +}, +"wandb": { +"enabled": true, +"project": "GenSLM-Megatron-DS" +} +} diff --git a/run_hzheng.sh b/run_hzheng.sh new file mode 100755 index 0000000000..b588c66565 --- /dev/null +++ b/run_hzheng.sh @@ -0,0 +1,16 @@ +#!/bin/bash --login +#PBS -l walltime=0:30:00 +#PBS -A datascience +#PBS -q debug +#PBS -l select=1 +#PBS -l filesystems=eagle:grand:home +cd ${PBS_O_WORKDIR} +export PPN=4 +# export MD=${HOME}/GB-Megatron-DeepSpeed +export MD="/eagle/FoundEpidem/foremans/Megatron-DeepSpeed" +export PYTHONPATH=$MD:$PYTHONPATH +source /eagle/argonne_tpc/soft/conda.sh +export TRITON_CACHE_DIR=/tmp/.cache/ +export PBS_JOBSIZE=$(cat $PBS_NODEFILE | uniq | wc -l) + +APRUN_PMI=pmix aprun -n $((PBS_JOBSIZE*PPN)) -N $PPN --cc depth -d 16 /eagle/argonne_tpc/soft/local_rank.sh python dpo_training.py --use-flash-attn-v2 --fp16 --num-workers 0 --split 100,0,0 --log-interval 1 --no-bias-gelu-fusion --lr-decay-style cosine --no-bias-dropout-fusion --no-masked-softmax-fusion --tokenizer-type Llama2Tokenizer --no-gradient-accumulation-fusion --accumulate-allreduce-grads-in-fp32 --use-checkpoint-opt_param-scheduler --lr 0.0003 --seq-length 1024 --save checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 --load checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 --num-layers 6 --hidden-size 4096 --train-iters 30 --eval-iters 10 --distributed-backend nccl --num-attention-heads 32 --save-interval 200 --eval-interval 50000 --max-position-embeddings 1024 --micro-batch-size 4 --data-file-list-p ALCF/data_textseq_proteingym_indels_file_list_p.txt --data-file-list-u ALCF/data_textseq_proteingym_indels_file_list_u.txt --tensor-model-parallel-size 1 --global-batch-size 32 --pipeline-model-parallel-size 1 --num-key-value-heads 32 --data-cache-path ./index-cache --ffn-hidden-size 11008 --tokenizer-model /eagle/datasets/dolma/utils/tokenizer.model --no-query-key-layer-scaling --use-rotary-position-embeddings --untie-embeddings-and-output-weights --swiglu --normalization rmsnorm --disable-bias-linear --deepspeed-activation-checkpointing --zero-stage=2 --deepspeed_config=ds_config-gpt.json --no-pipeline-parallel --deepspeed --checkpoint-activations --checkpoint-num-layers 1 --optimizer adamw From 8d048ac2b8281a5c42f66650366cf047bb3889a1 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Thu, 11 Apr 2024 17:04:43 -0500 Subject: [PATCH 39/50] Prefer `torch.optim` for `--optimizer` --- megatron/arguments.py | 5 ++- megatron/optimizer/__init__.py | 65 +++++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 270a886596..6674634dfc 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -913,8 +913,11 @@ def _add_training_args(parser): group.add_argument('--disable-bias-linear', action='store_false', help='Disable bias in the linear layers', dest='add_bias_linear') + # group.add_argument('--optimizer', type=str, default='adam', + # choices=['adam', 'sgd', 'adamw', ''], + # help='Optimizer function') group.add_argument('--optimizer', type=str, default='adam', - choices=['adam', 'sgd'], + choices=['adam', 'adamw', 'sgd', 'apex.adam', 'apex.sgd'], help='Optimizer function') group.add_argument('--dataloader-type', type=str, default=None, choices=['single', 'cyclic'], diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 12a458375d..cdcc344541 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -1,12 +1,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. from deepspeed.accelerator import get_accelerator -if get_accelerator().device_name() == 'cuda': - from apex.optimizers import FusedAdam as Adam - from apex.optimizers import FusedSGD as SGD -else: - from torch.optim import Adam - from torch.optim import SGD +import torch from megatron import get_args @@ -93,25 +88,55 @@ def get_megatron_optimizer(model, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps) else: - if args.optimizer == 'adam': + if str(args.optimizer).lower() == 'apex.adam': + assert get_accelerator().device_name() == 'cuda' + from apex.optimizers import FusedAdam as Adam + optimizer = Adam( + param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps + ) + elif str(args.optimizer).lower() == 'apex.sgd': + from apex.optimizers import FusedSGD as SGD + optimizer = SGD( + param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + momentum=args.sgd_momentum + ) + elif str(args.optimizer).lower() == 'adamw': + optimizer = torch.optim.AdamW( + param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps + ) + elif args.optimizer == 'adam': if args.ds_fused_adam: - global Adam + # global Adam from deepspeed.ops.adam import FusedAdam Adam = FusedAdam - optimizer = Adam(param_groups, - lr=args.lr, - weight_decay=args.weight_decay, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps) + else: + Adam = torch.optim.Adam + optimizer = Adam( + param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps + ) elif args.optimizer == 'sgd': - optimizer = SGD(param_groups, - lr=args.lr, - weight_decay=args.weight_decay, - momentum=args.sgd_momentum) + optimizer = torch.optim.SGD( + param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + momentum=args.sgd_momentum + ) else: - raise Exception('{} optimizer is not supported.'.format( - args.optimizer)) - + raise TypeError(f'{args.optimizer} optimizer is not supported.') if args.deepspeed: return optimizer From 859ec88123cb620c861183d02a0e56a83af51b46 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Thu, 11 Apr 2024 17:37:42 -0500 Subject: [PATCH 40/50] Remove global batch from `run_hzheng.sh` --- run_hzheng.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_hzheng.sh b/run_hzheng.sh index b588c66565..015e5f2b3d 100755 --- a/run_hzheng.sh +++ b/run_hzheng.sh @@ -13,4 +13,4 @@ source /eagle/argonne_tpc/soft/conda.sh export TRITON_CACHE_DIR=/tmp/.cache/ export PBS_JOBSIZE=$(cat $PBS_NODEFILE | uniq | wc -l) -APRUN_PMI=pmix aprun -n $((PBS_JOBSIZE*PPN)) -N $PPN --cc depth -d 16 /eagle/argonne_tpc/soft/local_rank.sh python dpo_training.py --use-flash-attn-v2 --fp16 --num-workers 0 --split 100,0,0 --log-interval 1 --no-bias-gelu-fusion --lr-decay-style cosine --no-bias-dropout-fusion --no-masked-softmax-fusion --tokenizer-type Llama2Tokenizer --no-gradient-accumulation-fusion --accumulate-allreduce-grads-in-fp32 --use-checkpoint-opt_param-scheduler --lr 0.0003 --seq-length 1024 --save checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 --load checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 --num-layers 6 --hidden-size 4096 --train-iters 30 --eval-iters 10 --distributed-backend nccl --num-attention-heads 32 --save-interval 200 --eval-interval 50000 --max-position-embeddings 1024 --micro-batch-size 4 --data-file-list-p ALCF/data_textseq_proteingym_indels_file_list_p.txt --data-file-list-u ALCF/data_textseq_proteingym_indels_file_list_u.txt --tensor-model-parallel-size 1 --global-batch-size 32 --pipeline-model-parallel-size 1 --num-key-value-heads 32 --data-cache-path ./index-cache --ffn-hidden-size 11008 --tokenizer-model /eagle/datasets/dolma/utils/tokenizer.model --no-query-key-layer-scaling --use-rotary-position-embeddings --untie-embeddings-and-output-weights --swiglu --normalization rmsnorm --disable-bias-linear --deepspeed-activation-checkpointing --zero-stage=2 --deepspeed_config=ds_config-gpt.json --no-pipeline-parallel --deepspeed --checkpoint-activations --checkpoint-num-layers 1 --optimizer adamw +APRUN_PMI=pmix aprun -n $((PBS_JOBSIZE*PPN)) -N $PPN --cc depth -d 16 /eagle/argonne_tpc/soft/local_rank.sh python dpo_training.py --use-flash-attn-v2 --fp16 --num-workers 0 --split 100,0,0 --log-interval 1 --no-bias-gelu-fusion --lr-decay-style cosine --no-bias-dropout-fusion --no-masked-softmax-fusion --tokenizer-type Llama2Tokenizer --no-gradient-accumulation-fusion --accumulate-allreduce-grads-in-fp32 --use-checkpoint-opt_param-scheduler --lr 0.0003 --seq-length 1024 --save checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 --load checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 --num-layers 16 --hidden-size 4096 --train-iters 30 --eval-iters 10 --distributed-backend nccl --num-attention-heads 32 --save-interval 200 --eval-interval 50000 --max-position-embeddings 1024 --micro-batch-size 4 --data-file-list-p ALCF/data_textseq_proteingym_indels_file_list_p.txt --data-file-list-u ALCF/data_textseq_proteingym_indels_file_list_u.txt --tensor-model-parallel-size 1 --pipeline-model-parallel-size 1 --num-key-value-heads 32 --data-cache-path ./index-cache --ffn-hidden-size 11008 --tokenizer-model /eagle/datasets/dolma/utils/tokenizer.model --no-query-key-layer-scaling --use-rotary-position-embeddings --untie-embeddings-and-output-weights --swiglu --normalization rmsnorm --disable-bias-linear --deepspeed-activation-checkpointing --zero-stage=2 --deepspeed_config=ds_config-gpt.json --no-pipeline-parallel --deepspeed --checkpoint-activations --checkpoint-num-layers 1 --optimizer adamw From 02961dc6a118a6da8292f19fb96c5d75c40d5e5e Mon Sep 17 00:00:00 2001 From: gdharuman Date: Fri, 19 Jul 2024 12:52:44 -0700 Subject: [PATCH 41/50] minor changes --- dpo_training.py | 8 +++++--- ds_config-gpt.json | 4 ++-- pretrain_gpt_alcf.py | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/dpo_training.py b/dpo_training.py index dd02b537f3..183dab787e 100644 --- a/dpo_training.py +++ b/dpo_training.py @@ -102,7 +102,7 @@ print('--------------------------------------------------') print(f"Setting up W&B from: {RANK} with {project_name}") print('--------------------------------------------------') - setup_wandb(project_name=project_name) + #setup_wandb(project_name=project_name) def model_provider(pre_process=True, post_process=True): """Build the model.""" @@ -838,6 +838,8 @@ def main(): config=args.deepspeed_config_dict, ) model = [model] + print_rank_0(get_parameters_in_billions(model)) + #exit() # ---------- Reference model ------------- # model_ref, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) # throwing assertion error @@ -1147,10 +1149,10 @@ def main(): averaged_rewards_iter.append(rewards.tolist()) if (i % args.save_interval == 0) and (i > 0) and (torch.distributed.get_rank() == 0): - TPL = os.environ.get('TP') + TPL = args.tensor_model_parallel_size GRAD_ACC = os.environ.get('GRAD_ACC_STEPS') print(f'Checkpointing loss and rewards at iteration {i} ..') - np.savez(f'./runs/proteingym_indels/loss-rewards_indels_textseq_nranks-{WORLD_SIZE}_model-nlayers-{args.num_layers}_TP-{TPL}_zero-{args.zero_stage}_gradacc-{GRAD_ACC}_seq-{args.seq_length}_bs-{args.micro_batch_size}_iters-{args.train_iters}-chkpt-{i}.npz', loss=np.array(averaged_loss_iter), rewards=np.array(averaged_rewards_iter)) + np.savez(f'./runs/loss-rewards_indels_textseq_nranks-{WORLD_SIZE}_model-nlayers-{args.num_layers}_TP-{TPL}_zero-{args.zero_stage}_gradacc-{GRAD_ACC}_lr-{args.lr}_seq-{args.seq_length}_bs-{args.micro_batch_size}_iters-{args.train_iters}-chkpt-{i}.npz', loss=np.array(averaged_loss_iter), rewards=np.array(averaged_rewards_iter)) # if torch.distributed.get_rank() == 0: # avg_loss_epoch.append(np.array(averaged_loss_iter).mean()) diff --git a/ds_config-gpt.json b/ds_config-gpt.json index 5ef4ebe0d1..dc1b5a779d 100644 --- a/ds_config-gpt.json +++ b/ds_config-gpt.json @@ -1,6 +1,6 @@ { "train_micro_batch_size_per_gpu": 1, - "gradient_accumulation_steps": 16, + "gradient_accumulation_steps": 3, "steps_per_print": 1, "wall_clock_breakdown" : true, "zero_force_ds_cpu_optimizer": false, @@ -49,6 +49,6 @@ }, "wandb": { "enabled": true, -"project": "GenSLM-Megatron-DS" +"project": "Megatron-DS" } } diff --git a/pretrain_gpt_alcf.py b/pretrain_gpt_alcf.py index 4fefef795f..1e220f0db9 100644 --- a/pretrain_gpt_alcf.py +++ b/pretrain_gpt_alcf.py @@ -183,7 +183,8 @@ def get_batch(data_iterator): tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() - + print_rank_0(f'tokens shape: {tokens.size()}') + print_rank_0(f'tokens[0] : {tokens[0]}') # Get the masks and postition ids. skip_mask = args.use_flash_attn or args.use_flash_attn_triton attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( From 984fcf0f6e7e317b1fa72e3040c82260cbb45190 Mon Sep 17 00:00:00 2001 From: gdharuman Date: Fri, 19 Jul 2024 12:53:34 -0700 Subject: [PATCH 42/50] scaling run launch script for dgx cluster --- run_dgxcluster_scaling.sh | 83 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 run_dgxcluster_scaling.sh diff --git a/run_dgxcluster_scaling.sh b/run_dgxcluster_scaling.sh new file mode 100644 index 0000000000..724154bc3a --- /dev/null +++ b/run_dgxcluster_scaling.sh @@ -0,0 +1,83 @@ +#!/bin/bash +#SBATCH --partition defq --nodes 31 +#SBATCH --exclusive +#SBATCH --job-name=example-mn-sbatch-job +#SBATCH --gpus-per-node=8 + +CONTAINER=${HOME}/enroot_images/megds2.sqsh +#srun --nodes 2 --mpi=pmix --gpus-per-node 8 --container-image=${CONTAINER} --ntasks-per-node=1 nvidia-smi -L +#exit 0 + +export OMPI_MCA_coll_hcoll_enable=0 +export UCX_TLS=rc +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +export NCCL_TOPO_FILE=/cm/shared/etc/ndv4-topo.xml +export NCCL_DEBUG=INFO +export NCCL_PROTO=LL,LL128,Simple +export NCCL_ALGO=Tree,Ring,CollnetDirect,CollnetChain,NVLS +export MELLANOX_VISIBLE_DEVICES=all +export PMIX_MCA_gds=hash +export PMIX_MCA_psec=native + +export NHOSTS="${SLURM_NNODES}" +export NGPU_PER_HOST="${SLURM_GPUS_ON_NODE}" +export NGPUS="$(( NHOSTS * NGPU_PER_HOST ))" +export OMP_NUM_THREADS=1 +export WORLD_SIZE=$NGPUS +export RANK=$SLURM_PROCID +export LOCAL_RANK=$SLURM_LOCALID +export NCCL_DEBUG=warn + +echo "PATH=$PATH" > .deepspeed_env +echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> .deepspeed_env +echo "CPATH=$CPATH" >> .deepspeed_env +echo "TORCH_EXTENSIONS_DIR=$PWD/deepspeed" >> .deepspeed_env +echo "HF_HOME=$PWD/hfdata" >> .deepspeed_env + + +echo ${SLURM_GPUS_ON_NODE} + +if [ ! -z "${SLURM_JOB_ID}" ]; then + # check the original location through scontrol and $SLURM_JOB_ID + SCRIPT_PATH=$(scontrol show job $SLURM_JOBID | awk -F= '/Command=/{print $2}') + export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +else + # otherwise: started with bash. Get the real location. + SCRIPT_PATH=$(realpath $0) +fi + +export _basedir="$(cd "$(dirname "${SCRIPT_PATH}")" && pwd)" +cd ${_basedir} +echo ${_basedir} + +#cd $SCRIPT_PATH +echo $SCRIPT_PATH +echo $SLURM_NNODES + +#CONTAINER=${HOME}/enroot_images/megds2.sqsh +#source /lustre/fs0/scratch/gdharuman/Megatron-DeepSpeed/deps/ezpz/src/ezpz/bin/savejobenv +srun --mpi=pmix --nodes $SLURM_NNODES --gpus-per-node 8 --ntasks-per-node=8 --container-workdir=${_basedir} --container-mounts="/lustre/fs0/scratch/gdharuman","/home/gdharuman" --container-image=${CONTAINER} python /lustre/fs0/scratch/gdharuman/Megatron-DeepSpeed/dpo_training.py \ + --use-flash-attn-v2 --fp16 --split 100,0,0 \ + --log-interval 1 --no-bias-gelu-fusion \ + --lr-decay-style cosine --no-bias-dropout-fusion \ + --no-masked-softmax-fusion --tokenizer-type Llama2Tokenizer \ + --no-gradient-accumulation-fusion --accumulate-allreduce-grads-in-fp32 \ + --use-checkpoint-opt_param-scheduler --lr 5e-6 --seq-length 512 \ + --save checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ + --load checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ + --num-layers 32 --hidden-size 4096 --train-iters 5000 --eval-iters 10 \ + --distributed-backend nccl --num-attention-heads 32 --save-interval 10 \ + --eval-interval 50000 --max-position-embeddings 4096 --micro-batch-size 12 \ + --data-file-list-p ALCF/data_textseq_p.txt \ + --data-file-list-u ALCF/data_textseq_u.txt \ + --tensor-model-parallel-size 8 --pipeline-model-parallel-size 1 \ + --num-key-value-heads 32 --data-cache-path ./index-cache \ + --ffn-hidden-size 11008 --tokenizer-model ALCF/tokenizer.model \ + --no-query-key-layer-scaling --use-rotary-position-embeddings \ + --untie-embeddings-and-output-weights --swiglu \ + --normalization rmsnorm --disable-bias-linear \ + --zero-stage=1 --deepspeed_config=ds_config-gpt.json \ + --no-pipeline-parallel --deepspeed --optimizer adamw From 907d15833d1de2142be908a403a6b5b50bd42d7e Mon Sep 17 00:00:00 2001 From: gdharuman Date: Mon, 22 Jul 2024 07:42:09 -0700 Subject: [PATCH 43/50] minor changes --- dpo_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpo_training.py b/dpo_training.py index 183dab787e..bb6bd96773 100644 --- a/dpo_training.py +++ b/dpo_training.py @@ -1167,7 +1167,7 @@ def main(): # print(avg_rewards_epoch) # np.savez(f'./runs/proteingym_indels/loss-rewards_iters-{args.train_iters}.npz', loss=np.array(averaged_loss_iter), rewards=np.array(averaged_rewards_iter)) - # Generate - NOT WORKING + # Generate if False: model[0].eval() print_rank_0(f'Generation mode..') @@ -1430,7 +1430,7 @@ def main(): print_rank_0(f'prompts_plus_generations: {prompts_plus_generations}') - if True: + if False: prompts=["Pen is mightier than", "A sequence", "Pythagoras theorem", "A sequence", "Hello world"] tokens_to_generate = 64 generated_responses = generate_post_training(model, prompts, tokens_to_generate, fprint=False) From f7f7da378799f4e795b53a6ae168314ef3e8b3f8 Mon Sep 17 00:00:00 2001 From: gdharuman Date: Thu, 1 Aug 2024 05:09:39 -0700 Subject: [PATCH 44/50] dsflops log to file --- dpo_training.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dpo_training.py b/dpo_training.py index bb6bd96773..28819bf3aa 100644 --- a/dpo_training.py +++ b/dpo_training.py @@ -823,6 +823,11 @@ def main(): if "compression_training" in args.deepspeed_config_dict: args.compression_training = True + from copy import deepcopy + ds_config_copy = deepcopy(args.deepspeed_config_dict) + ds_config_copy["flops_profiler"]["output_file"] = f"dsflops_nlayer{args.num_layers}_worldsize{WORLD_SIZE}_seq{args.seq_length}_mb{args.micro_batch_size}.log" + print_rank_0(f'Deepspeed config updated with out: {ds_config_copy["flops_profiler"]}') + # model = model_provider() # model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) model = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes? From f91ba8c361b0268957cc41609ae90cc0dde309f2 Mon Sep 17 00:00:00 2001 From: gdharuman Date: Thu, 1 Aug 2024 05:10:11 -0700 Subject: [PATCH 45/50] all data --- run_dgxcluster_scaling.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/run_dgxcluster_scaling.sh b/run_dgxcluster_scaling.sh index 724154bc3a..a5e4f8f19b 100644 --- a/run_dgxcluster_scaling.sh +++ b/run_dgxcluster_scaling.sh @@ -68,12 +68,12 @@ srun --mpi=pmix --nodes $SLURM_NNODES --gpus-per-node 8 --ntasks-per-node=8 --co --use-checkpoint-opt_param-scheduler --lr 5e-6 --seq-length 512 \ --save checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ --load checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ - --num-layers 32 --hidden-size 4096 --train-iters 5000 --eval-iters 10 \ - --distributed-backend nccl --num-attention-heads 32 --save-interval 10 \ - --eval-interval 50000 --max-position-embeddings 4096 --micro-batch-size 12 \ - --data-file-list-p ALCF/data_textseq_p.txt \ - --data-file-list-u ALCF/data_textseq_u.txt \ - --tensor-model-parallel-size 8 --pipeline-model-parallel-size 1 \ + --num-layers 32 --hidden-size 4096 --train-iters 100 --eval-iters 10 \ + --distributed-backend nccl --num-attention-heads 32 --save-interval 2000 \ + --eval-interval 50000 --max-position-embeddings 4096 --micro-batch-size 2 \ + --data-file-list-p ALCF/data_textseq_p_all.txt \ + --data-file-list-u ALCF/data_textseq_u_all.txt \ + --tensor-model-parallel-size 1 --pipeline-model-parallel-size 1 \ --num-key-value-heads 32 --data-cache-path ./index-cache \ --ffn-hidden-size 11008 --tokenizer-model ALCF/tokenizer.model \ --no-query-key-layer-scaling --use-rotary-position-embeddings \ From be896de95c235ca3912f095b2bb16f6004599abb Mon Sep 17 00:00:00 2001 From: gdharuman Date: Thu, 1 Aug 2024 06:35:10 -0700 Subject: [PATCH 46/50] reference optimizer removed --- dpo_training.py | 48 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/dpo_training.py b/dpo_training.py index 28819bf3aa..c7ca26d3eb 100644 --- a/dpo_training.py +++ b/dpo_training.py @@ -1,6 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -"""Pretrain GPT""" +"""Direct Preference Optimization""" import os from rich import print @@ -849,17 +849,17 @@ def main(): # ---------- Reference model ------------- # model_ref, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) # throwing assertion error model_ref = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes? - # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider) - optimizer_2 = get_megatron_optimizer(model_ref, None, None, 1.0) - opt_param_scheduler_2 = get_optimizer_param_scheduler(optimizer_2) - model_ref, optimizer_2, _, opt_param_scheduler_2 = deepspeed.initialize( - model=model_ref[0], - optimizer=optimizer_2, - args=args, - lr_scheduler=opt_param_scheduler_2, - mpu=mpu if args.no_pipeline_parallel else None, - config=args.deepspeed_config_dict, - ) + # # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider) + # optimizer_2 = get_megatron_optimizer(model_ref, None, None, 1.0) + # opt_param_scheduler_2 = get_optimizer_param_scheduler(optimizer_2) + # model_ref, optimizer_2, _, opt_param_scheduler_2 = deepspeed.initialize( + # model=model_ref[0], + # optimizer=optimizer_2, + # args=args, + # lr_scheduler=opt_param_scheduler_2, + # mpu=mpu if args.no_pipeline_parallel else None, + # config=args.deepspeed_config_dict, + # ) # model_ref, _, _, _ = deepspeed.initialize( # model=model_ref[0], # optimizer=None, @@ -878,7 +878,29 @@ def main(): # ) # model_ref = engine.module - + # deepspeed initialization of reference model without optimizer + ds_config_ref_dict = args.deepspeed_config_dict.copy() + if 'zero_optimization' in ds_config_ref_dict: + print_rank_0(f'args.deepspeed_config_dict before: {args.deepspeed_config_dict}') + print_rank_0(f'ds_config_ref_dict before: {ds_config_ref_dict}') + if 'zero_optimization' in ds_config_ref_dict.keys(): + del ds_config_ref_dict['zero_optimization'] + if 'optimizer' in ds_config_ref_dict.keys(): + del ds_config_ref_dict['optimizer'] + if 'train_batch_size' in ds_config_ref_dict.keys(): + del ds_config_ref_dict['train_batch_size'] + print_rank_0(f'args.deepspeed_config_dict after: {args.deepspeed_config_dict}') + print_rank_0(f'ds_config_ref_dict after: {ds_config_ref_dict}') + + model_ref, optimizer_2, _, opt_param_scheduler_2 = deepspeed.initialize( + model=model_ref[0], + config=ds_config_ref_dict + ) + print_rank_0(f'ref optimizer: {optimizer_2}') + print_rank_0(f'ref param scheduler: {opt_param_scheduler_2}') + assert optimizer_2 == None, "Reference model optimizer is not None" + assert opt_param_scheduler_2 == None, "Reference param scheduler is not None" + if isinstance(model_ref, deepspeed.PipelineEngine): print(f'Doing assertion checks on model_ref..') # hack to get batch_fn from pretrain_gpt.py From 67ff707b42bc9c5402c3a98f01563380e067b52f Mon Sep 17 00:00:00 2001 From: gdharuman Date: Thu, 1 Aug 2024 06:35:33 -0700 Subject: [PATCH 47/50] with reference optimizer --- dpo_training_ref.py | 1521 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1521 insertions(+) create mode 100644 dpo_training_ref.py diff --git a/dpo_training_ref.py b/dpo_training_ref.py new file mode 100644 index 0000000000..79510ad39d --- /dev/null +++ b/dpo_training_ref.py @@ -0,0 +1,1521 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Direct Preference Optimization""" + +import os +from rich import print +import torch +import math +import numpy as np + +# The earliest we can measure the start time. +import time +from datetime import datetime +import threading + +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import GPTModel, GPTModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group, update_rotary_pos_emb +from megatron.arguments import core_transformer_config_from_args +from megatron.utils import ( + report_memory, + throughput_calculator, + checkpoint_throughput_calculator +) +from pathlib import Path + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator.real_accelerator import get_accelerator +import subprocess +import wandb + +import time +from torch import nn +import torch.nn.functional as F + +# from ezpz import get_logger +from ezpz.dist import get_world_size, setup_wandb, get_rank + +# More imports +from megatron.initialize import initialize_megatron +from megatron.initialize import set_jit_fusion_options +from megatron.training import print_datetime, _create_ds_config_dict +from megatron.training import setup_model_and_optimizer +from megatron.training import load_model_weights_only, get_model +from megatron.training import load_model_weights_only_modified +from megatron.training import get_optimizer_param_scheduler, cyclic_iter +from megatron.training import train, train_step +from megatron.training import train_step_dpo, training_log_dpo +from megatron.optimizer import get_megatron_optimizer +from megatron.checkpointing import load_checkpoint +from megatron.data.data_samplers import build_pretraining_data_loader +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.arguments import core_transformer_config_from_args +from megatron import update_num_microbatches +from megatron import get_num_microbatches +from megatron.utils import throughput_calculator, get_parameters_in_billions +from megatron.text_generation import generate_and_post_process, beam_search_and_post_process +from megatron.text_generation.forward_step import ForwardStep, InferenceParams +from megatron.text_generation.sampling import sample +from megatron.text_generation.tokenization import detokenize_generations +from megatron.text_generation.communication import ( + copy_from_last_to_first_pipeline_stage, + broadcast_from_last_pipeline_stage, + broadcast_from_last_to_first_pipeline_stage) +from megatron.checkpointing import save_checkpoint +from megatron.utils import get_ltor_masks_and_position_ids +from generate_utils import generate_post_training + +# RANK = setup_torch( +# backend='deepspeed', +# port='5432', +# ) +RANK = get_rank() +WORLD_SIZE = get_world_size() +LEVEL = "DEBUG" if RANK == 0 else "CRITICAL" + +WANDB_MODE = os.environ.get('WANDB_MODE', None) +DISABLE_WANDB = ( + WANDB_MODE is not None and str(WANDB_MODE).lower() == 'disabled' +) + +if RANK == 0 and not DISABLE_WANDB: + project_name = ( + os.environ.get( + 'WB_PROJECT', + os.environ.get( + 'WANDB_PROJECT', + 'AuroraGPT' + ), + ) + ) + print('--------------------------------------------------') + print(f"Setting up W&B from: {RANK} with {project_name}") + print('--------------------------------------------------') + #setup_wandb(project_name=project_name) + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + print_rank_0('building GPT model ...') + see_memory_usage("Before Building Model", force=True) + args = get_args() + config = core_transformer_config_from_args(args) + if wandb.run is not None: + print(f"Updating WandB run: [{wandb.run.name}]({wandb.run.url})") + wandb.run.config.update({"args": vars(args)}, allow_val_change=True) + if RANK == 0: + git_ds_info() + if hasattr(mpu, 'get_sequence_parallel_group'): + dpg = mpu.get_sequence_parallel_group() + elif hasattr(mpu, 'get_data_parallel_group'): + dpg = mpu.get_data_parallel_group() + else: + dpg = None + if wandb is not None and wandb.run is not None: + assert wandb is not None and wandb.run is not None + print(f'Updating {wandb.run.name=} at {wandb.run.url=}') + wandb.run.config.update({'args': vars(args)}, allow_val_change=True) + with deepspeed.zero.Init( + data_parallel_group=dpg, + remote_device=( + None if args.remote_device == 'none' else args.remote_device + ), + config_dict_or_path=args.deepspeed_config_dict, + enabled=args.zero_stage == 3, + mpu=mpu + ): + if args.deepspeed and not args.no_pipeline_parallel: + model = GPTModelPipe( + config=config, + num_tokentypes=0, + parallel_output=True + ) + # This is a hack to give us a reference to + # get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + # Predompute the attention mask and store it in args. + # This avoids having to pipeline it + # as an activation during training. + # The mask is constant, and thus we can reuse it. + attention_mask = torch.tril( + torch.ones( + (1, args.seq_length, args.seq_length), + device=get_accelerator().current_device_name() + ) + ).view(1, 1, args.seq_length, args.seq_length) + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + # Attention mask must be bool. + args.attn_mask = attention_mask.to(torch.bool) + + # For prertaining, since sequence length is fixed, + # cache rotary embedding in args, to avoid communicating around + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(args.seq_length) + + else: + print(f'Building model check..') + model = GPTModel( + config=config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + # print_rank_0('\n ------------------------ ') + # print_rank_0(f'num of parameters {num_params}') + # print_rank_0('------------------------\n ') + print_rank_0(80 * '-') + print_rank_0(f"Number of parameters in model: {num_params}") + print_rank_0(80 * '-') + see_memory_usage("After Building Model", force=True) + if wandb.run is not None: + wandb.run.config.update({'num_params': num_params}, allow_val_change=True) + # wandb.run.watch( + # model, + # log='all', + # log_graph=True, + # ) + # wandb.run.config.update({'num_params': num_params}) + return model + +def throughput_flops(model, args, iteration_time, total_iterations): + batch_size = args.micro_batch_size * get_num_microbatches() * args.data_parallel_size + approx_parameters_in_billions = None if (model is None) else get_parameters_in_billions(model) + elapsed_time_per_iter = iteration_time/total_iterations + samples_per_second = batch_size / elapsed_time_per_iter + + #flops calculator + hidden_size = args.hidden_size + num_layers = args.num_layers + vocab_size = args.padded_vocab_size + + # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of + # https://arxiv.org/pdf/2104.04473.pdf). + # The factor of 4 is when used with activation check-pointing, + # otherwise it will be 3. + checkpoint_activations_factor = 3 + if hasattr(args, 'checkpoint_activations') and args.checkpoint_activations: + checkpoint_activations_factor = 4 + if hasattr(args, 'recompute_granularity') and (args.recompute_granularity == 'selective' or args.recompute_granularity == 'full'): + checkpoint_activations_factor = 4 + seq_len = args.seq_length + if hasattr(args, 'actual_seq_length'): + seq_len = args.actual_seq_length + flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) + tflops = flops_per_iteration / (elapsed_time_per_iter * args.world_size * (10**12)) + + return tflops + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + # print(f'len(tokenizer.vocab): {len(tokenizer.vocab)}') + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + data = next(data_iterator) if data_iterator is not None else None + # # Broadcast data. + # if data_iterator is not None: + # data = next(data_iterator) + # else: + # data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + skip_mask = args.use_flash_attn or args.use_flash_attn_triton + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + skip_mask) + + # For DS's sequence parallel + seq_parallel_world_size = mpu.get_sequence_parallel_world_size() + seq_parallel_world_rank = mpu.get_sequence_parallel_rank() + + # For Megatron's sequence parallel + if args.sequence_parallel: + seq_parallel_world_size = mpu.get_tensor_model_parallel_world_size() + seq_parallel_world_rank = mpu.get_tensor_model_parallel_rank() + seq_length = tokens.size(1) + + assert seq_length % seq_parallel_world_size == 0 + sub_seq_length = seq_length // seq_parallel_world_size + sub_seq_start = seq_parallel_world_rank * sub_seq_length + sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length + + tokens = tokens[:, sub_seq_start:sub_seq_end] + position_ids = position_ids[:, sub_seq_start:sub_seq_end] + # For DS's sequence parallel + if mpu.get_sequence_parallel_world_size() > 1: + labels = labels[:, sub_seq_start:sub_seq_end] + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def data_post_process(data, data_sampler_state_dict): + args = get_args() + if args.data_efficiency_curriculum_learning: + if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate'] + if current_seqlen < args.seq_length: + data['text'] = data['text'][:, :(current_seqlen+1)].contiguous() + elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape'] + if current_seqlen < args.seq_length: + orig_num_token = torch.numel(data['text']) + reshape_len = (data['text'].size()[1] // (current_seqlen+1)) * (current_seqlen+1) + data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen+1), + data['text'][:, -(current_seqlen+1):]), 0).contiguous() + num_row = math.ceil(orig_num_token / (current_seqlen+1)) + num_row = min(num_row, data['text'].size()[0]) + if num_row > 1 and num_row % 2 != 0: + num_row -= 1 + data['text'] = data['text'][:num_row, :].contiguous() + else: + args.data_efficiency_curriculum_learning_seqlen_type = None + return data + + +def get_batch_pipe(data): + """ + Modification of `get_batch` to work on `next(data_iterator)` + instead of `data_iterator` + """ + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + if ( + args.curriculum_learning_legacy + and args.curriculum_seqlen < tokens.size()[1] + ): + # seqlen-based curriculum learning + # tokens, position_ids, labels, loss_mask + # have size [batch size, seqlen] + tokens = tokens[:, :args.curriculum_seqlen].contiguous() + position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + if labels is not None: + labels = labels[:, :args.curriculum_seqlen].contiguous() + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + return (tokens, position_ids, attention_mask), (labels, loss_mask) + + +def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + if args.mos or args.kd: + # assert max(args.num_experts) >= 1 + loss = loss + moe_loss + mos_loss + if args.mos: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'mos loss': mos_loss + } + elif args.kd: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'kd loss': mos_loss + } + print_rank_0( + f'>>> total loss: {loss}, ' + f'lm loss {averaged_loss[0]}, ' + f'kd loss {mos_loss}' + ) + else: + if max(args.num_experts) <= 1: + return loss, {'lm loss': averaged_loss[0]} + loss = loss + moe_loss + return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + +def dpo_loss_func(loss_mask, dpo_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + if args.mos or args.kd: + # assert max(args.num_experts) >= 1 + loss = loss + moe_loss + mos_loss + if args.mos: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'mos loss': mos_loss + } + elif args.kd: + return loss, { + 'total loss': loss, + 'lm loss': averaged_loss[0], + 'moe loss': moe_loss, + 'kd loss': mos_loss + } + print_rank_0( + f'>>> total loss: {loss}, ' + f'lm loss {averaged_loss[0]}, ' + f'kd loss {mos_loss}' + ) + # else: + # if max(args.num_experts) <= 1: + # return loss, {'lm loss': averaged_loss[0]} + # loss = loss + moe_loss + # return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + else: + # if max(args.num_experts) <= 1: + # return loss, {'lm loss': averaged_loss[0]} + loss = dpo_loss + return loss, {'lm loss': averaged_loss[0], 'dpo loss': dpo_loss} + +def batch_seq_logprobs(logits, labels): + """ Function to compute a batch of sequence log probabilities """ + + logits = logits[:-1, :, :] # skip last logit + logits_logsoftmax = logits.log_softmax(-1) # compute log softmax of logits + + labels = labels[1:, :].clone() # clone labels + + # # Loss mask to avoid padded tokens while computing loss + # loss_mask = labels != tokenizer.pad_token_id + + # print(f'Labels shape: {labels.shape}') + # print(f'loss_mask shape: {loss_mask.shape}') + # print(f'loss_mask dtype: {loss_mask.dtype}') + + # Gather logps and squeeze last dimension + logprobs = torch.gather(logits_logsoftmax, dim=2, index=labels.unsqueeze(2)).squeeze(2) + # print(f'seq_logprobs shape: {logprobs.shape}') + + # Weighted sum over logprobs using loss mask + # seq_logprobs = (logprobs * loss_mask).sum(-1) + seq_logprobs = logprobs.sum(-1) + + return seq_logprobs + + +def calculate_mos_loss( + args, + stu_output, + teacher_model, + tokens, + position_ids, + attention_mask +): + mos_loss = 0 + alpha = args.kd_alpha_ce + beta = args.kd_beta_ce + kd_temp = args.kd_temp + + if teacher_model: + with torch.no_grad(): + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + curriculum_seqlen = args.curriculum_seqlen + tokens = tokens[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + csl = curriculum_seqlen + attention_mask = ( + attention_mask[:, :, :csl, :csl].contiguous() + ) + # No need to truncate labels + # as we do not need it for the teacher logits + tea_output, tea_other_losses = teacher_model( + tokens, + position_ids, + attention_mask + ) + assert stu_output.size() == tea_output.size(), ( + 'teacher and student output should match in size. ' + f'Student: {stu_output.size()}, ' + f'Teacher: {tea_output.size()}, ' + f'CL seq length {args.curriculum_seqlen}' + ) + + student_logits = F.log_softmax(stu_output / kd_temp, dim=2) + # The target logits is expected to be probabilities. + # If we use log_softmax, + # then we need to set target_log to true + # when initializing the KLDivLoss. + tea_logits = F.softmax(tea_output / kd_temp, dim=2) + + mos_loss = kd_temp * kd_temp * nn.KLDivLoss(reduction='batchmean')( + student_logits, + tea_logits + ) + + mos_loss = mos_loss.div(args.seq_length) * beta + return mos_loss + +def calculate_dpo_loss( + args, + stu_output, + teacher_model, + logprobs_p, + logprobs_u, + ref_logprobs_p, + ref_logprobs_u, + tokens, + position_ids, + attention_mask +): + mos_loss = 0 + alpha = args.kd_alpha_ce + beta = args.kd_beta_ce + kd_temp = args.kd_temp + kd_temp = 1.0 + beta = 0.1 # add to cmdline args + + if teacher_model: + with torch.no_grad(): + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + curriculum_seqlen = args.curriculum_seqlen + tokens = tokens[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + csl = curriculum_seqlen + attention_mask = ( + attention_mask[:, :, :csl, :csl].contiguous() + ) + # No need to truncate labels + # as we do not need it for the teacher logits + ref_output, ref_other_losses = teacher_model( + tokens, + position_ids, + attention_mask + ) + assert stu_output.size() == ref_output.size(), ( + 'ref and student output should match in size. ' + f'Student: {stu_output.size()}, ' + f'Reference: {ref_output.size()}, ' + f'CL seq length {args.curriculum_seqlen}' + ) + + student_logits = F.log_softmax(stu_output / kd_temp, dim=2) + # Labels ? + logprobs = torch.gather(student_logits, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + # The target logits is expected to be probabilities. + # If we use log_softmax, + # then we need to set target_log to true + # when initializing the KLDivLoss. + + # Get ratios of preferred log probabilities from model and ref model + logprob_ratio_p = logprobs_p - ref_logprobs_p + + # Get ratios of unpreferred log probabilities from model and ref model + logprob_ratio_u = logprobs_u - ref_logprobs_u + + # Difference of logprobs ratios scaled by beta + scaled_diff_logprob_ratios = beta * (logprob_ratio_p - logprob_ratio_u) + + # Losses computed as negative logsigmoid of scaled difference + losses = -F.logsigmoid(scaled_diff_logprob_ratios) + + # preferred dpo rewards + pref_dpo_rewards = (beta * logprob_ratio_p).detach() + + # unpreferred dpo rewards + unpref_dpo_rewards = (beta * logprob_ratio_u).detach() + + # Implicit DPO rewards + implicit_dpo_rewards = (pref_dpo_rewards > unpref_dpo_rewards).float() + rewards = implicit_dpo_rewards.cpu().mean() + + # Compute mean loss + dpo_loss = losses.mean() + # print(f'Loss dtype: {loss.dtype}') + + return dpo_loss, rewards + +def compute_dp_loss(logprobs_p, ref_logprobs_p, + logprobs_u, ref_logprobs_u, + beta=0.1): + + # Get ratios of preferred log probabilities from model and ref model + logprob_ratio_p = logprobs_p - ref_logprobs_p + + # Get ratios of unpreferred log probabilities from model and ref model + logprob_ratio_u = logprobs_u - ref_logprobs_u + + # Difference of logprobs ratios scaled by beta + scaled_diff_logprob_ratios = beta * (logprob_ratio_p - logprob_ratio_u) + + # Losses computed as negative logsigmoid of scaled difference + losses = -F.logsigmoid(scaled_diff_logprob_ratios) + + # Compute mean loss + dp_loss = losses.mean() + + return dp_loss + + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + if args.data_efficiency_curriculum_learning: + args.curriculum_seqlen = tokens.size()[1] + if ( + hasattr( + args, + 'data_efficiency_curriculum_learning_seqlen_type') + and ( + args.data_efficiency_curriculum_learning_seqlen_type + == 'seqlen_reshape' + ) + ): + args.data_efficiency_curriculum_learning_numel = ( + torch.numel(tokens) + ) + + if args.mos or args.kd: + # The forward func can return either the loss or the logits, + # depending on whether passing in the labels or not. + stu_output, other_losses = model(tokens, position_ids, attention_mask) + if ( + args.curriculum_learning_legacy + and args.curriculum_seqlen < args.seq_length + ): + assert args.curriculum_seqlen is not None + labels = labels[:, :args.curriculum_seqlen].contiguous() + output_tensor = tensor_parallel.vocab_parallel_cross_entropy( + stu_output.contiguous().float(), + labels + ) + else: + output_tensor, other_losses = model( + tokens, + position_ids, + attention_mask, + labels=labels + ) + if ( + args.curriculum_learning_legacy and + args.curriculum_seqlen < args.seq_length + ): + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + moe_losses = [] + for moe_loss in other_losses: + if moe_loss is not None: + moe_losses.append(moe_loss) + moe_loss = sum(moe_losses) * args.moe_loss_coeff + + mos_loss = 0 + if args.mos or args.kd: + assert model.training + if args.teacher_forward and args.teacher_model is not None: + mos_loss = calculate_mos_loss( + args, + stu_output, + args.teacher_model[0], + tokens, + position_ids, + attention_mask + ) + + # Output_tensor stores the standard loss, + # loss_func calculates the total loss. + return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0('> building train, validation, and test datasets ' + 'for GPT ...') + files = [] + if args.data_file_list is not None: + with open(args.data_file_list, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files.append(float(w)) + files.append(fname) + elif len(args.data_path) == 1 and os.path.isdir(args.data_path[0]): + path = args.data_path[0] + "/" + for f in os.listdir(path): + if (os.path.isfile(path + f) and f.find(".bin") != -1): + files.append(1) + files.append(path + f.split(".bin")[0]) + else: + files = args.data_path + print_rank_0(f"file list {files}") + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=files, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +def command_exists(cmd): + result = subprocess.Popen( + f'type {cmd}', + stdout=subprocess.PIPE, + shell=True + ) + return result.wait() == 0 + + +def git_ds_info(): + if RANK != 0: + return + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print( + f'**** Git info for Megatron: ' + f'git_hash={git_hash} git_branch={git_branch} ****' + ) + + +def main(): + # if RANK == 0: + # setup_wandb() + + if os.getenv('TORCH_PROFILER_ENABLED') == '1': + from torch.profiler import profile, record_function, ProfilerActivity + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron( + # extra_args_provider=extra_args_provider, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # external_args=external_args + ) + # Set pytorch JIT layer fusion options and warmup JIT functions. + if get_accelerator().device_name() == 'cuda': + set_jit_fusion_options() + + args = get_args() + timers = get_timers() + + # model = model_provider() + model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + + prof.export_chrome_trace(f"{args.tensorboard_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}.json") + else: + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron( + # extra_args_provider=extra_args_provider, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # external_args=external_args + ) + # Set pytorch JIT layer fusion options and warmup JIT functions. + if get_accelerator().device_name() == 'cuda': + set_jit_fusion_options() + + args = get_args() + timers = get_timers() + + if args.deepspeed: + args.deepspeed_config_dict = _create_ds_config_dict() + if "curriculum_learning" in args.deepspeed_config_dict and \ + "enabled" in args.deepspeed_config_dict["curriculum_learning"]: + args.curriculum_learning_legacy = args.deepspeed_config_dict[ \ + "curriculum_learning"]["enabled"] + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + from deepspeed.runtime.data_pipeline.curriculum_scheduler \ + import CurriculumScheduler + args.curriculum_scheduler = CurriculumScheduler( \ + args.deepspeed_config_dict["curriculum_learning"]) + if "compression_training" in args.deepspeed_config_dict: + args.compression_training = True + + from copy import deepcopy + ds_config_copy = deepcopy(args.deepspeed_config_dict) + ds_config_copy["flops_profiler"]["output_file"] = f"dsflops_nlayer{args.num_layers}_worldsize{WORLD_SIZE}_seq{args.seq_length}_mb{args.micro_batch_size}.log" + print_rank_0(f'Deepspeed config updated with out: {ds_config_copy["flops_profiler"]}') + + # model = model_provider() + # model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) + model = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes? + # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider) + optimizer = get_megatron_optimizer(model, None, None, 1.0) + opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + model, optimizer, _, opt_param_scheduler = deepspeed.initialize( + model=model[0], + optimizer=optimizer, + args=args, + lr_scheduler=opt_param_scheduler, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + model = [model] + print_rank_0(get_parameters_in_billions(model)) + #exit() + + # ---------- Reference model ------------- + # model_ref, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) # throwing assertion error + model_ref = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes? + # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider) + optimizer_2 = get_megatron_optimizer(model_ref, None, None, 1.0) + opt_param_scheduler_2 = get_optimizer_param_scheduler(optimizer_2) + model_ref, optimizer_2, _, opt_param_scheduler_2 = deepspeed.initialize( + model=model_ref[0], + optimizer=optimizer_2, + args=args, + lr_scheduler=opt_param_scheduler_2, + mpu=mpu if args.no_pipeline_parallel else None, + config=args.deepspeed_config_dict, + ) + # model_ref, _, _, _ = deepspeed.initialize( + # model=model_ref[0], + # optimizer=None, + # args=args, + # lr_scheduler=None, + # mpu=mpu if args.no_pipeline_parallel else None, + # config=args.deepspeed_config_dict, + # ) + # engine = deepspeed.init_inference(model=model_ref[0], + # mp_size=args.tensor_model_parallel_size, + # tensor_parallel={"mpu": mpu}, + # dtype=torch.half, + # replace_with_kernel_inject=True, + # # moe_experts=args.num_experts, + # # moe_type=args.mlp_type + # ) + # model_ref = engine.module + + print_rank_0(f'optimizer_2: {optimizer_2}') + + if isinstance(model_ref, deepspeed.PipelineEngine): + print(f'Doing assertion checks on model_ref..') + # hack to get batch_fn from pretrain_gpt.py + model_ref.set_batch_fn(model_ref.module._megatron_batch_fn) + + assert model_ref.grid.get_pipe_parallel_rank() == mpu.get_pipeline_model_parallel_rank() + assert model_ref.grid.get_slice_parallel_rank() == mpu.get_tensor_model_parallel_rank() + assert model_ref.grid.get_data_parallel_rank() == mpu.get_data_parallel_rank() + + model_ref = [model_ref] + iteration2 = load_checkpoint(model_ref, optimizer_2, opt_param_scheduler_2) # THIS WORKED!! After commenting out assert args.consumed_train_samples == 0 in load_checkpoint() + + # THINGS THAT DID NOT WORK FOR LOADING FROM CHECKPOINT + # model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only(model_provider) # DID NOT WORK - train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size 32 != 8 * 1 * 8 + # model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only_modified(model_provider) # DID NOT WORK - optimizer = FusedAdam(TypeError: FusedAdam.__init__() got an unexpected keyword argument 'beta1' + # ---------------------------------------- + + if args.data_file_list_u is not None: + print(f'data files list unpreferred: {args.data_file_list_u}') + + # Number of train/valid/test samples. + if args.train_samples: + print(f'args.train_samples: {args.train_samples}') + train_samples = args.train_samples + else: + print(f'args.train_iters: {args.train_iters}') + print(f'args.global_batch_size: {args.global_batch_size}') + train_samples = args.train_iters * args.global_batch_size + + print(f'args.eval_interval: {args.eval_interval}') + print(f'args.eval_iters: {args.eval_iters}') + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print_rank_0(f'train_val_test_num_samples: {train_val_test_num_samples}') + # print(f'args.data_impl: {args.data_impl}') + # print(f'args.split: {args.split}') + # print(f'args.seq_length: {args.seq_length}') + # print(f'args.seed: {args.seed}') + # print(f'args.train_data_path: {args.train_data_path}') + # print(f'args.valid_data_path: {args.valid_data_path}') + # print(f'args.test_data_path: {args.test_data_path}') + # print(f'args.data_cache_path: {args.data_cache_path}') + + files_u = [] + with open(args.data_file_list_u, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files_u.append(float(w)) + files_u.append(fname) + train_ds_u, valid_ds_u, test_ds_u = build_train_valid_test_datasets( + data_prefix=files_u, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating unpreferred GPT datasets ...") + + if args.data_file_list_p is not None: + print_rank_0(f'data files list preferred: {args.data_file_list_p}') + + files_p = [] + with open(args.data_file_list_p, 'r') as flist: + for f in flist.readlines(): + w, fname = f.split() + files_p.append(float(w)) + files_p.append(fname) + train_ds_p, valid_ds_p, test_ds_p = build_train_valid_test_datasets( + data_prefix=files_p, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=True, + # skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating preferred GPT datasets ...") + + # Data loaders + print_rank_0(f'args.consumed_train_samples: {args.consumed_train_samples}') + print_rank_0(f'args.dataloader_type: {args.dataloader_type}') + train_dataloader_u = build_pretraining_data_loader( + train_ds_u, args.consumed_train_samples) + train_dataloader_p = build_pretraining_data_loader( + train_ds_p, args.consumed_train_samples) + + # Build train iterators + dl_type = args.dataloader_type + assert dl_type in ['single', 'cyclic'] + + if train_dataloader_u is not None: + print_rank_0(f'unpreferred train_dataloader is not None..') + train_data_iterator_u = iter(train_dataloader_u) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader_u)) + print_rank_0("> finished creating unpreferred train_data_iterator...") + if train_dataloader_p is not None: + print_rank_0(f'preferred train_dataloader is not None..') + train_data_iterator_p = iter(train_dataloader_p) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader_p)) + print_rank_0("> finished creating preferred train_data_iterator...") + + + print_rank_0(f'args.train_iters: {args.train_iters}') + print_rank_0(f'args.save_interval: {args.save_interval}') + report_memory_flag = True + + # Train model + model[0].train() + + if torch.distributed.get_rank() == 0: + averaged_loss_iter = [] + averaged_rewards_iter = [] + avg_loss_epoch = [] + avg_rewards_epoch = [] + + for epoch in range(1): + iteration = 0 + for i in range(args.train_iters): + # Get batch + timers = get_timers() + timers('batch-generator-unpreferred', log_level=2).start() + tokens_u, labels_u, loss_mask_u, attention_mask_u, position_ids_u = get_batch( + train_data_iterator_u) + timers('batch-generator-unpreferred').stop() + # print_rank_0(f'tokens_u[0].size(): {tokens_u[0].size()}') + # print_rank_0(f'tokens_u[0,400:1024]: {tokens_u[0,400:1024]}') + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for unpref train_data_iterator ...") + + timers('batch-generator-preferred', log_level=2).start() + tokens_p, labels_p, loss_mask_p, attention_mask_p, position_ids_p = get_batch( + train_data_iterator_p) + timers('batch-generator-preferred').stop() + # print(f'tokens shape: {tokens_u.shape}') + print_rank_0("> finished extracting batch of tokens, labels, attn mask etc. for pref train_data_iterator ...") + + # Model forward + # output_tensor, other_losses = model[0]( + # tokens_u, + # position_ids_u, + # attention_mask_u, + # labels=labels_u + # ) # OUT OF MEMORY ERROR even with 4 nodes + + # Model forward with concatenated inputs + tokens_c = torch.cat((tokens_p,tokens_u), 0) + position_ids_c = torch.cat((position_ids_p,position_ids_u), 0) + labels_c = torch.cat((labels_p,labels_u), 0) + loss_mask_c = torch.cat((loss_mask_p,loss_mask_u), 0) + + # Logits and loss + output_c, other_losses_c = model[0]( + tokens_c, + position_ids_c, + None, + # labels=labels_u + ) + + loss_c = tensor_parallel.vocab_parallel_cross_entropy( + output_c.contiguous().float(), + labels_c + ) + + # Reference model forward with concatenated inputs + with torch.no_grad(): + # Logits and loss + routput_c, rother_losses_c = model_ref[0]( + tokens_c, + position_ids_c, + None, + # labels=labels_u + ) + rloss_c = tensor_parallel.vocab_parallel_cross_entropy( + routput_c.contiguous().float(), + labels_c + ) + + # # Print statements for debugging + # print(f'tokens_p: {tokens_p}') + # print(f'tokens_u: {tokens_u}') + # # print(f'output_p[0]: {output_p[0]}') + # # print(f'output_u[0]: {output_u[0]}') + # print(f'output_c[0]: {output_c[0]}') + # print(f'tokens_p shape: {tokens_p.size()}, tokens_u shape: {tokens_u.size()}') + # print(f'tokens_c shape: {tokens_c.size()}') + # print(f'position_ids_p shape: {position_ids_p.size()}, position_ids_u shape: {position_ids_u.size()}') + # print(f'position_ids_c shape: {position_ids_c.size()}') + # print(f'output_c shape: {output_c.size()}') + # print(f'loss_c shape: {loss_c.size()}') + # print(f'routput_c shape: {routput_c.size()}') + # print(f'rloss_c shape: {rloss_c.size()}') + # print(f'loss_mask_p shape: {loss_mask_p.size()}') + # print(f'loss_mask_u shape: {loss_mask_u.size()}') + # print(f'loss_mask_c shape: {loss_mask_c.size()}') + # print(f'attention_mask_u: {attention_mask_u}') + # print(f'loss_mask_p sum: {torch.sum(loss_mask_p), 8*4096}')# print(f'loss_mask_p shape: {loss_mask_p.size()}') + + # Seq logprobs + print_rank_0(f'args.micro_batch_size: {args.micro_batch_size}') + seq_logps_p = torch.sum(loss_c[:args.micro_batch_size,:] * loss_mask_p, dim=-1) / torch.sum(loss_mask_p, dim=-1) + seq_logps_u = torch.sum(loss_c[args.micro_batch_size:,:] * loss_mask_u, dim=-1) / torch.sum(loss_mask_u, dim=-1) + rseq_logps_p = torch.sum(rloss_c[:args.micro_batch_size,:] * loss_mask_p, dim=-1) / torch.sum(loss_mask_p, dim=-1) + rseq_logps_u = torch.sum(rloss_c[args.micro_batch_size:,:] * loss_mask_u, dim=-1) / torch.sum(loss_mask_u, dim=-1) + + # # Print statements for debugging + # print(f'seq_logps_p shape: {seq_logps_p.size()}') + # print(f'seq_logps_u shape: {seq_logps_u.size()}') + # print(f'rseq_logps_p shape: {rseq_logps_p.size()}') + # print(f'rseq_logps_u shape: {rseq_logps_u.size()}') + + # Loss + pu_ratio = seq_logps_p - seq_logps_u + rpu_ratio = rseq_logps_p - rseq_logps_u + sdiff_ratio = 0.1*(pu_ratio - rpu_ratio) + # print(f'sdiff_ratio: {sdiff_ratio}') + final = -F.logsigmoid(sdiff_ratio) + # print(f'final: {final}') + # dloss = torch.sum(final) + dloss = torch.mean(final) + + # Model backward and update + model[0].backward(dloss) + + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + # print(f'increment: {increment}') + # model[0].step(lr_kwargs={'increment': increment}) + model[0].step() + update_successful = model[0].was_step_applied() + print_rank_0(f'update_successful: {update_successful}') + + # Iteration updates + iteration += 1 + args.iteration = iteration + # print(f'args.consumed_train_samples: {args.consumed_train_samples}') + new_samples = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + + + args.consumed_train_samples += new_samples + # print(f'args.consumed_train_samples: {args.consumed_train_samples}') + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([dloss]) + loss_dict = {'loss': averaged_loss} + print_rank_0(f'iteration: {iteration}, dloss: {averaged_loss.detach().cpu().tolist()}') + psrewards_p = (0.1 * (seq_logps_p - rseq_logps_p)).detach() + psrewards_u = (0.1 * (seq_logps_u - rseq_logps_u)).detach() + psrewards = (psrewards_p > psrewards_u).float() + rewards = psrewards.cpu().mean() + print_rank_0(f'iteration: {iteration}, rewards: {rewards}') + + # wandb logging + # report_memory_flag = training_log_dpo(loss_dict, iteration, report_memory_flag) + + if torch.distributed.get_rank() == 0: + averaged_loss_iter.append(averaged_loss.detach().cpu().tolist()[0]) + averaged_rewards_iter.append(rewards.tolist()) + + if (i % args.save_interval == 0) and (i > 0) and (torch.distributed.get_rank() == 0): + TPL = args.tensor_model_parallel_size + GRAD_ACC = os.environ.get('GRAD_ACC_STEPS') + print(f'Checkpointing loss and rewards at iteration {i} ..') + np.savez(f'./runs/loss-rewards_indels_textseq_nranks-{WORLD_SIZE}_model-nlayers-{args.num_layers}_TP-{TPL}_zero-{args.zero_stage}_gradacc-{GRAD_ACC}_lr-{args.lr}_seq-{args.seq_length}_bs-{args.micro_batch_size}_iters-{args.train_iters}-chkpt-{i}.npz', loss=np.array(averaged_loss_iter), rewards=np.array(averaged_rewards_iter)) + + # if torch.distributed.get_rank() == 0: + # avg_loss_epoch.append(np.array(averaged_loss_iter).mean()) + # avg_rewards_epoch.append(np.array(averaged_rewards_iter).mean()) + + # Aggregated loss and rewards + # torch.distributed.barrier() + # if torch.distributed.get_rank() == 0: + # print(averaged_loss_iter) + # print(averaged_rewards_iter) + # print(avg_loss_epoch) + # print(avg_rewards_epoch) + # np.savez(f'./runs/proteingym_indels/loss-rewards_iters-{args.train_iters}.npz', loss=np.array(averaged_loss_iter), rewards=np.array(averaged_rewards_iter)) + + # Generate + if False: + model[0].eval() + print_rank_0(f'Generation mode..') + print_rank_0(f'args.seq_length: {args.seq_length}') + tokenizer = get_tokenizer() + print_rank_0(f'len(tokenizer.vocab): {len(tokenizer.vocab)}') + prompts=["A sequence", "A sequence","A sequence", "A sequence", "A sequence"] + tokens_to_generate = 64 + add_BOS = False + if add_BOS: + prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt) + for prompt in prompts] + else: + prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts] + + print_rank_0(f'prompts_tokens: {prompts_tokens}') + + # Now we have a list of list of tokens which each list has a different + # size. We want to extend this list to: + # - incorporate the tokens that need to be generated + # - make all the sequences equal length. + # Get the prompts length. + prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens] + # Get the max prompts length. + max_prompt_len = max(prompts_length) + # Number of tokens in the each sample of the batch. + samples_length = max_prompt_len + tokens_to_generate + # Now update the list of list to be of the same size: samples_length. + for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length): + padding_size = samples_length - prompt_length + prompt_tokens.extend([tokenizer.eod] * padding_size) + + # Now we are in a structured format, we can convert to tensors. + prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens) + prompts_length_tensor = torch.cuda.LongTensor(prompts_length) + print_rank_0(f'prompts_tokens_tensor: {prompts_tokens_tensor}') + print_rank_0(f'prompts_length_tensor: {prompts_length_tensor}') + + batch_size = prompts_tokens_tensor.size(0) + min_prompt_length = prompts_length_tensor.min().item() + max_sequence_length = prompts_tokens_tensor.size(1) + + print_rank_0(f'batch_size: {batch_size}') + print_rank_0(f'min_prompt_length: {min_prompt_length}') + print_rank_0(f'max_sequence_length: {max_sequence_length}') + print_rank_0(f'max_position_embeddings: {args.max_position_embeddings}') + print_rank_0(f'args.max_tokens_to_oom: {args.max_tokens_to_oom}') + if max_sequence_length > args.max_position_embeddings: + raise ValueError("Length of prompt + tokens_to_generate longer than allowed") + + if max_sequence_length * batch_size > args.max_tokens_to_oom: + raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) + + # INSTANTIATING FORWARD_STEP ? + model_fwd = ForwardStep(model[0], batch_size, max_sequence_length) + inference_params = InferenceParams(batch_size, + max_sequence_length) + + if hasattr(args, 'eos_id'): + termination_id = args.eos_id + print(f'args.eos_id: {args.eos_id}') + else: + termination_id = tokenizer.eod + print(f'tokenizer.eod: {tokenizer.eod}') + + # Log probability of the sequence (prompt + generated tokens). + output_log_probs = None + output_log_probs_size = (batch_size, max_sequence_length - 1) + # Lengths of generated seuquence including including prompts. + generated_sequence_lengths = None + top_k = 0 + top_p = 1.0 + temperature = 1.0 + top_p_decay=0.0 + top_p_bound=0.0 + add_BOS=False + use_eod_token_for_early_termination=True + stop_on_double_eol=False + stop_on_eol=False + prevent_newline_after_colon=False + random_seed=42 + return_output_log_probs = False + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: + output_log_probs = torch.empty(output_log_probs_size, + dtype=torch.float32, + device=torch.cuda.current_device()) + print_rank_0(f'On mpu.is_pipeline_last_stage branch and output_log_probs is set: {output_log_probs}') + generated_sequence_lengths = torch.ones( + batch_size, dtype=torch.int64, + device=torch.cuda.current_device()) * max_sequence_length + print_rank_0(f'On mpu.is_pipeline_last_stage branch and generated_sequence_lengths: {generated_sequence_lengths}') + + # Whether we have reached a termination id. + is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, + device=torch.cuda.current_device()) + + + with torch.no_grad(): + prompts_attention_mask, _, prompts_position_ids = get_ltor_masks_and_position_ids( + data=prompts_tokens_tensor, + eod_token=None, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False + ) + prev_context_length = 0 + for context_length in range(min_prompt_length, max_sequence_length): + # Pick the slice that we need to pass through the network. + tokens2use = prompts_tokens_tensor[:, prev_context_length:context_length] + positions2use = prompts_position_ids[:, prev_context_length:context_length] + attention_mask2use = prompts_attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + # print_rank_0(f'tokens2use shape: {tokens2use.size()}') + # print_rank_0(f'positions2use shape: {positions2use.size()}') + # print_rank_0(f'attention_mask2use shape: {attention_mask2use.size()}') + # print_rank_0(f'prompts_tokens_tensor shape: {prompts_tokens_tensor.size()}') + # print_rank_0(f'prompts_position_ids shape: {prompts_position_ids.size()}') + # print_rank_0(f'prompts_attention_mask shape: {prompts_attention_mask.size()}') + + # ------ + # plogits = forward_step(tokens2use, positions2use, attention_mask2use) + # plogits = plogits[0] + # print_rank_0(f'context_length: {context_length}, plogits: {plogits}') + + # plogits = model[0](prompts_tokens_tensor, + # prompts_position_ids, + # prompts_attention_mask, + # inference_params=inference_params + # ) + # print_rank_0(f'logits: {plogits}') + #------- + inference_params = InferenceParams(batch_size, + tokens2use.size(1)) + plogits = model[0](tokens2use, + positions2use, + attention_mask2use, + inference_params=inference_params + ) + plogits = plogits[0] + # plogits = torch.cuda.FloatTensor(plogits) + # print_rank_0(f'plogits: {plogits.size()}') + # print_rank_0(f'plogits type: {plogits.dtype}') + + if mpu.is_pipeline_last_stage(): + if prevent_newline_after_colon: + plogits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":" + # Always the last stage should have an output. + assert plogits is not None + + # Sample. + last_token_logits = plogits[:, -1, :] + new_sample = sample(last_token_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + vocab_size=tokenizer.vocab_size) + if top_p > 0.0 and top_p_decay > 0.0: + top_p = top_p * top_p_decay + if top_p_bound > 0.0: + top_p = max(top_p, top_p_bound) + print_rank_0(f'new_sample: {new_sample}') + for nidx, ns in enumerate(new_sample.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, new_sample[{nidx}]: {tokenizer.detokenize(ns)}') + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = prompts_length_tensor <= context_length + # Update the tokens. + print_rank_0(f'started: {started}') + # print_rank_0(f'prompts_tokens_tensor before copying new_sample: {prompts_tokens_tensor}') + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor before[{nidx}]: {tokenizer.detokenize(ns)}') + prompts_tokens_tensor[started, context_length] = new_sample[started] + # print_rank_0(f'prompts_tokens_tensor after copying new_sample: {prompts_tokens_tensor}') + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor after[{nidx}]: {tokenizer.detokenize(ns)}') + + # Update the tokens on the first stage so the next input to + # the network is correct. + copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, + prompts_tokens_tensor[:, context_length]) + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, prompts_tokens_tensor after copy_from_last_to_first_pipeline_stage [{nidx}]: {tokenizer.detokenize(ns)}') + + # Update the context length for the next token generation. + prev_context_length = context_length + print_rank_0(f'prev_context_length: {prev_context_length}') + + # Check if all the sequences have hit the termination_id. + done = None + if mpu.is_pipeline_last_stage(): + # These stopping methods are tokenizer dependent + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() + + just_finished = (done_token & ~is_generation_done).bool() + generated_sequence_lengths[just_finished.view(-1)] = \ + context_length + 1 + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) + done = broadcast_from_last_pipeline_stage(1, torch.uint8, + tensor=done) + if use_eod_token_for_early_termination and done: + print_rank_0(f'done: {done}') + break + + # =================================================== + # Update the length of based on max generated length. + # =================================================== + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop [{nidx}]: {tokenizer.detokenize(ns)}') + prompts_tokens_tensor = prompts_tokens_tensor[:, :(context_length + 1)] + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and slicing with ctx length[{nidx}]: {tokenizer.detokenize(ns)}') + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: + output_log_probs = output_log_probs[:, :context_length] + + # ====================================== + # Broadcast to the first pipeline stage. + # ====================================== + + generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( + batch_size, torch.int64, generated_sequence_lengths) + if return_output_log_probs: + output_log_probs_size = (batch_size, context_length) + output_log_probs = broadcast_from_last_to_first_pipeline_stage( + output_log_probs_size, torch.float32, output_log_probs) + + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and befoer final post-process[{nidx}]: {tokenizer.detokenize(ns)}') + # Only post-process on first stage. + if mpu.is_pipeline_first_stage(): + prompts_plus_generations = [] + + for nidx, ns in enumerate(prompts_tokens_tensor.cpu().numpy().tolist()): + print_rank_0(f'nidx: {nidx}, detokenized prompts_tokens_tensor after the generate loop and after final post-process[{nidx}]: {tokenizer.detokenize(ns)}') + + rtokens = prompts_tokens_tensor.cpu().numpy().tolist() + rlengths = prompts_length_tensor.cpu().numpy().tolist() + print_rank_0(f'rlengths: {rlengths}') + for sequence_tokens, slength in zip(rtokens, rlengths): + sequence_tokens = sequence_tokens[:slength] + prompts_plus_generations.append( + tokenizer.detokenize(sequence_tokens)) + # _, prompts_plus_generations, prompts_plus_generations_segments = \ + # detokenize_generations(prompts_tokens_tensor, prompts_length_tensor, True) + + print_rank_0(f'prompts_plus_generations: {prompts_plus_generations}') + + if False: + prompts=["Pen is mightier than", "A sequence", "Pythagoras theorem", "A sequence", "Hello world"] + tokens_to_generate = 64 + generated_responses = generate_post_training(model, prompts, tokens_to_generate, fprint=False) + + if False: + print_rank_0(f'Generation mode..') + print_rank_0(f'args.seq_length: {args.seq_length}') + tokenizer = get_tokenizer() + print_rank_0(f'len(tokenizer.vocab): {len(tokenizer.vocab)}') + model[0].eval() + choice = torch.cuda.LongTensor(1) + torch.distributed.broadcast(choice, 0) + # if choice[0].item() == 0: + try: + tokens_to_generate_len = 1021 + response, _, _, _ = generate_and_post_process(model[0], prompts=["A sequence", "A sequence","A sequence", "A sequence", "A sequence", "A sequence","A sequence", "A sequence"], tokens_to_generate=tokens_to_generate_len) + print_rank_0(f'generation completed..\n response:{response}') + except ValueError as ve: + print_rank_0(f'ValueError: {ve}') + pass + # elif choice[0].item() == 1: + # try: + # response, _, _ = beam_search_and_post_process(model[0], prompts=["A sequence", "A sequence", "A sequence", "A sequence",], tokens_to_generate=32) + # print(f'generation completed..\n response:{response}') + # except ValueError as ve: + # print(f'ValueError: {ve}') + # pass + + # # Checkpointing + # if args.save and iteration != 0: + # save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + + return model + +# def main(): +# # if RANK == 0: +# # setup_wandb() +# if os.getenv('TORCH_PROFILER_ENABLED') == '1': +# from torch.profiler import profile, record_function, ProfilerActivity +# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: +# model = pretrain( +# train_valid_test_datasets_provider, +# model_provider, +# ModelType.encoder_or_decoder, +# forward_step, +# args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, +# data_post_process=data_post_process +# ) + +# prof.export_chrome_trace(f"{args.tensorboard_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}.json") +# else: +# model = pretrain( +# train_valid_test_datasets_provider, +# model_provider, +# ModelType.encoder_or_decoder, +# forward_step, +# args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, +# data_post_process=data_post_process +# ) +# return model + + +if __name__ == "__main__": + # git_ds_info() + # pretrain(train_valid_test_datasets_provider, + # model_provider, + # ModelType.encoder_or_decoder, + # forward_step, + # args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + # data_post_process=data_post_process) + import sys + import deepspeed.comm as dist + + # Return trained model + model = main() + + dist.log_summary() + if wandb.run is not None: + print(f"wandb.run.name: {wandb.run.name}") + print(f"wandb.run.url: {wandb.run.url}") + wandb.finish() + sys.exit() From 4804eb72e4d2349ca56e54e0e2dcf976ab9544e8 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 1 Aug 2024 07:39:42 -0700 Subject: [PATCH 48/50] scaling 3.5B model --- run_dgxcluster_scaling_3p5B.sh | 83 ++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 run_dgxcluster_scaling_3p5B.sh diff --git a/run_dgxcluster_scaling_3p5B.sh b/run_dgxcluster_scaling_3p5B.sh new file mode 100644 index 0000000000..1188f17b37 --- /dev/null +++ b/run_dgxcluster_scaling_3p5B.sh @@ -0,0 +1,83 @@ +#!/bin/bash +#SBATCH --partition defq --nodes 31 +#SBATCH --exclusive +#SBATCH --job-name=example-mn-sbatch-job +#SBATCH --gpus-per-node=8 + +CONTAINER=${HOME}/enroot_images/megds2.sqsh +#srun --nodes 2 --mpi=pmix --gpus-per-node 8 --container-image=${CONTAINER} --ntasks-per-node=1 nvidia-smi -L +#exit 0 + +export OMPI_MCA_coll_hcoll_enable=0 +export UCX_TLS=rc +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +export NCCL_TOPO_FILE=/cm/shared/etc/ndv4-topo.xml +export NCCL_DEBUG=INFO +export NCCL_PROTO=LL,LL128,Simple +export NCCL_ALGO=Tree,Ring,CollnetDirect,CollnetChain,NVLS +export MELLANOX_VISIBLE_DEVICES=all +export PMIX_MCA_gds=hash +export PMIX_MCA_psec=native + +export NHOSTS="${SLURM_NNODES}" +export NGPU_PER_HOST="${SLURM_GPUS_ON_NODE}" +export NGPUS="$(( NHOSTS * NGPU_PER_HOST ))" +export OMP_NUM_THREADS=1 +export WORLD_SIZE=$NGPUS +export RANK=$SLURM_PROCID +export LOCAL_RANK=$SLURM_LOCALID +export NCCL_DEBUG=warn + +echo "PATH=$PATH" > .deepspeed_env +echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> .deepspeed_env +echo "CPATH=$CPATH" >> .deepspeed_env +echo "TORCH_EXTENSIONS_DIR=$PWD/deepspeed" >> .deepspeed_env +echo "HF_HOME=$PWD/hfdata" >> .deepspeed_env + + +echo ${SLURM_GPUS_ON_NODE} + +if [ ! -z "${SLURM_JOB_ID}" ]; then + # check the original location through scontrol and $SLURM_JOB_ID + SCRIPT_PATH=$(scontrol show job $SLURM_JOBID | awk -F= '/Command=/{print $2}') + export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +else + # otherwise: started with bash. Get the real location. + SCRIPT_PATH=$(realpath $0) +fi + +export _basedir="$(cd "$(dirname "${SCRIPT_PATH}")" && pwd)" +cd ${_basedir} +echo ${_basedir} + +#cd $SCRIPT_PATH +echo $SCRIPT_PATH +echo $SLURM_NNODES + +#CONTAINER=${HOME}/enroot_images/megds2.sqsh +#source /lustre/fs0/scratch/gdharuman/Megatron-DeepSpeed/deps/ezpz/src/ezpz/bin/savejobenv +srun --mpi=pmix --nodes $SLURM_NNODES --gpus-per-node 8 --ntasks-per-node=8 --container-workdir=${_basedir} --container-mounts="/lustre/fs0/scratch/gdharuman","/home/gdharuman" --container-image=${CONTAINER} python /lustre/fs0/scratch/gdharuman/Megatron-DeepSpeed/dpo_training_ref.py \ + --use-flash-attn-v2 --fp16 --split 100,0,0 \ + --log-interval 1 --no-bias-gelu-fusion \ + --lr-decay-style cosine --no-bias-dropout-fusion \ + --no-masked-softmax-fusion --tokenizer-type Llama2Tokenizer \ + --no-gradient-accumulation-fusion --accumulate-allreduce-grads-in-fp32 \ + --use-checkpoint-opt_param-scheduler --lr 5e-6 --seq-length 512 \ + --save checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ + --load checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ + --num-layers 16 --hidden-size 4096 --train-iters 100 --eval-iters 10 \ + --distributed-backend nccl --num-attention-heads 32 --save-interval 2000 \ + --eval-interval 50000 --max-position-embeddings 1024 --micro-batch-size 8 \ + --data-file-list-p ALCF/data_textseq_p_all.txt \ + --data-file-list-u ALCF/data_textseq_u_all.txt \ + --tensor-model-parallel-size 1 --pipeline-model-parallel-size 1 \ + --num-key-value-heads 32 --data-cache-path ./index-cache \ + --ffn-hidden-size 11008 --tokenizer-model ALCF/tokenizer.model \ + --no-query-key-layer-scaling --use-rotary-position-embeddings \ + --untie-embeddings-and-output-weights --swiglu \ + --normalization rmsnorm --disable-bias-linear \ + --zero-stage=1 --deepspeed_config=ds_config-gpt_nooffload.json \ + --no-pipeline-parallel --deepspeed --optimizer adamw From 14e2bf36d4ed42180c5d41b5e1e041f857f029c6 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 1 Aug 2024 07:39:59 -0700 Subject: [PATCH 49/50] scaling 70B model --- run_dgxcluster_scaling_70B.sh | 83 +++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 run_dgxcluster_scaling_70B.sh diff --git a/run_dgxcluster_scaling_70B.sh b/run_dgxcluster_scaling_70B.sh new file mode 100644 index 0000000000..1a88b86a86 --- /dev/null +++ b/run_dgxcluster_scaling_70B.sh @@ -0,0 +1,83 @@ +#!/bin/bash +#SBATCH --partition defq --nodes 31 +#SBATCH --exclusive +#SBATCH --job-name=example-mn-sbatch-job +#SBATCH --gpus-per-node=8 + +CONTAINER=${HOME}/enroot_images/megds2.sqsh +#srun --nodes 2 --mpi=pmix --gpus-per-node 8 --container-image=${CONTAINER} --ntasks-per-node=1 nvidia-smi -L +#exit 0 + +export OMPI_MCA_coll_hcoll_enable=0 +export UCX_TLS=rc +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +export NCCL_TOPO_FILE=/cm/shared/etc/ndv4-topo.xml +export NCCL_DEBUG=INFO +export NCCL_PROTO=LL,LL128,Simple +export NCCL_ALGO=Tree,Ring,CollnetDirect,CollnetChain,NVLS +export MELLANOX_VISIBLE_DEVICES=all +export PMIX_MCA_gds=hash +export PMIX_MCA_psec=native + +export NHOSTS="${SLURM_NNODES}" +export NGPU_PER_HOST="${SLURM_GPUS_ON_NODE}" +export NGPUS="$(( NHOSTS * NGPU_PER_HOST ))" +export OMP_NUM_THREADS=1 +export WORLD_SIZE=$NGPUS +export RANK=$SLURM_PROCID +export LOCAL_RANK=$SLURM_LOCALID +export NCCL_DEBUG=warn + +echo "PATH=$PATH" > .deepspeed_env +echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> .deepspeed_env +echo "CPATH=$CPATH" >> .deepspeed_env +echo "TORCH_EXTENSIONS_DIR=$PWD/deepspeed" >> .deepspeed_env +echo "HF_HOME=$PWD/hfdata" >> .deepspeed_env + + +echo ${SLURM_GPUS_ON_NODE} + +if [ ! -z "${SLURM_JOB_ID}" ]; then + # check the original location through scontrol and $SLURM_JOB_ID + SCRIPT_PATH=$(scontrol show job $SLURM_JOBID | awk -F= '/Command=/{print $2}') + export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +else + # otherwise: started with bash. Get the real location. + SCRIPT_PATH=$(realpath $0) +fi + +export _basedir="$(cd "$(dirname "${SCRIPT_PATH}")" && pwd)" +cd ${_basedir} +echo ${_basedir} + +#cd $SCRIPT_PATH +echo $SCRIPT_PATH +echo $SLURM_NNODES + +#CONTAINER=${HOME}/enroot_images/megds2.sqsh +#source /lustre/fs0/scratch/gdharuman/Megatron-DeepSpeed/deps/ezpz/src/ezpz/bin/savejobenv +srun --mpi=pmix --nodes $SLURM_NNODES --gpus-per-node 8 --ntasks-per-node=8 --container-workdir=${_basedir} --container-mounts="/lustre/fs0/scratch/gdharuman","/home/gdharuman" --container-image=${CONTAINER} python /lustre/fs0/scratch/gdharuman/Megatron-DeepSpeed/dpo_training.py \ + --use-flash-attn-v2 --fp16 --split 100,0,0 \ + --log-interval 1 --no-bias-gelu-fusion \ + --lr-decay-style cosine --no-bias-dropout-fusion \ + --no-masked-softmax-fusion --tokenizer-type Llama2Tokenizer \ + --no-gradient-accumulation-fusion --accumulate-allreduce-grads-in-fp32 \ + --use-checkpoint-opt_param-scheduler --lr 5e-6 --seq-length 128 \ + --save checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ + --load checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ + --num-layers 80 --hidden-size 8192 --train-iters 100 --eval-iters 10 \ + --distributed-backend nccl --num-attention-heads 64 --save-interval 2000 \ + --eval-interval 50000 --max-position-embeddings 128 --micro-batch-size 2 \ + --data-file-list-p ALCF/data_textseq_p_all.txt \ + --data-file-list-u ALCF/data_textseq_u_all.txt \ + --tensor-model-parallel-size 8 --pipeline-model-parallel-size 1 \ + --num-key-value-heads 8 --data-cache-path ./index-cache \ + --ffn-hidden-size 28672 --tokenizer-model ALCF/tokenizer.model \ + --no-query-key-layer-scaling --use-rotary-position-embeddings \ + --untie-embeddings-and-output-weights --swiglu \ + --normalization rmsnorm --disable-bias-linear \ + --zero-stage=2 --deepspeed_config=ds_config-gpt_nooffload_zero2.json \ + --no-pipeline-parallel --deepspeed --optimizer adamw From 71db627b0d6458ebf1d7779f3ff87a89ff6c3932 Mon Sep 17 00:00:00 2001 From: gdharuman11 Date: Thu, 1 Aug 2024 07:40:14 -0700 Subject: [PATCH 50/50] scaling 7B model --- run_dgxcluster_scaling_7B.sh | 83 ++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 run_dgxcluster_scaling_7B.sh diff --git a/run_dgxcluster_scaling_7B.sh b/run_dgxcluster_scaling_7B.sh new file mode 100644 index 0000000000..de716de64b --- /dev/null +++ b/run_dgxcluster_scaling_7B.sh @@ -0,0 +1,83 @@ +#!/bin/bash +#SBATCH --partition defq --nodes 31 +#SBATCH --exclusive +#SBATCH --job-name=example-mn-sbatch-job +#SBATCH --gpus-per-node=8 + +CONTAINER=${HOME}/enroot_images/megds2.sqsh +#srun --nodes 2 --mpi=pmix --gpus-per-node 8 --container-image=${CONTAINER} --ntasks-per-node=1 nvidia-smi -L +#exit 0 + +export OMPI_MCA_coll_hcoll_enable=0 +export UCX_TLS=rc +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +export NCCL_TOPO_FILE=/cm/shared/etc/ndv4-topo.xml +export NCCL_DEBUG=INFO +export NCCL_PROTO=LL,LL128,Simple +export NCCL_ALGO=Tree,Ring,CollnetDirect,CollnetChain,NVLS +export MELLANOX_VISIBLE_DEVICES=all +export PMIX_MCA_gds=hash +export PMIX_MCA_psec=native + +export NHOSTS="${SLURM_NNODES}" +export NGPU_PER_HOST="${SLURM_GPUS_ON_NODE}" +export NGPUS="$(( NHOSTS * NGPU_PER_HOST ))" +export OMP_NUM_THREADS=1 +export WORLD_SIZE=$NGPUS +export RANK=$SLURM_PROCID +export LOCAL_RANK=$SLURM_LOCALID +export NCCL_DEBUG=warn + +echo "PATH=$PATH" > .deepspeed_env +echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> .deepspeed_env +echo "CPATH=$CPATH" >> .deepspeed_env +echo "TORCH_EXTENSIONS_DIR=$PWD/deepspeed" >> .deepspeed_env +echo "HF_HOME=$PWD/hfdata" >> .deepspeed_env + + +echo ${SLURM_GPUS_ON_NODE} + +if [ ! -z "${SLURM_JOB_ID}" ]; then + # check the original location through scontrol and $SLURM_JOB_ID + SCRIPT_PATH=$(scontrol show job $SLURM_JOBID | awk -F= '/Command=/{print $2}') + export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +else + # otherwise: started with bash. Get the real location. + SCRIPT_PATH=$(realpath $0) +fi + +export _basedir="$(cd "$(dirname "${SCRIPT_PATH}")" && pwd)" +cd ${_basedir} +echo ${_basedir} + +#cd $SCRIPT_PATH +echo $SCRIPT_PATH +echo $SLURM_NNODES + +#CONTAINER=${HOME}/enroot_images/megds2.sqsh +#source /lustre/fs0/scratch/gdharuman/Megatron-DeepSpeed/deps/ezpz/src/ezpz/bin/savejobenv +srun --mpi=pmix --nodes $SLURM_NNODES --gpus-per-node 8 --ntasks-per-node=8 --container-workdir=${_basedir} --container-mounts="/lustre/fs0/scratch/gdharuman","/home/gdharuman" --container-image=${CONTAINER} python /lustre/fs0/scratch/gdharuman/Megatron-DeepSpeed/dpo_training.py \ + --use-flash-attn-v2 --fp16 --split 100,0,0 \ + --log-interval 1 --no-bias-gelu-fusion \ + --lr-decay-style cosine --no-bias-dropout-fusion \ + --no-masked-softmax-fusion --tokenizer-type Llama2Tokenizer \ + --no-gradient-accumulation-fusion --accumulate-allreduce-grads-in-fp32 \ + --use-checkpoint-opt_param-scheduler --lr 5e-6 --seq-length 512 \ + --save checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ + --load checkpoints/ds_stage2_nl6_hs4096_mb24_seq1024_gb48_pp1_tp2_fp16 \ + --num-layers 32 --hidden-size 4096 --train-iters 100 --eval-iters 10 \ + --distributed-backend nccl --num-attention-heads 32 --save-interval 2000 \ + --eval-interval 50000 --max-position-embeddings 1024 --micro-batch-size 4 \ + --data-file-list-p ALCF/data_textseq_p_all.txt \ + --data-file-list-u ALCF/data_textseq_u_all.txt \ + --tensor-model-parallel-size 1 --pipeline-model-parallel-size 1 \ + --num-key-value-heads 32 --data-cache-path ./index-cache \ + --ffn-hidden-size 11008 --tokenizer-model ALCF/tokenizer.model \ + --no-query-key-layer-scaling --use-rotary-position-embeddings \ + --untie-embeddings-and-output-weights --swiglu \ + --normalization rmsnorm --disable-bias-linear \ + --zero-stage=1 --deepspeed_config=ds_config-gpt_nooffload.json \ + --no-pipeline-parallel --deepspeed --optimizer adamw