Skip to content

Commit

Permalink
[moe] update train script (hpcaitech#4959)
Browse files Browse the repository at this point in the history
* update

* update ckpt

* update train

* update train
  • Loading branch information
oahzxl committed Oct 26, 2023
1 parent 4a7bf29 commit c644b47
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 176 deletions.
17 changes: 9 additions & 8 deletions colossalai/moe/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
for name, param in state_dict.items():
if ".experts." in name:
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
state_dict[name] = param
if name in dict(model.named_parameters()):
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
state_dict[name] = param
dist.barrier()
return state_dict

Expand Down
32 changes: 8 additions & 24 deletions examples/language/openmoe/benchmark/benchmark_cai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from model.modeling_openmoe import OpenMoeForCausalLM
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
Expand All @@ -19,7 +19,7 @@
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import set_moe_args, skip_init
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -218,28 +218,12 @@ def main():
# Build OpenMoe model
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
moe_args = {
"num_experts": config.num_experts,
"moe_layer_interval": config.moe_layer_interval,
"router_topk": 2,
"router_capacity_factor_train": 1.25,
"router_capacity_factor_eval": 2.0,
"router_min_capacity": 4,
"router_noisy_policy": None,
"router_drop_tks": True,
"router_aux_loss_factor": 0.01,
"router_z_loss_factor": 0.01,
"mlp_gated": True,
"label_smoothing": 0.001,
"z_loss_factor": 0.01,
"enable_load_balance": args.load_balance,
"load_balance_tolerance": 0.1,
"load_balance_beam_width": 8,
"load_balance_group_swap_factor": 0.4,
"enable_kernel": args.use_kernel,
"enable_comm_overlap": args.overlap_alltoall,
}
set_moe_args(config, moe_args)
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=args.load_balance,
enable_kernel=args.use_kernel,
enable_comm_overlap=args.overlap_alltoall)
with skip_init():
model = OpenMoeForCausalLM(config)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
Expand Down
61 changes: 11 additions & 50 deletions examples/language/openmoe/infer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from argparse import ArgumentParser

import torch
from model.modeling_openmoe import OpenMoeForCausalLM
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig

from colossalai.moe.utils import set_moe_args


def parse_args():
parser = ArgumentParser()
Expand All @@ -15,59 +13,22 @@ def parse_args():


def inference(args):

tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
if args.model == "test":
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
moe_args = {
"num_experts": config.num_experts,
"moe_layer_interval": config.moe_layer_interval,
"router_topk": 2,
"router_capacity_factor_train": 1.25,
"router_capacity_factor_eval": 2.0,
"router_min_capacity": 4,
"router_noisy_policy": None,
"router_drop_tks": True,
"router_aux_loss_factor": 0.01,
"router_z_loss_factor": 0.01,
"mlp_gated": True,
"label_smoothing": 0.001,
"z_loss_factor": 0.01,
"enable_load_balance": False,
"load_balance_tolerance": 0.1,
"load_balance_beam_width": 8,
"load_balance_group_swap_factor": 0.4,
"enable_kernel": False,
"enable_comm_overlap": False,
}
set_moe_args(config, moe_args)
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_kernel=True)
model = OpenMoeForCausalLM(config)
else:
config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}")
moe_args = {
"num_experts": config.num_experts,
"moe_layer_interval": config.moe_layer_interval,
"router_topk": 2,
"router_capacity_factor_train": 1.25,
"router_capacity_factor_eval": 2.0,
"router_min_capacity": 4,
"router_noisy_policy": None,
"router_drop_tks": True,
"router_aux_loss_factor": 0.01,
"router_z_loss_factor": 0.01,
"mlp_gated": True,
"label_smoothing": 0.001,
"z_loss_factor": 0.01,
"enable_load_balance": False,
"load_balance_tolerance": 0.1,
"load_balance_beam_width": 8,
"load_balance_group_swap_factor": 0.4,
"enable_kernel": False,
"enable_comm_overlap": False,
}
set_moe_args(config, moe_args)
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_kernel=False)
model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config)
model = model.eval().half()
model = model.eval().bfloat16()
model = model.to(torch.cuda.current_device())

