Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 5, 2023
1 parent 1adaf3c commit 9c887f1
Showing 1 changed file with 42 additions and 56 deletions.
98 changes: 42 additions & 56 deletions scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,47 +54,39 @@

import argparse
import os

import pytorch_lightning as pl
import torch
import yaml
import pytorch_lightning as pl
from omegaconf import OmegaConf

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron import GPTModel

from omegaconf import OmegaConf
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.utils import logging


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input',
required=True,
type=str,
help='path to the two MPT-7B .bin weight files from HuggingFace')
parser.add_argument('-c', '--config',
required=True,
type=str,
help='the path to the megatron_gpt_config.yaml file')
parser.add_argument('-o', '--output',
required=False,
default=None,
type=str,
help='path to dir where to store output .nemo file')
parser.add_argument('--cuda',
action='store_true',
help='put Nemo model onto GPU prior to savedown')

parser.add_argument(
'-i', '--input', required=True, type=str, help='path to the two MPT-7B .bin weight files from HuggingFace'
)
parser.add_argument(
'-c', '--config', required=True, type=str, help='the path to the megatron_gpt_config.yaml file'
)
parser.add_argument(
'-o', '--output', required=False, default=None, type=str, help='path to dir where to store output .nemo file'
)
parser.add_argument('--cuda', action='store_true', help='put Nemo model onto GPU prior to savedown')

args = parser.parse_args()



if not os.path.exists(args.input):
logging.critical(f'Input directory [ {args.input} ] does not exist or cannot be found. Aborting.')
exit(255)

if not os.path.exists(args.config):
logging.critical(f'Path to config file [ {args.config} ] does not exist or cannot be found. Aborting.')
exit(255)

with open(args.config, 'r', encoding='utf_8') as fr:
orig_cfg = yaml.safe_load(fr)

Expand All @@ -103,7 +95,7 @@
del model_dict['tokenizer']
if 'data' in model_dict:
del model_dict['data']

override_model_dict = {
'micro_batch_size': 1,
'global_batch_size': 4,
Expand All @@ -119,7 +111,7 @@
'max_position_embeddings': 2048,
'num_layers': 32,
'num_attention_heads': 32,
'ffn_hidden_size': 4*4096,
'ffn_hidden_size': 4 * 4096,
'precision': 'bf16',
'layernorm_epsilon': 1e-5,
'pre_process': True,
Expand Down Expand Up @@ -154,11 +146,11 @@
'num_nodes': 1,
'accelerator': 'gpu' if args.cuda else 'cpu',
'precision': 'bf16',
'logger': False, # logger provided by exp_manager
'logger': False, # logger provided by exp_manager
'enable_checkpointing': False,
'replace_sampler_ddp': False,
'max_epochs': -1, # PTL default. In practice, max_steps will be reached first.
'max_steps': 100000, # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
'max_epochs': -1, # PTL default. In practice, max_steps will be reached first.
'max_steps': 100000, # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
'log_every_n_steps': 10,
'val_check_interval': 100,
'limit_val_batches': 50,
Expand All @@ -171,31 +163,27 @@

model_dict.update(override_model_dict)
model_dict['tokenizer'] = tokeniser_dict

omega_cfg = OmegaConf.create(model_dict)

trainer = pl.Trainer(**trainer_dict)



model = MegatronGPTModel(omega_cfg, trainer)



model_keys = list(model.state_dict().keys())
model_dtypes = list(set([model.state_dict()[x].dtype for x in model_keys]))

if not (len(model_dtypes) == 1 and model_dtypes[0] is torch.bfloat16):
model = model.bfloat16()

if args.cuda:
model = model.cuda()



mpt_1 = torch.load(os.path.join(args.input, 'pytorch_model-00001-of-00002.bin'), map_location="cpu")
mpt_2 = torch.load(os.path.join(args.input, 'pytorch_model-00002-of-00002.bin'), map_location="cpu")
mpt_dict = {**mpt_1, **mpt_2}
del mpt_1, mpt_2



def convert_state_dict(state_dict, amp=False):
def get_new_key(old_key):
if old_key == 'transformer.wte.weight':
Expand All @@ -210,39 +198,37 @@ def get_new_key(old_key):
p5 = p4.replace('norm_2.weight', 'post_attention_layernorm.weight')
p6 = p5.replace('ffn.up_proj.weight', 'mlp.dense_h_to_4h.weight')
p7 = p6.replace('ffn.down_proj.weight', 'mlp.dense_4h_to_h.weight')

return p7

new_dict = {}

for old_key, val in state_dict.items():
new_key = get_new_key(old_key)
if amp:
new_key = 'module.' + new_key

new_dict[new_key] = val

return new_dict



convert_dict = convert_state_dict(mpt_dict, amp=model_dict['megatron_amp_O2'])

if model_dict['megatron_amp_O2']:
missing_keys, unexpected_keys = model.model.load_state_dict(convert_dict, strict=True)
else:
missing_keys, unexpected_keys = super(GPTModel, model.model).load_state_dict(convert_dict, strict=True)

if len(missing_keys) > 0:
logging.critical('Missing keys were detected during the load, something has gone wrong. Aborting.')
logging.critical(f'Missing keys: \n{missing_keys}')
exit(255)



if len(unexpected_keys) > 0:
logging.warning('Unexpected keys were detected which should not happen. Please investigate.')
logging.warning(f'Unexpected keys: \n{unexpected_keys}')

if args.output is None:
args.output = os.path.dirname(os.path.abspath(__file__))

model.save_to(os.path.join(args.output, 'megatron_mpt_7b_base_tp1_pp1.nemo'))

0 comments on commit 9c887f1

Please sign in to comment.