-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PyTorch multi-gemm MOE implementation
- Loading branch information
Showing
3 changed files
with
318 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
307 changes: 307 additions & 0 deletions
307
deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,307 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import torch | ||
|
||
from deepspeed.accelerator import get_accelerator | ||
from ....allocator import empty_from | ||
from ....inference_utils import ActivationType, is_gated | ||
from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation | ||
from ....kernels.ragged_ops import ( | ||
MoEGather, | ||
MoEScatter, | ||
RaggedTopKGating, | ||
) | ||
from ....ragged import RaggedBatchWrapper | ||
|
||
from ...interfaces import DSMoEBase, DSMoERegistry | ||
from ...configs import DSMoEConfig | ||
# from ....kernels.cutlass_ops import MoEGEMM | ||
from ....inference_parameter import InferenceParameter | ||
|
||
def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> torch.Tensor: | ||
if is_gated(act_type): | ||
act_func_map = { | ||
ActivationType.ReGLU: torch.nn.functional.relu, | ||
ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), | ||
ActivationType.SiGLU: torch.nn.functional.silu, | ||
} | ||
|
||
act_act = out_states[..., ::2] | ||
act_linear = out_states[..., 1::2] | ||
|
||
act_act = act_func_map[act_type](act_act) | ||
return act_act * act_linear | ||
else: | ||
act_func_map = { | ||
ActivationType.RELU: torch.nn.functional.relu, | ||
ActivationType.GELU: torch.nn.functional.gelu, | ||
ActivationType.SILU: torch.nn.functional.silu, | ||
ActivationType.IDENTITY: lambda x: x, | ||
} | ||
|
||
return act_func_map[act_type](out_states) | ||
|
||
def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Reference gating code. | ||
""" | ||
logits = logits.float() | ||
probs = torch.nn.functional.softmax(logits, dim=1) | ||
|
||
indices1_s = torch.argmax(probs, dim=-1) | ||
mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1]) | ||
indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1 | ||
indices1_s = torch.min(indices1_s, indices_mask) | ||
|
||
gates1_s = (probs * mask1).sum(dim=1) | ||
|
||
sorted_indices = indices1_s.sort()[1] | ||
original_indices = sorted_indices.sort()[1] | ||
|
||
exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long() | ||
exp_count_cumsum = exp_count.cumsum(dim=0) | ||
|
||
return sorted_indices, original_indices, exp_count_cumsum, gates1_s | ||
|
||
@DSMoERegistry.register_module | ||
class DSPytorchMultiGemmMoE(DSMoEBase): | ||
""" | ||
MoE implementation based on the CUTLASS multi-GEMM. | ||
""" | ||
|
||
@staticmethod | ||
def name(): | ||
return 'pytorch_multi_gemm_moe' | ||
|
||
@staticmethod | ||
def supports_config(config: DSMoEConfig) -> bool: | ||
if config.input_dtype != config.output_dtype: | ||
return False | ||
|
||
if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: | ||
return False | ||
|
||
if config.top_k != 1 and config.top_k != 2: | ||
return False | ||
|
||
return True | ||
|
||
def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None: | ||
super().__init__(config, implementation_config) | ||
|
||
# Convenience variables for frequently accessed items. | ||
self.max_tokens = self._config.max_tokens | ||
self.n_experts = self._config.n_experts | ||
self.n_top_k = self._config.top_k | ||
self.intermediate_dim = self._config.intermediate_features | ||
|
||
moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation | ||
|
||
# self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn) | ||
# self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) | ||
|
||
if is_gated(self._config.activation): | ||
self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype, | ||
self._config.activation) | ||
else: | ||
self._activation = None | ||
|
||
self._gate_proj = BlasLibLinear(self._config.input_dtype) | ||
self._top_1_gate = RaggedTopKGating(config.input_dtype) | ||
self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) | ||
self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores) | ||
|
||
self._create_buffers() | ||
|
||
def _create_buffers(self): | ||
|
||
# Gating buffers | ||
self._logits = torch.empty((self._config.max_tokens, self.n_experts), | ||
dtype=self._config.input_dtype, | ||
device=get_accelerator().current_device()) | ||
self._expert_counts = torch.empty((self.n_experts, ), | ||
dtype=torch.int32, | ||
device=get_accelerator().current_device()) | ||
self._scores = torch.empty((self._config.max_tokens, self.n_top_k), | ||
dtype=torch.float32, | ||
device=get_accelerator().current_device()) | ||
self._assignments = torch.empty((self._config.max_tokens, self.n_top_k), | ||
dtype=torch.int32, | ||
device=get_accelerator().current_device()) | ||
self._offsets = torch.empty((self._config.max_tokens, self.n_top_k), | ||
dtype=torch.int32, | ||
device=get_accelerator().current_device()) | ||
|
||
# Scatter buffers | ||
self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), | ||
dtype=self._config.input_dtype, | ||
device=get_accelerator().current_device()) | ||
self._expert_cumsum = torch.empty((self._config.n_experts, ), | ||
dtype=torch.int64, | ||
device=get_accelerator().current_device()) | ||
self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k), | ||
dtype=torch.int32, | ||
device=get_accelerator().current_device()) | ||
|
||
# GEMM Buffers | ||
self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features), | ||
dtype=self._config.output_dtype, | ||
device=get_accelerator().current_device()) | ||
if self._activation is not None: | ||
self._gated_intermediate = torch.empty( | ||
(self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2), | ||
dtype=self._config.output_dtype, | ||
device=get_accelerator().current_device()) | ||
|
||
self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), | ||
dtype=self._config.output_dtype, | ||
device=get_accelerator().current_device()) | ||
|
||
# Gather buffer | ||
self._output = torch.empty((self._config.max_tokens, self._config.model_dim), | ||
dtype=self._config.output_dtype, | ||
device=get_accelerator().current_device()) | ||
|
||
def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter: | ||
""" | ||
Ensures gate param is going to match the activation data type. | ||
""" | ||
param = param.to(self._config.input_dtype) | ||
return InferenceParameter.initialize(param) | ||
|
||
def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter: | ||
""" | ||
Converts param to same data type as input and output. | ||
Parameters: | ||
param (torch.Tensor): Weight or bias tensor. | ||
""" | ||
param = param.to(self._config.input_dtype) | ||
|
||
# if len(param.shape) == 3: | ||
# param = param.permute(0, 2, 1).contiguous() | ||
return InferenceParameter.initialize(param) | ||
|
||
def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter: | ||
""" | ||
Converts param to same data type as input and output. | ||
Parameters: | ||
param (torch.Tensor): Weight or bias tensor. | ||
""" | ||
param = param.to(self._config.input_dtype) | ||
|
||
# if len(param.shape) == 3: | ||
# param = param.permute(0, 2, 1).contiguous() | ||
return InferenceParameter.initialize(param) | ||
|
||
@property | ||
def output(self) -> torch.Tensor: | ||
return self._output | ||
|
||
def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, | ||
gate_w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
""" | ||
Helper function to isolate the logit for gating. This will take the hidden states | ||
and produce the metadata + tensors for the CUTLASS ragged GEMMs. If the input has | ||
been padded for CG, this will strip the padding for MoE. | ||
Parameters: | ||
hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [n_tokens, model_dim]. | ||
batch_metadata (RaggedBatchWrapper): Batch metadata for the hidden states. | ||
gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. | ||
Returns: | ||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input, the cumsum of the offsets (for the MoE kernels themselves), the scores, and the mapped slots (to recover the original order of the tokens) | ||
""" | ||
|
||
# Get views on the buffers for gating | ||
logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) | ||
scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k)) | ||
assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k)) | ||
offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k)) | ||
mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k)) | ||
moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1])) | ||
|
||
self._gate_proj(logits, hidden_states, gate_w) | ||
self._expert_counts.zero_() | ||
self._top_1_gate(self._expert_counts, scores, assignments, offsets, logits, batch_metadata) | ||
self._moe_scatter(moe_input, self._expert_cumsum, mapped_slots, hidden_states, self._expert_counts, | ||
assignments, offsets) | ||
|
||
return moe_input, self._expert_cumsum, scores, mapped_slots | ||
|
||
def forward(self, | ||
hidden_states: torch.Tensor, | ||
batch_metadata: RaggedBatchWrapper, | ||
gate_w: torch.Tensor, | ||
mlp_1_w: torch.Tensor, | ||
mlp_2_w: torch.Tensor, | ||
mlp_1_b: Optional[torch.Tensor] = None, | ||
mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
""" | ||
MoE forward pass built on top of CUTLASS multi-GEMM. | ||
Parameters: | ||
hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim]. | ||
gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. | ||
""" | ||
print("Using DSPytorchMultiGemmMoE forward pass") | ||
|
||
moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) | ||
|
||
# Get views on the buffers for GEMM | ||
intermediate = empty_from(self._intermediate, | ||
(hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1])) | ||
output_unordered = empty_from(self._output_unordered, | ||
(hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1])) | ||
output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) | ||
|
||
for expert_idx in range(mlp_1_w.shape[0]): | ||
min_bound = 0 if expert_idx == 0 else expert_cumsum[expert_idx - 1] | ||
max_bound = expert_cumsum[expert_idx] | ||
|
||
input_slice = moe_input[min_bound:max_bound] | ||
intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx] if mlp_1_b is not None else None) | ||
|
||
intermediate = _activation_reference(intermediate, self._config.activation) | ||
output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx] if mlp_2_b is not None else None) | ||
|
||
output_unordered[min_bound:max_bound] = output_slice | ||
|
||
# if self._activation is not None: | ||
# gated_intermediate = empty_from( | ||
# self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1])) | ||
# self._mlp_1( | ||
# gated_intermediate, | ||
# moe_input, | ||
# mlp_1_w, | ||
# expert_cumsum, | ||
# mlp_1_b, | ||
# ) | ||
# self._activation(intermediate, gated_intermediate) | ||
# else: | ||
# self._mlp_1( | ||
# intermediate, | ||
# moe_input, | ||
# mlp_1_w, | ||
# expert_cumsum, | ||
# mlp_1_b, | ||
# ) | ||
|
||
# self._mlp_2( | ||
# output_unordered, | ||
# intermediate, | ||
# mlp_2_w, | ||
# expert_cumsum, | ||
# mlp_2_b, | ||
# ) | ||
|
||
self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts) | ||
return output | ||
|