input_str = """```
Expand All @@ -86,7 +47,7 @@ def inference(args):
# print("model config: ", model.config)
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=False)
input_ids = input_ids.input_ids.to(torch.cuda.current_device())
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=16)
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64)
out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
print(f"output: \n{out}\n")

Expand Down
78 changes: 75 additions & 3 deletions examples/language/openmoe/model/modeling_openmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe.layers import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.moe.utils import get_activation, set_moe_args

if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
Expand All @@ -49,6 +49,78 @@
_CONFIG_FOR_DOC = "LlamaConfig"


def set_openmoe_args(
config: LlamaConfig,
num_experts: int,
moe_layer_interval: int,
router_topk: int = 2,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
router_noisy_policy: str = None,
router_drop_tks: bool = True,
router_aux_loss_factor: float = 0.01,
router_z_loss_factor: float = 0.01,
mlp_gated: bool = True,
label_smoothing: float = 0.001,
z_loss_factor: float = 0.01,
enable_load_balance: bool = False,
load_balance_tolerance: float = 0.1,
load_balance_beam_width: int = 8,
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
) -> None:
"""
MoE related arguments.
It inserts the MoE arguments into the Llama config.
Args:
config (LlamaConfig): Transformers Llama config.
num_experts (int, optional): Number of experts.
moe_layer_interval (int, optional): The interval moe layer.
router_topk (int, optional): Moe router top k. Defaults to 2.
router_capacity_factor_train (float, optional): Moe router max capacity for train. Defaults to 1.25.
router_capacity_factor_eval (float, optional): Moe router max capacity for eval. Defaults to 2.0.
router_min_capacity (int, optional): Moe router min capacity. Defaults to 4.
router_noisy_policy (str, optional): Moe router noisy policy. You can choose [Jitter, Gaussian, None]. Defaults to None.
router_drop_tks (bool, optional): Whether moe router drop tokens which exceed max capacity. Defaults to True.
router_aux_loss_factor (float, optional): Moe router aux loss. You can refer to STMoE for details. Defaults to 0.01.
router_z_loss_factor (float, optional): Moe router z loss. You can refer to STMoE for details. Defaults to 0.01.
mlp_gated (bool, optional): Use gate in mlp. Defaults to True.
label_smoothing (float, optional): Label smoothing. Defaults to 0.001.
z_loss_factor (float, optional): The final outputs' classification z loss factor. Defaults to 0.01.
enable_load_balance (bool, optional): Expert load balance. Defaults to False.
load_balance_tolerance (float, optional): Expert load balance search's difference tolerance. Defaults to 0.1.
load_balance_beam_width (int, optional): Expert load balance search's beam width. Defaults to 8.
load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4.
enable_kernel (bool, optional): Use kernel optimization. Defaults to False.
enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False.
"""
moe_args = dict(
num_experts=num_experts,
moe_layer_interval=moe_layer_interval,
router_topk=router_topk,
router_capacity_factor_train=router_capacity_factor_train,
router_capacity_factor_eval=router_capacity_factor_eval,
router_min_capacity=router_min_capacity,
router_noisy_policy=router_noisy_policy,
router_drop_tks=router_drop_tks,
router_aux_loss_factor=router_aux_loss_factor,
router_z_loss_factor=router_z_loss_factor,
mlp_gated=mlp_gated,
label_smoothing=label_smoothing,
z_loss_factor=z_loss_factor,
enable_load_balance=enable_load_balance,
load_balance_tolerance=load_balance_tolerance,
load_balance_beam_width=load_balance_beam_width,
load_balance_group_swap_factor=load_balance_group_swap_factor,
enable_kernel=enable_kernel,
enable_comm_overlap=enable_comm_overlap,
)
set_moe_args(config, moe_args)


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
Expand Down Expand Up @@ -96,7 +168,7 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc
output_sin: a float32 Tensor with shape [length, features]
output_cos: a float32 Tensor with shape [length, features]
"""
fraction = torch.arange(0, features, 2, dtype=torch.float64).cuda() / features
fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features
timescale = min_timescale * (max_timescale / min_timescale)**fraction
rotational_frequency = 1. / timescale

Expand Down Expand Up @@ -231,7 +303,7 @@ def __init__(self, config: LlamaConfig):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4)
self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
Expand Down
8 changes: 7 additions & 1 deletion examples/language/openmoe/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ python infer.py --model "test"
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin zero2_ep \
--plugin "ep" \
--batch_size 1

torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1

torchrun --standalone --nproc_per_node 4 train.py \
Expand Down
Loading

0 comments on commit c644b47

Please sign in to comment.