Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastGen H100 MoE support: Add PyTorch multi-gemm MOE implementation #5586

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
15 changes: 11 additions & 4 deletions deepspeed/inference/v2/modules/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
DSUnembedRegistry,
)

import torch


def instantiate_attention(attention_config: DSSelfAttentionConfig,
engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase:
Expand Down Expand Up @@ -129,10 +131,15 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin
"weight_dtype": moe_config.input_dtype,
}

# Currently, we only have one implementation, so we just return it.
config = ConfigBundle(name="cutlass_multi_gemm_moe",
config=moe_config,
implementation_config=implementation_config)
# check if we are on H100 or above
if torch.cuda.get_device_capability(0)[0] >= 9: #ignore-cuda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an extension the accelerator interface to avoid cuda references.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what is expected behavior when running on non-cuda devices, where torch.cuda is unavailable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. How about we extend the accelerator interface and return -1 for non-cuda devices?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need two accelerator API changes

  1. accelerator.name() that can be used in this case to restrict to cuda.
  2. accelerator.compute_capability() returns tuple of int major, minor versions (similar to cuda approach) of current accelerator which the client can use for control flow.

While the first API is very straightforward, the second seems a bit tricky if we want accelerator to freely manage versioning.

@delock, @nelyahu, @hipudding I will appreciate your thoughts on if the proposed capability API provides sufficient freedom for your cases. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase

  1. the existing device_name API with "device_index" input as None, provides this functionality.
  2. This API is very specific where compute capabilties are encoded as major.minor. IMO such switching between optimizations should be configured by the user. and not set automatically according to device types. this is inconsistent even for GPUs. Switching from H100 to A100 for example to debug an accuracy issue on different setup will provide unexpected different behavior.
    I am not familiar with the implementation of the ConfigBundle flow, so I can't help finding an alternative.
    Also the CUDAGatedActivation is a CUDA specific - So it seems like the existing structure of the code is problematic in terms of accelerator generalization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase

Yes, when the device_index is set to None in the device_name function, it is possible to obtain the device name. For Ascend and CANN (npu_accelerator), the compute_capability is not yet needed, but I am not sure if there will be such a requirement in future versions. However, if this proposal is ready to be implemented, I would be happy to cooperate in modifying the npu_accelerator part.

Copy link
Collaborator

@delock delock Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accelerator.compute_capability() seems an accelerator specific code so even if caller gets this tuple, caller still needs to know which accelerator it is to make proper decision.

From the context, caller needs to decide whether cutlass_multi_gemm_moe should be used or pytorch_multi_gemm_moe should be used. Thus the following interface in accelerator might help this situation and extendable in the future:

accelerator.get_property("multi_gemm_moe"). This returns either None (accelerator didn't define this property) or "cutlass_multi_gemm_moe" (accelerator is CUDA with compute capability < 9) or anything else an accelerator defined and preferred to use. Then caller could use this property in the context:

name = get_accelerator().get_property("multi_gemm_moe")
if name == None:
    name = "pytorch_multi_gemm_moe" # default behavior if accelerator didn't define this property
config = ConfigBundle(name=name,
                      config=moe_config,
                      implementation_config=implementation_config)

And in CUDA accelerator we could have something like this:

def get_property(query):
    ...
    if query == "multi_gemm_moe":
        if torch.cuda.get_device_capability(0)[0] >= 9:
            return None  # or "pytorch_multi_gemm_moe"
        else:
            return "cutlass_multi_gemm_moe"
   ...
   return None

Other accelerators only need to return None for unrecognized properties, so future maintainence time could be small.

def get_property(query):
   ...
   return None   # query "multi_gemm_moe" returns from here

config = ConfigBundle(name="pytorch_multi_gemm_moe",
config=moe_config,
implementation_config=implementation_config)
else:
config = ConfigBundle(name="cutlass_multi_gemm_moe",
config=moe_config,
implementation_config=implementation_config)
return DSMoERegistry.instantiate_config(config)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
# DeepSpeed Team

from .cutlass_multi_gemm import DSMultiGemmMoE
from .pytorch_multi_gemm import DSPytorchMultiGemmMoE
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Any, Dict, Optional, Tuple

import torch

from deepspeed.accelerator import get_accelerator
from ....allocator import empty_from
from ....inference_utils import ActivationType, is_gated
from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation
from ....kernels.ragged_ops import (
MoEGather,
MoEScatter,
RaggedTopKGating,
)
from ....ragged import RaggedBatchWrapper

from ...interfaces import DSMoEBase, DSMoERegistry
from ...configs import DSMoEConfig
# from ....kernels.cutlass_ops import MoEGEMM
from ....inference_parameter import InferenceParameter


def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> torch.Tensor:
if is_gated(act_type):
act_func_map = {
ActivationType.ReGLU: torch.nn.functional.relu,
ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
ActivationType.SiGLU: torch.nn.functional.silu,
}

act_act = out_states[..., ::2]
act_linear = out_states[..., 1::2]

act_act = act_func_map[act_type](act_act)
return act_act * act_linear
else:
act_func_map = {
ActivationType.RELU: torch.nn.functional.relu,
ActivationType.GELU: torch.nn.functional.gelu,
ActivationType.SILU: torch.nn.functional.silu,
ActivationType.IDENTITY: lambda x: x,
}

return act_func_map[act_type](out_states)


@DSMoERegistry.register_module
class DSPytorchMultiGemmMoE(DSMoEBase):
"""
MoE implementation based on the CUTLASS multi-GEMM.
"""

@staticmethod
def name():
return 'pytorch_multi_gemm_moe'

@staticmethod
def supports_config(config: DSMoEConfig) -> bool:
if config.input_dtype != config.output_dtype:
return False

if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16:
return False

if config.top_k != 1 and config.top_k != 2:
return False

return True

def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None:
super().__init__(config, implementation_config)

# Convenience variables for frequently accessed items.
self.max_tokens = self._config.max_tokens
self.n_experts = self._config.n_experts
self.n_top_k = self._config.top_k
self.intermediate_dim = self._config.intermediate_features

moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation

# self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn)
# self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY)

if is_gated(self._config.activation):
self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype,
self._config.activation)
else:
self._activation = None

