Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mock model training demo【dont merge】 #497

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions projects/mock_gpt_train/configs/gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from projects.mock_transformers import init_env # noqa
from projects.mock_transformers.dist_infer_gpt import *
from libai.config import LazyCall
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config

cfg = LazyCall(GPT2Config)(vocab_size=50257)

gpt_model = LazyCall(GPT2Model)(config=cfg)

pretrain_model = LazyCall(GPT2LMHeadModel)(config=cfg)
61 changes: 61 additions & 0 deletions projects/mock_gpt_train/configs/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from libai.config import LazyCall
from libai.evaluation import PPLEvaluator
from projects.mock_gpt_train.configs.gpt import pretrain_model as model
from projects.MagicPrompt.configs.gpt2_dataset import dataloader, tokenization
from configs.common.optim import optim

from libai.scheduler import WarmupExponentialLR

from configs.common.train import train
from configs.common.models.graph import graph

graph.global_mode.enabled = True
# graph.enabled = False
vocab_file = "/data/home/magicprompt/vocab.json"
merge_files = "/data/home/magicprompt/merges.txt"
train_data_prefix = "/data/home/magicprompt/train/en_train_mmap_text_sentence"

tokenization.tokenizer.vocab_file = vocab_file
tokenization.tokenizer.merges_file = merge_files
dataloader.train.dataset[0].data_prefix = train_data_prefix
dataloader.train.dataset[0].indexed_dataset.data_prefix = train_data_prefix

train.dist.pipeline_num_layers = 12

for ds in dataloader.train.dataset:
ds.max_seq_length = 1024

optim.lr = 5.0e-05

train.update(
dict(
output_dir="projects/MagicPrompt/oneflow_magicprompt",
train_micro_batch_size=4,
test_micro_batch_size=4,
train_epoch=33,
train_iter=10000,
log_period=50,
amp=dict(enabled=False),
warmup_ratio=0,
checkpointer=dict(period=8000, max_to_keep=20),
dist=dict(
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
# pipeline_num_layers=model.cfg.hidden_layers,
),
scheduler=LazyCall(WarmupExponentialLR)(
warmup_factor=0.0,
gamma=1.0,
warmup_method="linear",
warmup_iter=0.0,
),
evaluation=dict(
enabled=False,
evaluator=LazyCall(PPLEvaluator)(),
eval_iter=250,
eval_period=4000,
),
rdma_enabled=False,
)
)
98 changes: 90 additions & 8 deletions projects/mock_transformers/dist_infer_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import init_env # noqa
import oneflow as flow
from omegaconf import DictConfig
from oneflow.utils.global_view import global_mode
from transformers import AutoModelForCausalLM, AutoTokenizer, pytorch_utils
from transformers.models.gpt2 import modeling_gpt2

from libai.layers import Conv1D
from libai.layers import Embedding as LiBaiEmbedding
from libai.utils import distributed as dist
from projects.mock_transformers import init_env # noqa

# ------replace Embedding to libai------
temp_class = modeling_gpt2.GPT2Model


class LiBaiGPTModel(temp_class):
def __init__(self, config):
super().__init__(config)
self.wte = LiBaiEmbedding(config.vocab_size, self.embed_dim)
self.wpe = LiBaiEmbedding(config.max_position_embeddings, self.embed_dim)


modeling_gpt2.GPT2Model = LiBaiGPTModel


# ------replace Conv1D to libai------
Expand Down Expand Up @@ -65,26 +79,26 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
in_features=self.embed_dim,
out_features=2 * self.embed_dim,
parallel="col",
dtype=flow.float16,
dtype=flow.float32,
)
self.q_attn = Conv1D(
in_features=self.embed_dim,
out_features=self.embed_dim,
parallel="col",
dtype=flow.float16,
dtype=flow.float32,
)
else:
self.c_attn = Conv1D(
in_features=self.embed_dim,
out_features=3 * self.embed_dim,
parallel="col",
dtype=flow.float16,
dtype=flow.float32,
)
self.c_proj = Conv1D(
in_features=self.embed_dim,
out_features=self.embed_dim,
parallel="row",
dtype=flow.float16,
dtype=flow.float32,
)


Expand All @@ -103,16 +117,84 @@ def __init__(self, intermediate_size, config):
in_features=embed_dim,
out_features=intermediate_size,
parallel="col",
dtype=flow.float16,
dtype=flow.float32,
)
self.c_proj = Conv1D(
in_features=intermediate_size,
out_features=embed_dim,
parallel="row",
dtype=flow.float16,
dtype=flow.float32,
)


modeling_gpt2.GPT2MLP = LiBaiGPT2MLP


# ------replace Loss Function to libai------


class GPT2Loss(flow.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, logits, lable):
lable = lable.to_global(placement=logits.placement)
loss = flow._C.sparse_softmax_cross_entropy(logits, lable)
loss = loss.mean()
return loss


modeling_gpt2.CrossEntropyLoss = GPT2Loss


# ------replace model return type to libai------
temp_class = modeling_gpt2.GPT2LMHeadModel


class LiBaiGPT2LMHeadModel(temp_class):
def __init__(self, config):
super().__init__(config)

def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
out = super().forward(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return {
"loss": out.loss,
}


modeling_gpt2.GPT2LMHeadModel = LiBaiGPT2LMHeadModel

if __name__ == "__main__":
# set dist config
parallel_config = DictConfig(
Expand All @@ -127,7 +209,7 @@ def __init__(self, intermediate_size, config):
dist.setup_dist_util(parallel_config)

# initial and load model
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=flow.float16)
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=flow.float32)
# set model to cuda
dist.set_device_type("cuda")
model._apply(dist.convert_to_distributed_default_setting)
Expand Down