From ea17ea6d7b4e2c68cfbfb68538b207bc46c6ed6a Mon Sep 17 00:00:00 2001 From: ghostplant Date: Mon, 27 Dec 2021 06:02:53 +0000 Subject: [PATCH] support handling multi-gate options (#71) --- README.md | 4 ++- tutel/impls/moe_layer.py | 56 ++++++++++++++++++++-------------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index dde12c7e..429e70cb 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,9 @@ Usage of MOELayer: ``` * Usage of MOELayer Args: - gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'} + gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}, + or a list of dict-type gate descriptions, e.g. [{'type': 'top', 'k', 2}, {'type': 'top', 'k', 2}], + the value of k in top-gating can be also negative, like -2, which indicates one GPU will hold 1/(-k) parameters of an expert model_dim : the number of channels for MOE's input tensor experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)` diff --git a/tutel/impls/moe_layer.py b/tutel/impls/moe_layer.py index abbb9dc4..98bfe5d2 100644 --- a/tutel/impls/moe_layer.py +++ b/tutel/impls/moe_layer.py @@ -148,7 +148,7 @@ def apply_on_expert_fn(self, input, expert_fn, group, sharded_count): return result_output, l_loss -class MegatronLMGate(): +class MegatronLMGate(torch.nn.Module): """Megatron-LM Tensor Parallel over MoE Gate Type """ @@ -157,6 +157,9 @@ def __init__( **kwargs, ): self.l_zero = None + self._modules = dict() + self._parameters = dict() + self._buffers = dict() def named_parameters(self): return [] @@ -173,15 +176,6 @@ def apply_on_expert_fn(self, input, expert_fn, group, sharded_count): class MOELayer(torch.nn.Module): """Tutel optimized MOELayer - - Args: - gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'} - model_dim : the number of channels for MOE's input tensor - experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network - scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)` - result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)` - group : specify the explicit communication group of all_to_all - seeds : a tuple containing a tripple of int to specify manual seed of (shared params, local params, others params after MoE's) """ def __init__(self, gate_type, model_dim: int, experts = None, scan_expert_func = None, result_func = None, group: Optional[Any] = None, seeds = None, **kwargs): @@ -342,22 +336,28 @@ def to(self, *args, **kwargs): logging.warning(f"gate_type value `{gate_type}` in tutel.moe_layer has been deprecated, please use gate_type = {{'type': 'top', 'k': {top_k}}} instead.") gate_type = {'type': 'top', 'k': top_k} - if gate_type['type'] == 'top': - if seeds is not None and seeds[0] is not None: - torch.manual_seed(seeds[0]) - - if "fp32_gate" in kwargs: - logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.') - gate_type["fp32_gate"] = kwargs["fp32_gate"] + if not isinstance(gate_type, list): + gate_type = [gate_type] + + self.gates = [] + for gi, single_gate_type in enumerate(gate_type): + if single_gate_type['type'] == 'top': + if seeds is not None and seeds[0] is not None: + torch.manual_seed(seeds[0] + gi) + if "fp32_gate" in kwargs: + logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.') + single_gate_type["fp32_gate"] = kwargs["fp32_gate"] + + self.gates += [TopKGate(model_dim=model_dim, top_k=single_gate_type['k'], num_global_experts=self.num_global_experts, **single_gate_type)] + elif single_gate_type['type'] == 'megatron': + self.gates += [MegatronLMGate(**single_gate_type)] + assert isinstance(experts, dict), "Gate type `megatron` requires dict-type expert description." + assert self.num_local_experts == 1, "Gate type `megatron` requires `count_per_node` == 1 in expert attributions." + assert experts['type'] == 'ffn', "Gate type `megatron` requires `type` == `ffn` in expert attributions." + else: + raise Exception("Unrecognized gate_type: %s" % single_gate_type) - self.gate = TopKGate(model_dim=model_dim, top_k=gate_type['k'], num_global_experts=self.num_global_experts, **gate_type) - elif gate_type['type'] == 'megatron': - self.gate = MegatronLMGate(**gate_type) - assert isinstance(experts, dict), "Gate type `megatron` requires dict-type expert description." - assert self.num_local_experts == 1, "Gate type `megatron` requires `count_per_node` == 1 in expert attributions." - assert experts['type'] == 'ffn', "Gate type `megatron` requires `type` == `ffn` in expert attributions." - else: - raise Exception("Unrecognized gate_type: %s" % gate_type) + self.gates = ModuleList(self.gates) if seeds is not None and len(seeds) > 2 and seeds[2] is not None: torch.manual_seed(seeds[2]) @@ -375,13 +375,13 @@ def expert_fn(dispatched_input): def get_parameter_iterator(self, param_type): if param_type == 'gate': - return self.gate.named_parameters() + return self.gates.named_parameters() elif param_type == 'local_experts': return self.experts.named_parameters() else: raise Exception("Specified parameter type is not recognized: %s. Valid `param_type` includes: gate, local_experts." % param_type) - def forward(self, input: Tensor, **kwargs: Any): + def forward(self, input: Tensor, gate_index=0, **kwargs: Any): if self.skip_moe: result_output = input result_output.l_aux = None @@ -404,7 +404,7 @@ def forward(self, input: Tensor, **kwargs: Any): reshaped_input = pad_input reshaped_input = reshaped_input.to(next(iter(self.experts.parameters())).dtype) - result_output, l_aux = self.gate.apply_on_expert_fn(reshaped_input, self.expert_fn, self.group, sharded_count=self.sharded_count) + result_output, l_aux = self.gates[gate_index].apply_on_expert_fn(reshaped_input, self.expert_fn, self.group, sharded_count=self.sharded_count) result_output = result_output[:reshaped_input_samples, :] result_output = result_output.view(original_shape).to(original_dtype)