self._gate_proj = BlasLibLinear(self._config.input_dtype)
self._top_1_gate = RaggedTopKGating(config.input_dtype)
self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim)
self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores)

self._create_buffers()

def _create_buffers(self):

# Gating buffers
self._logits = torch.empty((self._config.max_tokens, self.n_experts),
dtype=self._config.input_dtype,
device=get_accelerator().current_device())
self._expert_counts = torch.empty((self.n_experts, ),
dtype=torch.int32,
device=get_accelerator().current_device())
self._scores = torch.empty((self._config.max_tokens, self.n_top_k),
dtype=torch.float32,
device=get_accelerator().current_device())
self._assignments = torch.empty((self._config.max_tokens, self.n_top_k),
dtype=torch.int32,
device=get_accelerator().current_device())
self._offsets = torch.empty((self._config.max_tokens, self.n_top_k),
dtype=torch.int32,
device=get_accelerator().current_device())

# Scatter buffers
self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim),
dtype=self._config.input_dtype,
device=get_accelerator().current_device())
self._expert_cumsum = torch.empty((self._config.n_experts, ),
dtype=torch.int64,
device=get_accelerator().current_device())
self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k),
dtype=torch.int32,
device=get_accelerator().current_device())

# GEMM Buffers
self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features),
dtype=self._config.output_dtype,
device=get_accelerator().current_device())
if self._activation is not None:
self._gated_intermediate = torch.empty(
(self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2),
dtype=self._config.output_dtype,
device=get_accelerator().current_device())

self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim),
dtype=self._config.output_dtype,
device=get_accelerator().current_device())

# Gather buffer
self._output = torch.empty((self._config.max_tokens, self._config.model_dim),
dtype=self._config.output_dtype,
device=get_accelerator().current_device())

def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter:
"""
Ensures gate param is going to match the activation data type.
"""
param = param.to(self._config.input_dtype)
return InferenceParameter.initialize(param)

def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter:
"""
Converts param to same data type as input and output.

Parameters:
param (torch.Tensor): Weight or bias tensor.
"""
param = param.to(self._config.input_dtype)

# if len(param.shape) == 3:
# param = param.permute(0, 2, 1).contiguous()
return InferenceParameter.initialize(param)

def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter:
"""
Converts param to same data type as input and output.

Parameters:
param (torch.Tensor): Weight or bias tensor.
"""
param = param.to(self._config.input_dtype)

# if len(param.shape) == 3:
# param = param.permute(0, 2, 1).contiguous()
return InferenceParameter.initialize(param)

@property
def output(self) -> torch.Tensor:
return self._output

def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper,
gate_w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Helper function to isolate the logit for gating. This will take the hidden states
and produce the metadata + tensors for the CUTLASS ragged GEMMs. If the input has
been padded for CG, this will strip the padding for MoE.

Parameters:
hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [n_tokens, model_dim].
batch_metadata (RaggedBatchWrapper): Batch metadata for the hidden states.
gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim].

Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input, the cumsum of the offsets (for the MoE kernels themselves), the scores, and the mapped slots (to recover the original order of the tokens)
"""

# Get views on the buffers for gating
logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1]))
scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k))
assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k))
offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k))
mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k))
moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1]))

self._gate_proj(logits, hidden_states, gate_w)
self._expert_counts.zero_()
self._top_1_gate(self._expert_counts, scores, assignments, offsets, logits, batch_metadata)
self._moe_scatter(moe_input, self._expert_cumsum, mapped_slots, hidden_states, self._expert_counts,
assignments, offsets)

return moe_input, self._expert_cumsum, scores, mapped_slots

def forward(self,
hidden_states: torch.Tensor,
batch_metadata: RaggedBatchWrapper,
gate_w: torch.Tensor,
mlp_1_w: torch.Tensor,
mlp_2_w: torch.Tensor,
mlp_1_b: Optional[torch.Tensor] = None,
mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
MoE forward pass built on top of CUTLASS multi-GEMM.

Parameters:
hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim].
gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim].
"""

moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w)

# Get views on the buffers for GEMM
intermediate = empty_from(self._intermediate,
(hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1]))
output_unordered = empty_from(self._output_unordered,
(hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1]))
output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1]))

for expert_idx in range(mlp_1_w.shape[0]):
min_bound = 0 if expert_idx == 0 else expert_cumsum[expert_idx - 1]
max_bound = expert_cumsum[expert_idx]

input_slice = moe_input[min_bound:max_bound]
intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx],
mlp_1_b[expert_idx] if mlp_1_b is not None else None)

intermediate = _activation_reference(intermediate, self._config.activation)
output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx],
mlp_2_b[expert_idx] if mlp_2_b is not None else None)

output_unordered[min_bound:max_bound] = output_slice

# if self._activation is not None:
# gated_intermediate = empty_from(
# self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1]))
# self._mlp_1(
# gated_intermediate,
# moe_input,
# mlp_1_w,
# expert_cumsum,
# mlp_1_b,
# )
# self._activation(intermediate, gated_intermediate)
# else:
# self._mlp_1(
# intermediate,
# moe_input,
# mlp_1_w,
# expert_cumsum,
# mlp_1_b,
# )

# self._mlp_2(
# output_unordered,
# intermediate,
# mlp_2_w,
# expert_cumsum,
# mlp_2_b,
# )

self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts)
return output