From 737f18756b5a659b49ea3e9c0c0e15f3e408fd8a Mon Sep 17 00:00:00 2001 From: xiezipeng-ML Date: Tue, 18 Apr 2023 08:09:04 +0000 Subject: [PATCH 1/2] mock model training demo --- projects/mock_gpt_train/configs/gpt.py | 11 ++ projects/mock_gpt_train/configs/training.py | 63 ++++++++++ projects/mock_transformers/dist_infer_gpt.py | 117 +++++++++++-------- 3 files changed, 141 insertions(+), 50 deletions(-) create mode 100644 projects/mock_gpt_train/configs/gpt.py create mode 100644 projects/mock_gpt_train/configs/training.py diff --git a/projects/mock_gpt_train/configs/gpt.py b/projects/mock_gpt_train/configs/gpt.py new file mode 100644 index 000000000..73646d904 --- /dev/null +++ b/projects/mock_gpt_train/configs/gpt.py @@ -0,0 +1,11 @@ +from projects.mock_transformers import init_env # noqa +from projects.mock_transformers.dist_infer_gpt import * +from libai.config import LazyCall +from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config + +cfg = LazyCall(GPT2Config)(vocab_size=50257) + +gpt_model = LazyCall(GPT2Model)(config = cfg) + +pretrain_model = LazyCall(GPT2LMHeadModel)(config=cfg) + diff --git a/projects/mock_gpt_train/configs/training.py b/projects/mock_gpt_train/configs/training.py new file mode 100644 index 000000000..ea4ed8364 --- /dev/null +++ b/projects/mock_gpt_train/configs/training.py @@ -0,0 +1,63 @@ +from libai.config import LazyCall +from libai.evaluation import PPLEvaluator +from projects.mock_gpt_train.configs.gpt import pretrain_model as model +from projects.MagicPrompt.configs.gpt2_dataset import dataloader, tokenization +from configs.common.optim import optim + +from libai.scheduler import WarmupExponentialLR + +from configs.common.train import train +from configs.common.models.graph import graph + +graph.global_mode.enabled = True +# graph.enabled = False +vocab_file = "/data/home/magicprompt/vocab.json" +merge_files = "/data/home/magicprompt/merges.txt" +train_data_prefix = "/data/home/magicprompt/train/en_train_mmap_text_sentence" + +tokenization.tokenizer.vocab_file = vocab_file +tokenization.tokenizer.merges_file = merge_files +dataloader.train.dataset[0].data_prefix = train_data_prefix +dataloader.train.dataset[0].indexed_dataset.data_prefix = train_data_prefix + +train.input_placement_device = "cpu" + +train.dist.pipeline_num_layers = 12 + +for ds in dataloader.train.dataset: + ds.max_seq_length = 1024 + +optim.lr = 5.0e-05 + +train.update( + dict( + output_dir="projects/MagicPrompt/oneflow_magicprompt", + train_micro_batch_size=4, + test_micro_batch_size=4, + train_epoch=33, + train_iter=10000, + log_period=50, + amp=dict(enabled=False), + warmup_ratio=0, + checkpointer=dict(period=8000, max_to_keep=20), + dist=dict( + data_parallel_size=1, + tensor_parallel_size=1, + pipeline_parallel_size=1, + # pipeline_num_layers=model.cfg.hidden_layers, + ), + scheduler=LazyCall(WarmupExponentialLR)( + warmup_factor=0.0, + gamma=1.0, + warmup_method="linear", + warmup_iter=0.0, + ), + evaluation=dict( + enabled=False, + evaluator=LazyCall(PPLEvaluator)(), + eval_iter=250, + eval_period=4000, + ), + rdma_enabled=False, + ) +) diff --git a/projects/mock_transformers/dist_infer_gpt.py b/projects/mock_transformers/dist_infer_gpt.py index 9299d98f3..0b79ed63b 100644 --- a/projects/mock_transformers/dist_infer_gpt.py +++ b/projects/mock_transformers/dist_infer_gpt.py @@ -13,15 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import init_env # noqa +from projects.mock_transformers import init_env # noqa import oneflow as flow from omegaconf import DictConfig from oneflow.utils.global_view import global_mode from transformers import AutoModelForCausalLM, AutoTokenizer, pytorch_utils from transformers.models.gpt2 import modeling_gpt2 +from libai.layers import ParallelCrossEntropyLoss from libai.layers import Conv1D -from libai.utils import distributed as dist +from libai.layers import Embedding as LiBaiEmbedding + + +# ------replace Embedding to libai------ +temp_class = modeling_gpt2.GPT2Model + + +class LiBaiGPTModel(temp_class): + def __init__(self, config): + super().__init__(config) + self.wte = LiBaiEmbedding(config.vocab_size, self.embed_dim) + self.wpe = LiBaiEmbedding(config.max_position_embeddings, self.embed_dim) + + +modeling_gpt2.GPT2Model = LiBaiGPTModel # ------replace Conv1D to libai------ @@ -65,26 +80,26 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): in_features=self.embed_dim, out_features=2 * self.embed_dim, parallel="col", - dtype=flow.float16, + dtype=flow.float32, ) self.q_attn = Conv1D( in_features=self.embed_dim, out_features=self.embed_dim, parallel="col", - dtype=flow.float16, + dtype=flow.float32, ) else: self.c_attn = Conv1D( in_features=self.embed_dim, out_features=3 * self.embed_dim, parallel="col", - dtype=flow.float16, + dtype=flow.float32, ) self.c_proj = Conv1D( in_features=self.embed_dim, out_features=self.embed_dim, parallel="row", - dtype=flow.float16, + dtype=flow.float32, ) @@ -103,54 +118,56 @@ def __init__(self, intermediate_size, config): in_features=embed_dim, out_features=intermediate_size, parallel="col", - dtype=flow.float16, + dtype=flow.float32, ) self.c_proj = Conv1D( in_features=intermediate_size, out_features=embed_dim, parallel="row", - dtype=flow.float16, + dtype=flow.float32, ) - -if __name__ == "__main__": - # set dist config - parallel_config = DictConfig( - dict( - data_parallel_size=1, - tensor_parallel_size=2, - pipeline_parallel_size=1, # set to 1, unsupport pipeline parallel now - pipeline_num_layers=None, - device_type="cpu", - ) - ) - dist.setup_dist_util(parallel_config) - - # initial and load model - model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=flow.float16) - # set model to cuda - dist.set_device_type("cuda") - model._apply(dist.convert_to_distributed_default_setting) - # initial tokenizer - tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False) - - # get input_ids - prompt = "Hello, I'm a language model," - input_ids = tokenizer(prompt, return_tensors="np").input_ids - input_ids = flow.from_numpy(input_ids) - input_ids = input_ids.to_global( - sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), - placement=dist.get_layer_placement(0), - ) - - # generate id - placement_sbp_dict = dict( - placement=flow.env.all_device_placement("cuda"), - sbp=flow.sbp.broadcast, - ) - with global_mode(True, **placement_sbp_dict): - generated_ids = model.generate(input_ids, max_length=30) - out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - - if dist.is_main_process(): - print(out_put_ids) +modeling_gpt2.CrossEntropyLoss = ParallelCrossEntropyLoss + + +# if __name__ == "__main__": +# # set dist config +# parallel_config = DictConfig( +# dict( +# data_parallel_size=1, +# tensor_parallel_size=2, +# pipeline_parallel_size=1, # set to 1, unsupport pipeline parallel now +# pipeline_num_layers=None, +# device_type="cpu", +# ) +# ) +# dist.setup_dist_util(parallel_config) + +# # initial and load model +# model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=flow.float32) +# # set model to cuda +# dist.set_device_type("cuda") +# model._apply(dist.convert_to_distributed_default_setting) +# # initial tokenizer +# tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False) + +# # get input_ids +# prompt = "Hello, I'm a language model," +# input_ids = tokenizer(prompt, return_tensors="np").input_ids +# input_ids = flow.from_numpy(input_ids) +# input_ids = input_ids.to_global( +# sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), +# placement=dist.get_layer_placement(0), +# ) + +# # generate id +# placement_sbp_dict = dict( +# placement=flow.env.all_device_placement("cuda"), +# sbp=flow.sbp.broadcast, +# ) +# with global_mode(True, **placement_sbp_dict): +# generated_ids = model.generate(input_ids, max_length=30) +# out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + +# if dist.is_main_process(): +# print(out_put_ids) From 324cb114fd0f23fdde7da16b1d0ab44bc9d34aec Mon Sep 17 00:00:00 2001 From: xiezipeng-ML Date: Wed, 19 Apr 2023 02:45:44 +0000 Subject: [PATCH 2/2] refine --- projects/mock_gpt_train/configs/gpt.py | 3 +- projects/mock_gpt_train/configs/training.py | 2 - projects/mock_transformers/dist_infer_gpt.py | 159 +++++++++++++------ 3 files changed, 113 insertions(+), 51 deletions(-) diff --git a/projects/mock_gpt_train/configs/gpt.py b/projects/mock_gpt_train/configs/gpt.py index 73646d904..26b6430af 100644 --- a/projects/mock_gpt_train/configs/gpt.py +++ b/projects/mock_gpt_train/configs/gpt.py @@ -5,7 +5,6 @@ cfg = LazyCall(GPT2Config)(vocab_size=50257) -gpt_model = LazyCall(GPT2Model)(config = cfg) +gpt_model = LazyCall(GPT2Model)(config=cfg) pretrain_model = LazyCall(GPT2LMHeadModel)(config=cfg) - diff --git a/projects/mock_gpt_train/configs/training.py b/projects/mock_gpt_train/configs/training.py index ea4ed8364..7d9f2dcd0 100644 --- a/projects/mock_gpt_train/configs/training.py +++ b/projects/mock_gpt_train/configs/training.py @@ -20,8 +20,6 @@ dataloader.train.dataset[0].data_prefix = train_data_prefix dataloader.train.dataset[0].indexed_dataset.data_prefix = train_data_prefix -train.input_placement_device = "cpu" - train.dist.pipeline_num_layers = 12 for ds in dataloader.train.dataset: diff --git a/projects/mock_transformers/dist_infer_gpt.py b/projects/mock_transformers/dist_infer_gpt.py index 0b79ed63b..0973b38ec 100644 --- a/projects/mock_transformers/dist_infer_gpt.py +++ b/projects/mock_transformers/dist_infer_gpt.py @@ -13,17 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from projects.mock_transformers import init_env # noqa import oneflow as flow from omegaconf import DictConfig from oneflow.utils.global_view import global_mode from transformers import AutoModelForCausalLM, AutoTokenizer, pytorch_utils from transformers.models.gpt2 import modeling_gpt2 -from libai.layers import ParallelCrossEntropyLoss from libai.layers import Conv1D from libai.layers import Embedding as LiBaiEmbedding - +from libai.utils import distributed as dist +from projects.mock_transformers import init_env # noqa # ------replace Embedding to libai------ temp_class = modeling_gpt2.GPT2Model @@ -127,47 +126,113 @@ def __init__(self, intermediate_size, config): dtype=flow.float32, ) -modeling_gpt2.CrossEntropyLoss = ParallelCrossEntropyLoss - - -# if __name__ == "__main__": -# # set dist config -# parallel_config = DictConfig( -# dict( -# data_parallel_size=1, -# tensor_parallel_size=2, -# pipeline_parallel_size=1, # set to 1, unsupport pipeline parallel now -# pipeline_num_layers=None, -# device_type="cpu", -# ) -# ) -# dist.setup_dist_util(parallel_config) - -# # initial and load model -# model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=flow.float32) -# # set model to cuda -# dist.set_device_type("cuda") -# model._apply(dist.convert_to_distributed_default_setting) -# # initial tokenizer -# tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False) - -# # get input_ids -# prompt = "Hello, I'm a language model," -# input_ids = tokenizer(prompt, return_tensors="np").input_ids -# input_ids = flow.from_numpy(input_ids) -# input_ids = input_ids.to_global( -# sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), -# placement=dist.get_layer_placement(0), -# ) - -# # generate id -# placement_sbp_dict = dict( -# placement=flow.env.all_device_placement("cuda"), -# sbp=flow.sbp.broadcast, -# ) -# with global_mode(True, **placement_sbp_dict): -# generated_ids = model.generate(input_ids, max_length=30) -# out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - -# if dist.is_main_process(): -# print(out_put_ids) + +modeling_gpt2.GPT2MLP = LiBaiGPT2MLP + + +# ------replace Loss Function to libai------ + + +class GPT2Loss(flow.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, logits, lable): + lable = lable.to_global(placement=logits.placement) + loss = flow._C.sparse_softmax_cross_entropy(logits, lable) + loss = loss.mean() + return loss + + +modeling_gpt2.CrossEntropyLoss = GPT2Loss + + +# ------replace model return type to libai------ +temp_class = modeling_gpt2.GPT2LMHeadModel + + +class LiBaiGPT2LMHeadModel(temp_class): + def __init__(self, config): + super().__init__(config) + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + out = super().forward( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return { + "loss": out.loss, + } + + +modeling_gpt2.GPT2LMHeadModel = LiBaiGPT2LMHeadModel + +if __name__ == "__main__": + # set dist config + parallel_config = DictConfig( + dict( + data_parallel_size=1, + tensor_parallel_size=2, + pipeline_parallel_size=1, # set to 1, unsupport pipeline parallel now + pipeline_num_layers=None, + device_type="cpu", + ) + ) + dist.setup_dist_util(parallel_config) + + # initial and load model + model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=flow.float32) + # set model to cuda + dist.set_device_type("cuda") + model._apply(dist.convert_to_distributed_default_setting) + # initial tokenizer + tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False) + + # get input_ids + prompt = "Hello, I'm a language model," + input_ids = tokenizer(prompt, return_tensors="np").input_ids + input_ids = flow.from_numpy(input_ids) + input_ids = input_ids.to_global( + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=dist.get_layer_placement(0), + ) + + # generate id + placement_sbp_dict = dict( + placement=flow.env.all_device_placement("cuda"), + sbp=flow.sbp.broadcast, + ) + with global_mode(True, **placement_sbp_dict): + generated_ids = model.generate(input_ids, max_length=30) + out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + if dist.is_main_process(): + print(out_put_ids)