diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index f719e299a4b2..38618a4b6018 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -32,6 +32,7 @@ DSUnembedRegistry, ) +import torch def instantiate_attention(attention_config: DSSelfAttentionConfig, engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase: @@ -129,10 +130,15 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin "weight_dtype": moe_config.input_dtype, } - # Currently, we only have one implementation, so we just return it. - config = ConfigBundle(name="cutlass_multi_gemm_moe", - config=moe_config, - implementation_config=implementation_config) + # check if we are on H100 or above + if torch.cuda.get_device_capability(0)[0] >= 9: + config = ConfigBundle(name="pytorch_multi_gemm_moe", + config=moe_config, + implementation_config=implementation_config) + else: + config = ConfigBundle(name="cutlass_multi_gemm_moe", + config=moe_config, + implementation_config=implementation_config) return DSMoERegistry.instantiate_config(config) diff --git a/deepspeed/inference/v2/modules/implementations/moe/__init__.py b/deepspeed/inference/v2/modules/implementations/moe/__init__.py index 053ad5da7746..57b08713be5a 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/__init__.py +++ b/deepspeed/inference/v2/modules/implementations/moe/__init__.py @@ -4,3 +4,4 @@ # DeepSpeed Team from .cutlass_multi_gemm import DSMultiGemmMoE +from .pytorch_multi_gemm import DSPytorchMultiGemmMoE diff --git a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py new file mode 100644 index 000000000000..40058d3ffb92 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py @@ -0,0 +1,307 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import ActivationType, is_gated +from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation +from ....kernels.ragged_ops import ( + MoEGather, + MoEScatter, + RaggedTopKGating, +) +from ....ragged import RaggedBatchWrapper + +from ...interfaces import DSMoEBase, DSMoERegistry +from ...configs import DSMoEConfig +# from ....kernels.cutlass_ops import MoEGEMM +from ....inference_parameter import InferenceParameter + +def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> torch.Tensor: + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + return act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + return act_func_map[act_type](out_states) + +def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Reference gating code. + """ + logits = logits.float() + probs = torch.nn.functional.softmax(logits, dim=1) + + indices1_s = torch.argmax(probs, dim=-1) + mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1]) + indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + gates1_s = (probs * mask1).sum(dim=1) + + sorted_indices = indices1_s.sort()[1] + original_indices = sorted_indices.sort()[1] + + exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long() + exp_count_cumsum = exp_count.cumsum(dim=0) + + return sorted_indices, original_indices, exp_count_cumsum, gates1_s + +@DSMoERegistry.register_module +class DSPytorchMultiGemmMoE(DSMoEBase): + """ + MoE implementation based on the CUTLASS multi-GEMM. + """ + + @staticmethod + def name(): + return 'pytorch_multi_gemm_moe' + + @staticmethod + def supports_config(config: DSMoEConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: + return False + + if config.top_k != 1 and config.top_k != 2: + return False + + return True + + def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + # Convenience variables for frequently accessed items. + self.max_tokens = self._config.max_tokens + self.n_experts = self._config.n_experts + self.n_top_k = self._config.top_k + self.intermediate_dim = self._config.intermediate_features + + moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation + + # self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn) + # self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) + + if is_gated(self._config.activation): + self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype, + self._config.activation) + else: + self._activation = None + + self._gate_proj = BlasLibLinear(self._config.input_dtype) + self._top_1_gate = RaggedTopKGating(config.input_dtype) + self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) + self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores) + + self._create_buffers() + + def _create_buffers(self): + + # Gating buffers + self._logits = torch.empty((self._config.max_tokens, self.n_experts), + dtype=self._config.input_dtype, + device=get_accelerator().current_device()) + self._expert_counts = torch.empty((self.n_experts, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._scores = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.float32, + device=get_accelerator().current_device()) + self._assignments = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._offsets = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + # Scatter buffers + self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), + dtype=self._config.input_dtype, + device=get_accelerator().current_device()) + self._expert_cumsum = torch.empty((self._config.n_experts, ), + dtype=torch.int64, + device=get_accelerator().current_device()) + self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + # GEMM Buffers + self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + if self._activation is not None: + self._gated_intermediate = torch.empty( + (self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + # Gather buffer + self._output = torch.empty((self._config.max_tokens, self._config.model_dim), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Ensures gate param is going to match the activation data type. + """ + param = param.to(self._config.input_dtype) + return InferenceParameter.initialize(param) + + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.input_dtype) + + # if len(param.shape) == 3: + # param = param.permute(0, 2, 1).contiguous() + return InferenceParameter.initialize(param) + + def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.input_dtype) + + # if len(param.shape) == 3: + # param = param.permute(0, 2, 1).contiguous() + return InferenceParameter.initialize(param) + + @property + def output(self) -> torch.Tensor: + return self._output + + def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, + gate_w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Helper function to isolate the logit for gating. This will take the hidden states + and produce the metadata + tensors for the CUTLASS ragged GEMMs. If the input has + been padded for CG, this will strip the padding for MoE. + + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [n_tokens, model_dim]. + batch_metadata (RaggedBatchWrapper): Batch metadata for the hidden states. + gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input, the cumsum of the offsets (for the MoE kernels themselves), the scores, and the mapped slots (to recover the original order of the tokens) + """ + + # Get views on the buffers for gating + logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) + scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k)) + assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k)) + offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k)) + mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k)) + moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1])) + + self._gate_proj(logits, hidden_states, gate_w) + self._expert_counts.zero_() + self._top_1_gate(self._expert_counts, scores, assignments, offsets, logits, batch_metadata) + self._moe_scatter(moe_input, self._expert_cumsum, mapped_slots, hidden_states, self._expert_counts, + assignments, offsets) + + return moe_input, self._expert_cumsum, scores, mapped_slots + + def forward(self, + hidden_states: torch.Tensor, + batch_metadata: RaggedBatchWrapper, + gate_w: torch.Tensor, + mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, + mlp_1_b: Optional[torch.Tensor] = None, + mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + MoE forward pass built on top of CUTLASS multi-GEMM. + + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim]. + gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. + """ + print("Using DSPytorchMultiGemmMoE forward pass") + + moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) + + # Get views on the buffers for GEMM + intermediate = empty_from(self._intermediate, + (hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1])) + output_unordered = empty_from(self._output_unordered, + (hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1])) + output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) + + for expert_idx in range(mlp_1_w.shape[0]): + min_bound = 0 if expert_idx == 0 else expert_cumsum[expert_idx - 1] + max_bound = expert_cumsum[expert_idx] + + input_slice = moe_input[min_bound:max_bound] + intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx] if mlp_1_b is not None else None) + + intermediate = _activation_reference(intermediate, self._config.activation) + output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx] if mlp_2_b is not None else None) + + output_unordered[min_bound:max_bound] = output_slice + + # if self._activation is not None: + # gated_intermediate = empty_from( + # self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1])) + # self._mlp_1( + # gated_intermediate, + # moe_input, + # mlp_1_w, + # expert_cumsum, + # mlp_1_b, + # ) + # self._activation(intermediate, gated_intermediate) + # else: + # self._mlp_1( + # intermediate, + # moe_input, + # mlp_1_w, + # expert_cumsum, + # mlp_1_b, + # ) + + # self._mlp_2( + # output_unordered, + # intermediate, + # mlp_2_w, + # expert_cumsum, + # mlp_2_b, + # ) + + self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts) + return output +