diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index 99e0ae811bbd..386fc2010805 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -48,14 +48,15 @@ def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: """ for name, param in state_dict.items(): if ".experts." in name: - model_param = dict(model.named_parameters())[name] - if is_moe_tensor(model_param): - ep_rank = get_ep_rank(model_param) - ep_size = get_ep_size(model_param) - expert_num = param.shape[0] // ep_size - assert param.shape[0] % ep_size == 0 - param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] - state_dict[name] = param + if name in dict(model.named_parameters()): + model_param = dict(model.named_parameters())[name] + if is_moe_tensor(model_param): + ep_rank = get_ep_rank(model_param) + ep_size = get_ep_size(model_param) + expert_num = param.shape[0] // ep_size + assert param.shape[0] % ep_size == 0 + param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] + state_dict[name] = param dist.barrier() return state_dict diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 2f6bfa0f89a2..1a158eabc151 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist from huggingface_hub import snapshot_download -from model.modeling_openmoe import OpenMoeForCausalLM +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm @@ -19,7 +19,7 @@ from colossalai.cluster import DistCoordinator from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import set_moe_args, skip_init +from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -218,28 +218,12 @@ def main(): # Build OpenMoe model repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - moe_args = { - "num_experts": config.num_experts, - "moe_layer_interval": config.moe_layer_interval, - "router_topk": 2, - "router_capacity_factor_train": 1.25, - "router_capacity_factor_eval": 2.0, - "router_min_capacity": 4, - "router_noisy_policy": None, - "router_drop_tks": True, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.01, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": args.load_balance, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": args.use_kernel, - "enable_comm_overlap": args.overlap_alltoall, - } - set_moe_args(config, moe_args) + set_openmoe_args(config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_load_balance=args.load_balance, + enable_kernel=args.use_kernel, + enable_comm_overlap=args.overlap_alltoall) with skip_init(): model = OpenMoeForCausalLM(config) coordinator.print_on_master(f"Finish init model with config:\n{config}") diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index 1ad1456b9c56..db90c6e34507 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -1,12 +1,10 @@ from argparse import ArgumentParser import torch -from model.modeling_openmoe import OpenMoeForCausalLM +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig -from colossalai.moe.utils import set_moe_args - def parse_args(): parser = ArgumentParser() @@ -15,59 +13,22 @@ def parse_args(): def inference(args): - tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") if args.model == "test": config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") - moe_args = { - "num_experts": config.num_experts, - "moe_layer_interval": config.moe_layer_interval, - "router_topk": 2, - "router_capacity_factor_train": 1.25, - "router_capacity_factor_eval": 2.0, - "router_min_capacity": 4, - "router_noisy_policy": None, - "router_drop_tks": True, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.01, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": False, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": False, - "enable_comm_overlap": False, - } - set_moe_args(config, moe_args) + set_openmoe_args(config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_kernel=True) model = OpenMoeForCausalLM(config) else: config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}") - moe_args = { - "num_experts": config.num_experts, - "moe_layer_interval": config.moe_layer_interval, - "router_topk": 2, - "router_capacity_factor_train": 1.25, - "router_capacity_factor_eval": 2.0, - "router_min_capacity": 4, - "router_noisy_policy": None, - "router_drop_tks": True, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.01, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": False, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": False, - "enable_comm_overlap": False, - } - set_moe_args(config, moe_args) + set_openmoe_args(config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_kernel=False) model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config) - model = model.eval().half() + model = model.eval().bfloat16() model = model.to(torch.cuda.current_device()) input_str = """``` @@ -86,7 +47,7 @@ def inference(args): # print("model config: ", model.config) input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=False) input_ids = input_ids.input_ids.to(torch.cuda.current_device()) - generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=16) + generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64) out = tokenizer.decode(generation_output[0], skip_special_tokens=False) print(f"output: \n{out}\n") diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 6f9b668e4597..7d28de731407 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -39,7 +39,7 @@ from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import get_activation +from colossalai.moe.utils import get_activation, set_moe_args if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -49,6 +49,78 @@ _CONFIG_FOR_DOC = "LlamaConfig" +def set_openmoe_args( + config: LlamaConfig, + num_experts: int, + moe_layer_interval: int, + router_topk: int = 2, + router_capacity_factor_train: float = 1.25, + router_capacity_factor_eval: float = 2.0, + router_min_capacity: int = 4, + router_noisy_policy: str = None, + router_drop_tks: bool = True, + router_aux_loss_factor: float = 0.01, + router_z_loss_factor: float = 0.01, + mlp_gated: bool = True, + label_smoothing: float = 0.001, + z_loss_factor: float = 0.01, + enable_load_balance: bool = False, + load_balance_tolerance: float = 0.1, + load_balance_beam_width: int = 8, + load_balance_group_swap_factor: float = 0.4, + enable_kernel: bool = False, + enable_comm_overlap: bool = False, +) -> None: + """ + MoE related arguments. + It inserts the MoE arguments into the Llama config. + + Args: + config (LlamaConfig): Transformers Llama config. + num_experts (int, optional): Number of experts. + moe_layer_interval (int, optional): The interval moe layer. + router_topk (int, optional): Moe router top k. Defaults to 2. + router_capacity_factor_train (float, optional): Moe router max capacity for train. Defaults to 1.25. + router_capacity_factor_eval (float, optional): Moe router max capacity for eval. Defaults to 2.0. + router_min_capacity (int, optional): Moe router min capacity. Defaults to 4. + router_noisy_policy (str, optional): Moe router noisy policy. You can choose [Jitter, Gaussian, None]. Defaults to None. + router_drop_tks (bool, optional): Whether moe router drop tokens which exceed max capacity. Defaults to True. + router_aux_loss_factor (float, optional): Moe router aux loss. You can refer to STMoE for details. Defaults to 0.01. + router_z_loss_factor (float, optional): Moe router z loss. You can refer to STMoE for details. Defaults to 0.01. + mlp_gated (bool, optional): Use gate in mlp. Defaults to True. + label_smoothing (float, optional): Label smoothing. Defaults to 0.001. + z_loss_factor (float, optional): The final outputs' classification z loss factor. Defaults to 0.01. + enable_load_balance (bool, optional): Expert load balance. Defaults to False. + load_balance_tolerance (float, optional): Expert load balance search's difference tolerance. Defaults to 0.1. + load_balance_beam_width (int, optional): Expert load balance search's beam width. Defaults to 8. + load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4. + enable_kernel (bool, optional): Use kernel optimization. Defaults to False. + enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False. + """ + moe_args = dict( + num_experts=num_experts, + moe_layer_interval=moe_layer_interval, + router_topk=router_topk, + router_capacity_factor_train=router_capacity_factor_train, + router_capacity_factor_eval=router_capacity_factor_eval, + router_min_capacity=router_min_capacity, + router_noisy_policy=router_noisy_policy, + router_drop_tks=router_drop_tks, + router_aux_loss_factor=router_aux_loss_factor, + router_z_loss_factor=router_z_loss_factor, + mlp_gated=mlp_gated, + label_smoothing=label_smoothing, + z_loss_factor=z_loss_factor, + enable_load_balance=enable_load_balance, + load_balance_tolerance=load_balance_tolerance, + load_balance_beam_width=load_balance_beam_width, + load_balance_group_swap_factor=load_balance_group_swap_factor, + enable_kernel=enable_kernel, + enable_comm_overlap=enable_comm_overlap, + ) + set_moe_args(config, moe_args) + + # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, @@ -96,7 +168,7 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc output_sin: a float32 Tensor with shape [length, features] output_cos: a float32 Tensor with shape [length, features] """ - fraction = torch.arange(0, features, 2, dtype=torch.float64).cuda() / features + fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features timescale = min_timescale * (max_timescale / min_timescale)**fraction rotational_frequency = 1. / timescale @@ -231,7 +303,7 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4) + self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 0f68db4275f7..71198d8756d0 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -7,7 +7,13 @@ python infer.py --model "test" torchrun --standalone --nproc_per_node 4 train.py \ --num_epoch 1 \ --model_name "test" \ - --plugin zero2_ep \ + --plugin "ep" \ + --batch_size 1 + +torchrun --standalone --nproc_per_node 4 train.py \ + --num_epoch 1 \ + --model_name "test" \ + --plugin "ep_zero" \ --batch_size 1 torchrun --standalone --nproc_per_node 4 train.py \ diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index ec9ec21b55dc..19bc70e1c4f5 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -2,9 +2,10 @@ import datasets import torch +import torch.distributed as dist import transformers from huggingface_hub import snapshot_download -from model.modeling_openmoe import OpenMoeForCausalLM +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm @@ -17,9 +18,9 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.moe import MoeCheckpintIO +from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import set_moe_args, skip_init +from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -42,7 +43,31 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): + """ + A random dataset + + You can use tokenizer to process your own data + Example: + self.input_ids = [] + self.attention_mask = [] + data = your_data() + data = shuffle(data) + for text in data: + # text is a str + encode = tokenizer( + "" + text, + return_tensors="pt", + add_special_tokens=False, + max_length=max_length, + truncation=True, + padding="max_length") + self.input_ids.append(encode["input_ids"]) + self.attention_mask.append(encode["attention_mask"]) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + """ + # TODO: use distributed sampler self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) @@ -88,20 +113,38 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero1_ep", "zero2_ep", "hybrid"], + choices=["ep", "ep_zero", "hybrid"], ) + + # optim + parser.add_argument("--decay_rate", type=float, default=-0.8, help="adafactor optim decay rate.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + + # zero stage for all plugins + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin") + + # ep zero plugin + parser.add_argument("--extra_dp_size", type=int, default=1, help="ep zero's moe dp size") + # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") parser.add_argument("--dp_size", type=int, default=1, help="dp size") parser.add_argument("--ep_size", type=int, default=2, help="ep size") - parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + # kernel parser.add_argument( "--use_kernel", action="store_true", - help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex.", ) + # loss parser.add_argument( "--router_aux_loss_factor", @@ -117,9 +160,13 @@ def parse_args(): ) parser.add_argument("--label_smoothing", type=float, default=0.0, help="label_smoothing.") parser.add_argument("--z_loss_factor", type=float, default=0.0001, help="z_loss_factor.") - # optim - parser.add_argument("--decay_rate", type=float, default=-0.8, help="adafactor optim decay rate.") - parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + + # load balance + parser.add_argument("--load_balance", action="store_true", help="moe load balance") + parser.add_argument("--load_balance_interval", type=int, default=1000, help="moe load balance interval") + + # overlap + parser.add_argument("--comm_overlap", action="store_true", help="moe comm overlap") args = parser.parse_args() return args @@ -145,49 +192,57 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == "zero1_ep": + hybrid_dict = { + "tp_size": 1, + "custom_policy": OpenMoeForCausalLMPolicy(), + "enable_fused_normalization": args.use_layernorm_kernel, + "enable_jit_fused": args.use_kernel, + "precision": "bf16", + "zero_stage": args.zero_stage, + } + mgr_dict = { + "seed": 42, + } + if args.plugin == "ep": + dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, - zero_stage=1, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", + max_ep_size=dp_size, + **mgr_dict, ) - elif args.plugin == "zero2_ep": + elif args.plugin == "ep_zero": + dp_size = dist.get_world_size() + use_ep_inside = False plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, - zero_stage=2, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + extra_dp_size=args.extra_dp_size, + use_ep_inside=use_ep_inside, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", + max_ep_size=dp_size // args.extra_dp_size, + use_ep_inside=use_ep_inside, + **mgr_dict, ) elif args.plugin == "hybrid": + dp_size = dist.get_world_size() // args.pp_size plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=args.pp_size, - zero_stage=args.zero_stage, microbatch_size=args.microbatch_size, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel if not test_mode else False, - enable_jit_fused=args.use_kernel if not test_mode else False, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", mode="fixed", fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, + **mgr_dict, ) else: raise ValueError(f"Invalid plugin {args.plugin}") @@ -202,28 +257,17 @@ def main(): else: repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - moe_args = { - "num_experts": config.num_experts, - "moe_layer_interval": config.moe_layer_interval, - "router_topk": 2, - "router_capacity_factor_train": 1.25, - "router_capacity_factor_eval": 2.0, - "router_min_capacity": 4, - "router_noisy_policy": None, - "router_drop_tks": True, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.01, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": False, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": False, - "enable_comm_overlap": False, - } - set_moe_args(config, moe_args) + set_openmoe_args( + config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + router_aux_loss_factor=args.router_aux_loss_factor, + router_z_loss_factor=args.router_z_loss_factor, + z_loss_factor=args.z_loss_factor, + enable_load_balance=args.load_balance, + enable_comm_overlap=args.comm_overlap, + enable_kernel=args.use_kernel, + ) with skip_init(): model = OpenMoeForCausalLM(config) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) @@ -233,7 +277,7 @@ def main(): # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if not test_mode else 20) + dataset = RandomDataset(num_samples=1000 if not test_mode else 20, tokenizer=tokenizer) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer @@ -259,7 +303,7 @@ def main(): desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", disable=not coordinator.is_master(), ) as pbar: - for _ in pbar: + for step in pbar: if use_pipeline: # Forward pass outputs = booster.execute_pipeline( @@ -287,6 +331,11 @@ def main(): optimizer.step() optimizer.zero_grad() + # Apply load balance + if args.load_balance and args.load_balance_interval > 0 and step % args.load_balance_interval == 0: + coordinator.print_on_master(f"Apply load balance") + apply_load_balance(model, optimizer) + # Finish training and evaluate logger.info(f"Finish finetuning", ranks=[0]) booster.save_model(model, args.output_path) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 40aae12f016a..b68eaec50fea 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -14,52 +14,28 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "examples/language/openmoe")) +sys.path.append(os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "examples/language/openmoe", +)) OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM +set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy def get_config(): config = LlamaConfig( vocab_size=300, - hidden_size=32, - intermediate_size=64, - num_hidden_layers=2, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=4, num_attention_heads=2, + head_dim=4, + dropout_rate=0.0, + hidden_act="swiglu", ) - settings = { - "vocab_size": 300, - "intermediate_size": 32, - "hidden_size": 16, - "num_hidden_layers": 2, - "head_dim": 4, - "num_attention_heads": 4, - "dropout_rate": 0.0, - "hidden_act": "swiglu", - "num_experts": 16, - "capacity_factor_train": 1.25, - "capacity_factor_eval": 2.0, - "min_capacity": 4, - "noisy_policy": None, - "drop_tks": True, - "moe_layer_interval": 4, - "router_aux_loss_factor": 0.1, - "router_z_loss_factor": 0.1, - "label_smoothing": 0.1, - "z_loss_factor": 0.1, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": False, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": False, - "enable_comm_overlap": False, - } - for key, value in settings.items(): - setattr(config, key, value) + set_openmoe_args(config, num_experts=16, moe_layer_interval=1) return config