-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FastGen H100 MoE support: Add PyTorch multi-gemm MOE implementation #5586
Open
HeyangQin
wants to merge
10
commits into
master
Choose a base branch
from
HeyangQin/fastgen_moe_h100
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
7b8ba2b
Add PyTorch multi-gemm MOE implementation
HeyangQin 8274a01
Merge branch 'master' into HeyangQin/fastgen_moe_h100
HeyangQin 88d758e
fix format
HeyangQin 022a7c6
remove unused code
HeyangQin c7db180
use accelerator abstract
HeyangQin 642e105
fix format
HeyangQin ce6b694
revert the accelerator commit
HeyangQin 416fae4
fix format
HeyangQin 1bd24de
Merge branch 'master' into HeyangQin/fastgen_moe_h100
loadams d6d2adf
Merge branch 'master' into HeyangQin/fastgen_moe_h100
loadams File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
287 changes: 287 additions & 0 deletions
287
deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import torch | ||
|
||
from deepspeed.accelerator import get_accelerator | ||
from ....allocator import empty_from | ||
from ....inference_utils import ActivationType, is_gated | ||
from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation | ||
from ....kernels.ragged_ops import ( | ||
MoEGather, | ||
MoEScatter, | ||
RaggedTopKGating, | ||
) | ||
from ....ragged import RaggedBatchWrapper | ||
|
||
from ...interfaces import DSMoEBase, DSMoERegistry | ||
from ...configs import DSMoEConfig | ||
# from ....kernels.cutlass_ops import MoEGEMM | ||
from ....inference_parameter import InferenceParameter | ||
|
||
|
||
def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> torch.Tensor: | ||
if is_gated(act_type): | ||
act_func_map = { | ||
ActivationType.ReGLU: torch.nn.functional.relu, | ||
ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), | ||
ActivationType.SiGLU: torch.nn.functional.silu, | ||
} | ||
|
||
act_act = out_states[..., ::2] | ||
act_linear = out_states[..., 1::2] | ||
|
||
act_act = act_func_map[act_type](act_act) | ||
return act_act * act_linear | ||
else: | ||
act_func_map = { | ||
ActivationType.RELU: torch.nn.functional.relu, | ||
ActivationType.GELU: torch.nn.functional.gelu, | ||
ActivationType.SILU: torch.nn.functional.silu, | ||
ActivationType.IDENTITY: lambda x: x, | ||
} | ||
|
||
return act_func_map[act_type](out_states) | ||
|
||
|
||
@DSMoERegistry.register_module | ||
class DSPytorchMultiGemmMoE(DSMoEBase): | ||
""" | ||
MoE implementation based on the CUTLASS multi-GEMM. | ||
""" | ||
|
||
@staticmethod | ||
def name(): | ||
return 'pytorch_multi_gemm_moe' | ||
|
||
@staticmethod | ||
def supports_config(config: DSMoEConfig) -> bool: | ||
if config.input_dtype != config.output_dtype: | ||
return False | ||
|
||
if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: | ||
return False | ||
|
||
if config.top_k != 1 and config.top_k != 2: | ||
return False | ||
|
||
return True | ||
|
||
def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None: | ||
super().__init__(config, implementation_config) | ||
|
||
# Convenience variables for frequently accessed items. | ||
self.max_tokens = self._config.max_tokens | ||
self.n_experts = self._config.n_experts | ||
self.n_top_k = self._config.top_k | ||
self.intermediate_dim = self._config.intermediate_features | ||
|
||
moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation | ||
|
||
# self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn) | ||
# self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) | ||
|
||
if is_gated(self._config.activation): | ||
self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype, | ||
self._config.activation) | ||
else: | ||
self._activation = None | ||
|
||
self._gate_proj = BlasLibLinear(self._config.input_dtype) | ||
self._top_1_gate = RaggedTopKGating(config.input_dtype) | ||
self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) | ||
self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores) | ||
|
||
self._create_buffers() | ||
|
||
def _create_buffers(self): | ||
|
||
# Gating buffers | ||
self._logits = torch.empty((self._config.max_tokens, self.n_experts), | ||
dtype=self._config.input_dtype, | ||
device=get_accelerator().current_device()) | ||
self._expert_counts = torch.empty((self.n_experts, ), | ||
dtype=torch.int32, | ||
device=get_accelerator().current_device()) | ||
self._scores = torch.empty((self._config.max_tokens, self.n_top_k), | ||
dtype=torch.float32, | ||
device=get_accelerator().current_device()) | ||
self._assignments = torch.empty((self._config.max_tokens, self.n_top_k), | ||
dtype=torch.int32, | ||
device=get_accelerator().current_device()) | ||
self._offsets = torch.empty((self._config.max_tokens, self.n_top_k), | ||
dtype=torch.int32, | ||
device=get_accelerator().current_device()) | ||
|
||
# Scatter buffers | ||
self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), | ||
dtype=self._config.input_dtype, | ||
device=get_accelerator().current_device()) | ||
self._expert_cumsum = torch.empty((self._config.n_experts, ), | ||
dtype=torch.int64, | ||
device=get_accelerator().current_device()) | ||
self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k), | ||
dtype=torch.int32, | ||
device=get_accelerator().current_device()) | ||
|
||
# GEMM Buffers | ||
self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features), | ||
dtype=self._config.output_dtype, | ||
device=get_accelerator().current_device()) | ||
if self._activation is not None: | ||
self._gated_intermediate = torch.empty( | ||
(self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2), | ||
dtype=self._config.output_dtype, | ||
device=get_accelerator().current_device()) | ||
|
||
self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), | ||
dtype=self._config.output_dtype, | ||
device=get_accelerator().current_device()) | ||
|
||
# Gather buffer | ||
self._output = torch.empty((self._config.max_tokens, self._config.model_dim), | ||
dtype=self._config.output_dtype, | ||
device=get_accelerator().current_device()) | ||
|
||
def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter: | ||
""" | ||
Ensures gate param is going to match the activation data type. | ||
""" | ||
param = param.to(self._config.input_dtype) | ||
return InferenceParameter.initialize(param) | ||
|
||
def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter: | ||
""" | ||
Converts param to same data type as input and output. | ||
|
||
Parameters: | ||
param (torch.Tensor): Weight or bias tensor. | ||
""" | ||
param = param.to(self._config.input_dtype) | ||
|
||
# if len(param.shape) == 3: | ||
# param = param.permute(0, 2, 1).contiguous() | ||
return InferenceParameter.initialize(param) | ||
|
||
def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter: | ||
""" | ||
Converts param to same data type as input and output. | ||
|
||
Parameters: | ||
param (torch.Tensor): Weight or bias tensor. | ||
""" | ||
param = param.to(self._config.input_dtype) | ||
|
||
# if len(param.shape) == 3: | ||
# param = param.permute(0, 2, 1).contiguous() | ||
return InferenceParameter.initialize(param) | ||
|
||
@property | ||
def output(self) -> torch.Tensor: | ||
return self._output | ||
|
||
def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, | ||
gate_w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
""" | ||
Helper function to isolate the logit for gating. This will take the hidden states | ||
and produce the metadata + tensors for the CUTLASS ragged GEMMs. If the input has | ||
been padded for CG, this will strip the padding for MoE. | ||
|
||
Parameters: | ||
hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [n_tokens, model_dim]. | ||
batch_metadata (RaggedBatchWrapper): Batch metadata for the hidden states. | ||
gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. | ||
|
||
Returns: | ||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input, the cumsum of the offsets (for the MoE kernels themselves), the scores, and the mapped slots (to recover the original order of the tokens) | ||
""" | ||
|
||
# Get views on the buffers for gating | ||
logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) | ||
scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k)) | ||
assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k)) | ||
offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k)) | ||
mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k)) | ||
moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1])) | ||
|
||
self._gate_proj(logits, hidden_states, gate_w) | ||
self._expert_counts.zero_() | ||
self._top_1_gate(self._expert_counts, scores, assignments, offsets, logits, batch_metadata) | ||
self._moe_scatter(moe_input, self._expert_cumsum, mapped_slots, hidden_states, self._expert_counts, | ||
assignments, offsets) | ||
|
||
return moe_input, self._expert_cumsum, scores, mapped_slots | ||
|
||
def forward(self, | ||
hidden_states: torch.Tensor, | ||
batch_metadata: RaggedBatchWrapper, | ||
gate_w: torch.Tensor, | ||
mlp_1_w: torch.Tensor, | ||
mlp_2_w: torch.Tensor, | ||
mlp_1_b: Optional[torch.Tensor] = None, | ||
mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
""" | ||
MoE forward pass built on top of CUTLASS multi-GEMM. | ||
|
||
Parameters: | ||
hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim]. | ||
gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. | ||
""" | ||
|
||
moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) | ||
|
||
# Get views on the buffers for GEMM | ||
intermediate = empty_from(self._intermediate, | ||
(hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1])) | ||
output_unordered = empty_from(self._output_unordered, | ||
(hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1])) | ||
output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) | ||
|
||
for expert_idx in range(mlp_1_w.shape[0]): | ||
min_bound = 0 if expert_idx == 0 else expert_cumsum[expert_idx - 1] | ||
max_bound = expert_cumsum[expert_idx] | ||
|
||
input_slice = moe_input[min_bound:max_bound] | ||
intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], | ||
mlp_1_b[expert_idx] if mlp_1_b is not None else None) | ||
|
||
intermediate = _activation_reference(intermediate, self._config.activation) | ||
output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], | ||
mlp_2_b[expert_idx] if mlp_2_b is not None else None) | ||
|
||
output_unordered[min_bound:max_bound] = output_slice | ||
|
||
# if self._activation is not None: | ||
# gated_intermediate = empty_from( | ||
# self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1])) | ||
# self._mlp_1( | ||
# gated_intermediate, | ||
# moe_input, | ||
# mlp_1_w, | ||
# expert_cumsum, | ||
# mlp_1_b, | ||
# ) | ||
# self._activation(intermediate, gated_intermediate) | ||
# else: | ||
# self._mlp_1( | ||
# intermediate, | ||
# moe_input, | ||
# mlp_1_w, | ||
# expert_cumsum, | ||
# mlp_1_b, | ||
# ) | ||
|
||
# self._mlp_2( | ||
# output_unordered, | ||
# intermediate, | ||
# mlp_2_w, | ||
# expert_cumsum, | ||
# mlp_2_b, | ||
# ) | ||
|
||
self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts) | ||
return output |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need an extension the accelerator interface to avoid
cuda
references.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what is expected behavior when running on non-cuda devices, where
torch.cuda
is unavailable?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. How about we extend the accelerator interface and return
-1
for non-cuda devices?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need two accelerator API changes
accelerator.name()
that can be used in this case to restrict tocuda
.accelerator.compute_capability()
returns tuple of int major, minor versions (similar to cuda approach) of current accelerator which the client can use for control flow.While the first API is very straightforward, the second seems a bit tricky if we want accelerator to freely manage versioning.
@delock, @nelyahu, @hipudding I will appreciate your thoughts on if the proposed capability API provides sufficient freedom for your cases. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjruwase
I am not familiar with the implementation of the ConfigBundle flow, so I can't help finding an alternative.
Also the CUDAGatedActivation is a CUDA specific - So it seems like the existing structure of the code is problematic in terms of accelerator generalization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjruwase
Yes, when the device_index is set to None in the device_name function, it is possible to obtain the device name. For Ascend and CANN (npu_accelerator), the compute_capability is not yet needed, but I am not sure if there will be such a requirement in future versions. However, if this proposal is ready to be implemented, I would be happy to cooperate in modifying the npu_accelerator part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accelerator.compute_capability()
seems an accelerator specific code so even if caller gets this tuple, caller still needs to know which accelerator it is to make proper decision.From the context, caller needs to decide whether
cutlass_multi_gemm_moe
should be used orpytorch_multi_gemm_moe
should be used. Thus the following interface in accelerator might help this situation and extendable in the future:accelerator.get_property("multi_gemm_moe")
. This returns eitherNone
(accelerator didn't define this property) or"cutlass_multi_gemm_moe"
(accelerator is CUDA with compute capability < 9) or anything else an accelerator defined and preferred to use. Then caller could use this property in the context:And in CUDA accelerator we could have something like this:
Other accelerators only need to return None for unrecognized properties, so future maintainence time could be small.