Skip to content

Commit

Permalink
Add PyTorch multi-gemm MOE implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
HeyangQin committed May 29, 2024
1 parent 988372b commit 7b8ba2b
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 4 deletions.
14 changes: 10 additions & 4 deletions deepspeed/inference/v2/modules/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
DSUnembedRegistry,
)

import torch

def instantiate_attention(attention_config: DSSelfAttentionConfig,
engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase:
Expand Down Expand Up @@ -129,10 +130,15 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin
"weight_dtype": moe_config.input_dtype,
}

# Currently, we only have one implementation, so we just return it.
config = ConfigBundle(name="cutlass_multi_gemm_moe",
config=moe_config,
implementation_config=implementation_config)
# check if we are on H100 or above
if torch.cuda.get_device_capability(0)[0] >= 9:
config = ConfigBundle(name="pytorch_multi_gemm_moe",
config=moe_config,
implementation_config=implementation_config)
else:
config = ConfigBundle(name="cutlass_multi_gemm_moe",
config=moe_config,
implementation_config=implementation_config)
return DSMoERegistry.instantiate_config(config)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
# DeepSpeed Team

from .cutlass_multi_gemm import DSMultiGemmMoE
from .pytorch_multi_gemm import DSPytorchMultiGemmMoE
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Any, Dict, Optional, Tuple

import torch

from deepspeed.accelerator import get_accelerator
from ....allocator import empty_from
from ....inference_utils import ActivationType, is_gated
from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation
from ....kernels.ragged_ops import (
MoEGather,
MoEScatter,
RaggedTopKGating,
)
from ....ragged import RaggedBatchWrapper

from ...interfaces import DSMoEBase, DSMoERegistry
from ...configs import DSMoEConfig
# from ....kernels.cutlass_ops import MoEGEMM
from ....inference_parameter import InferenceParameter

def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> torch.Tensor:
if is_gated(act_type):
act_func_map = {
ActivationType.ReGLU: torch.nn.functional.relu,
ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
ActivationType.SiGLU: torch.nn.functional.silu,
}

act_act = out_states[..., ::2]
act_linear = out_states[..., 1::2]

act_act = act_func_map[act_type](act_act)
return act_act * act_linear
else:
act_func_map = {
ActivationType.RELU: torch.nn.functional.relu,
ActivationType.GELU: torch.nn.functional.gelu,
ActivationType.SILU: torch.nn.functional.silu,
ActivationType.IDENTITY: lambda x: x,
}

return act_func_map[act_type](out_states)

def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reference gating code.
"""
logits = logits.float()
probs = torch.nn.functional.softmax(logits, dim=1)

indices1_s = torch.argmax(probs, dim=-1)
mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1])
indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1
indices1_s = torch.min(indices1_s, indices_mask)

gates1_s = (probs * mask1).sum(dim=1)

sorted_indices = indices1_s.sort()[1]
original_indices = sorted_indices.sort()[1]

exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long()
exp_count_cumsum = exp_count.cumsum(dim=0)

return sorted_indices, original_indices, exp_count_cumsum, gates1_s

@DSMoERegistry.register_module
class DSPytorchMultiGemmMoE(DSMoEBase):
"""
MoE implementation based on the CUTLASS multi-GEMM.
"""

@staticmethod
def name():
return 'pytorch_multi_gemm_moe'

@staticmethod
def supports_config(config: DSMoEConfig) -> bool:
if config.input_dtype != config.output_dtype:
return False

if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16:
return False

if config.top_k != 1 and config.top_k != 2:
return False

return True

def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None:
super().__init__(config, implementation_config)

# Convenience variables for frequently accessed items.
self.max_tokens = self._config.max_tokens
self.n_experts = self._config.n_experts
self.n_top_k = self._config.top_k
self.intermediate_dim = self._config.intermediate_features

moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation

# self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn)
# self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY)

if is_gated(self._config.activation):
self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype,
self._config.activation)
else:
self._activation = None

self._gate_proj = BlasLibLinear(self._config.input_dtype)
self._top_1_gate = RaggedTopKGating(config.input_dtype)
self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim)
self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores)

self._create_buffers()

def _create_buffers(self):

# Gating buffers
self._logits = torch.empty((self._config.max_tokens, self.n_experts),
dtype=self._config.input_dtype,
device=get_accelerator().current_device())
self._expert_counts = torch.empty((self.n_experts, ),
dtype=torch.int32,
device=get_accelerator().current_device())
self._scores = torch.empty((self._config.max_tokens, self.n_top_k),
dtype=torch.float32,
device=get_accelerator().current_device())
self._assignments = torch.empty((self._config.max_tokens, self.n_top_k),
dtype=torch.int32,
device=get_accelerator().current_device())
self._offsets = torch.empty((self._config.max_tokens, self.n_top_k),
dtype=torch.int32,
device=get_accelerator().current_device())

# Scatter buffers
self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim),
dtype=self._config.input_dtype,
device=get_accelerator().current_device())
self._expert_cumsum = torch.empty((self._config.n_experts, ),
dtype=torch.int64,
device=get_accelerator().current_device())
self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k),
dtype=torch.int32,
device=get_accelerator().current_device())

# GEMM Buffers
self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features),
dtype=self._config.output_dtype,
device=get_accelerator().current_device())
if self._activation is not None:
self._gated_intermediate = torch.empty(
(self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2),
dtype=self._config.output_dtype,
device=get_accelerator().current_device())

self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim),
dtype=self._config.output_dtype,
device=get_accelerator().current_device())

# Gather buffer
self._output = torch.empty((self._config.max_tokens, self._config.model_dim),
dtype=self._config.output_dtype,
device=get_accelerator().current_device())

def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter:
"""
Ensures gate param is going to match the activation data type.
"""
param = param.to(self._config.input_dtype)
return InferenceParameter.initialize(param)

def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter:
"""
Converts param to same data type as input and output.
Parameters:
param (torch.Tensor): Weight or bias tensor.
"""
param = param.to(self._config.input_dtype)

# if len(param.shape) == 3:
# param = param.permute(0, 2, 1).contiguous()
return InferenceParameter.initialize(param)

def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter:
"""
Converts param to same data type as input and output.
Parameters:
param (torch.Tensor): Weight or bias tensor.
"""
param = param.to(self._config.input_dtype)

# if len(param.shape) == 3:
# param = param.permute(0, 2, 1).contiguous()
return InferenceParameter.initialize(param)

@property
def output(self) -> torch.Tensor:
return self._output

def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper,
gate_w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Helper function to isolate the logit for gating. This will take the hidden states
and produce the metadata + tensors for the CUTLASS ragged GEMMs. If the input has
been padded for CG, this will strip the padding for MoE.
Parameters:
hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [n_tokens, model_dim].
batch_metadata (RaggedBatchWrapper): Batch metadata for the hidden states.
gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim].
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input, the cumsum of the offsets (for the MoE kernels themselves), the scores, and the mapped slots (to recover the original order of the tokens)
"""

# Get views on the buffers for gating
logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1]))
scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k))
assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k))
offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k))
mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k))
moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1]))

self._gate_proj(logits, hidden_states, gate_w)
self._expert_counts.zero_()
self._top_1_gate(self._expert_counts, scores, assignments, offsets, logits, batch_metadata)
self._moe_scatter(moe_input, self._expert_cumsum, mapped_slots, hidden_states, self._expert_counts,
assignments, offsets)

return moe_input, self._expert_cumsum, scores, mapped_slots

def forward(self,
hidden_states: torch.Tensor,
batch_metadata: RaggedBatchWrapper,
gate_w: torch.Tensor,
mlp_1_w: torch.Tensor,
mlp_2_w: torch.Tensor,
mlp_1_b: Optional[torch.Tensor] = None,
mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
MoE forward pass built on top of CUTLASS multi-GEMM.
Parameters:
hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim].
gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim].
"""
print("Using DSPytorchMultiGemmMoE forward pass")

moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w)

# Get views on the buffers for GEMM
intermediate = empty_from(self._intermediate,
(hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1]))
output_unordered = empty_from(self._output_unordered,
(hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1]))
output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1]))

for expert_idx in range(mlp_1_w.shape[0]):
min_bound = 0 if expert_idx == 0 else expert_cumsum[expert_idx - 1]
max_bound = expert_cumsum[expert_idx]

input_slice = moe_input[min_bound:max_bound]
intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx] if mlp_1_b is not None else None)

intermediate = _activation_reference(intermediate, self._config.activation)
output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx] if mlp_2_b is not None else None)

output_unordered[min_bound:max_bound] = output_slice

# if self._activation is not None:
# gated_intermediate = empty_from(
# self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1]))
# self._mlp_1(
# gated_intermediate,
# moe_input,
# mlp_1_w,
# expert_cumsum,
# mlp_1_b,
# )
# self._activation(intermediate, gated_intermediate)
# else:
# self._mlp_1(
# intermediate,
# moe_input,
# mlp_1_w,
# expert_cumsum,
# mlp_1_b,
# )

# self._mlp_2(
# output_unordered,
# intermediate,
# mlp_2_w,
# expert_cumsum,
# mlp_2_b,
# )

self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts)
return output

0 comments on commit 7b8ba2b

Please sign in to comment.