From 7b8ba2b26f4b7ae57c28c00b39328762bcfad59d Mon Sep 17 00:00:00 2001 From: HeyangQin Date: Wed, 29 May 2024 18:43:18 +0000 Subject: [PATCH 1/7] Add PyTorch multi-gemm MOE implementation --- deepspeed/inference/v2/modules/heuristics.py | 14 +- .../modules/implementations/moe/__init__.py | 1 + .../implementations/moe/pytorch_multi_gemm.py | 307 ++++++++++++++++++ 3 files changed, 318 insertions(+), 4 deletions(-) create mode 100644 deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index f719e299a4b2..38618a4b6018 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -32,6 +32,7 @@ DSUnembedRegistry, ) +import torch def instantiate_attention(attention_config: DSSelfAttentionConfig, engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase: @@ -129,10 +130,15 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin "weight_dtype": moe_config.input_dtype, } - # Currently, we only have one implementation, so we just return it. - config = ConfigBundle(name="cutlass_multi_gemm_moe", - config=moe_config, - implementation_config=implementation_config) + # check if we are on H100 or above + if torch.cuda.get_device_capability(0)[0] >= 9: + config = ConfigBundle(name="pytorch_multi_gemm_moe", + config=moe_config, + implementation_config=implementation_config) + else: + config = ConfigBundle(name="cutlass_multi_gemm_moe", + config=moe_config, + implementation_config=implementation_config) return DSMoERegistry.instantiate_config(config) diff --git a/deepspeed/inference/v2/modules/implementations/moe/__init__.py b/deepspeed/inference/v2/modules/implementations/moe/__init__.py index 053ad5da7746..57b08713be5a 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/__init__.py +++ b/deepspeed/inference/v2/modules/implementations/moe/__init__.py @@ -4,3 +4,4 @@ # DeepSpeed Team from .cutlass_multi_gemm import DSMultiGemmMoE +from .pytorch_multi_gemm import DSPytorchMultiGemmMoE diff --git a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py new file mode 100644 index 000000000000..40058d3ffb92 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py @@ -0,0 +1,307 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import ActivationType, is_gated +from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation +from ....kernels.ragged_ops import ( + MoEGather, + MoEScatter, + RaggedTopKGating, +) +from ....ragged import RaggedBatchWrapper + +from ...interfaces import DSMoEBase, DSMoERegistry +from ...configs import DSMoEConfig +# from ....kernels.cutlass_ops import MoEGEMM +from ....inference_parameter import InferenceParameter + +def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> torch.Tensor: + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + return act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + return act_func_map[act_type](out_states) + +def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Reference gating code. + """ + logits = logits.float() + probs = torch.nn.functional.softmax(logits, dim=1) + + indices1_s = torch.argmax(probs, dim=-1) + mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1]) + indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + gates1_s = (probs * mask1).sum(dim=1) + + sorted_indices = indices1_s.sort()[1] + original_indices = sorted_indices.sort()[1] + + exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long() + exp_count_cumsum = exp_count.cumsum(dim=0) + + return sorted_indices, original_indices, exp_count_cumsum, gates1_s + +@DSMoERegistry.register_module +class DSPytorchMultiGemmMoE(DSMoEBase): + """ + MoE implementation based on the CUTLASS multi-GEMM. + """ + + @staticmethod + def name(): + return 'pytorch_multi_gemm_moe' + + @staticmethod + def supports_config(config: DSMoEConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: + return False + + if config.top_k != 1 and config.top_k != 2: + return False + + return True + + def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + # Convenience variables for frequently accessed items. + self.max_tokens = self._config.max_tokens + self.n_experts = self._config.n_experts + self.n_top_k = self._config.top_k + self.intermediate_dim = self._config.intermediate_features + + moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation + + # self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn) + # self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) + + if is_gated(self._config.activation): + self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype, + self._config.activation) + else: + self._activation = None + + self._gate_proj = BlasLibLinear(self._config.input_dtype) + self._top_1_gate = RaggedTopKGating(config.input_dtype) + self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) + self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores) + + self._create_buffers() + + def _create_buffers(self): + + # Gating buffers + self._logits = torch.empty((self._config.max_tokens, self.n_experts), + dtype=self._config.input_dtype, + device=get_accelerator().current_device()) + self._expert_counts = torch.empty((self.n_experts, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._scores = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.float32, + device=get_accelerator().current_device()) + self._assignments = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._offsets = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + # Scatter buffers + self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), + dtype=self._config.input_dtype, + device=get_accelerator().current_device()) + self._expert_cumsum = torch.empty((self._config.n_experts, ), + dtype=torch.int64, + device=get_accelerator().current_device()) + self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + # GEMM Buffers + self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + if self._activation is not None: + self._gated_intermediate = torch.empty( + (self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + # Gather buffer + self._output = torch.empty((self._config.max_tokens, self._config.model_dim), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Ensures gate param is going to match the activation data type. + """ + param = param.to(self._config.input_dtype) + return InferenceParameter.initialize(param) + + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.input_dtype) + + # if len(param.shape) == 3: + # param = param.permute(0, 2, 1).contiguous() + return InferenceParameter.initialize(param) + + def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.input_dtype) + + # if len(param.shape) == 3: + # param = param.permute(0, 2, 1).contiguous() + return InferenceParameter.initialize(param) + + @property + def output(self) -> torch.Tensor: + return self._output + + def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, + gate_w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Helper function to isolate the logit for gating. This will take the hidden states + and produce the metadata + tensors for the CUTLASS ragged GEMMs. If the input has + been padded for CG, this will strip the padding for MoE. + + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [n_tokens, model_dim]. + batch_metadata (RaggedBatchWrapper): Batch metadata for the hidden states. + gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input, the cumsum of the offsets (for the MoE kernels themselves), the scores, and the mapped slots (to recover the original order of the tokens) + """ + + # Get views on the buffers for gating + logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) + scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k)) + assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k)) + offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k)) + mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k)) + moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1])) + + self._gate_proj(logits, hidden_states, gate_w) + self._expert_counts.zero_() + self._top_1_gate(self._expert_counts, scores, assignments, offsets, logits, batch_metadata) + self._moe_scatter(moe_input, self._expert_cumsum, mapped_slots, hidden_states, self._expert_counts, + assignments, offsets) + + return moe_input, self._expert_cumsum, scores, mapped_slots + + def forward(self, + hidden_states: torch.Tensor, + batch_metadata: RaggedBatchWrapper, + gate_w: torch.Tensor, + mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, + mlp_1_b: Optional[torch.Tensor] = None, + mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + MoE forward pass built on top of CUTLASS multi-GEMM. + + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim]. + gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. + """ + print("Using DSPytorchMultiGemmMoE forward pass") + + moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) + + # Get views on the buffers for GEMM + intermediate = empty_from(self._intermediate, + (hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1])) + output_unordered = empty_from(self._output_unordered, + (hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1])) + output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) + + for expert_idx in range(mlp_1_w.shape[0]): + min_bound = 0 if expert_idx == 0 else expert_cumsum[expert_idx - 1] + max_bound = expert_cumsum[expert_idx] + + input_slice = moe_input[min_bound:max_bound] + intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx] if mlp_1_b is not None else None) + + intermediate = _activation_reference(intermediate, self._config.activation) + output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx] if mlp_2_b is not None else None) + + output_unordered[min_bound:max_bound] = output_slice + + # if self._activation is not None: + # gated_intermediate = empty_from( + # self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1])) + # self._mlp_1( + # gated_intermediate, + # moe_input, + # mlp_1_w, + # expert_cumsum, + # mlp_1_b, + # ) + # self._activation(intermediate, gated_intermediate) + # else: + # self._mlp_1( + # intermediate, + # moe_input, + # mlp_1_w, + # expert_cumsum, + # mlp_1_b, + # ) + + # self._mlp_2( + # output_unordered, + # intermediate, + # mlp_2_w, + # expert_cumsum, + # mlp_2_b, + # ) + + self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts) + return output + From 88d758e3a95f87c0c08d7987e070c286c20729af Mon Sep 17 00:00:00 2001 From: HeyangQin Date: Wed, 29 May 2024 18:49:20 +0000 Subject: [PATCH 2/7] fix format --- deepspeed/inference/v2/modules/heuristics.py | 9 +++++---- .../modules/implementations/moe/pytorch_multi_gemm.py | 10 +++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index 38618a4b6018..090739f8103c 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -34,6 +34,7 @@ import torch + def instantiate_attention(attention_config: DSSelfAttentionConfig, engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase: """ @@ -133,12 +134,12 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin # check if we are on H100 or above if torch.cuda.get_device_capability(0)[0] >= 9: config = ConfigBundle(name="pytorch_multi_gemm_moe", - config=moe_config, - implementation_config=implementation_config) + config=moe_config, + implementation_config=implementation_config) else: config = ConfigBundle(name="cutlass_multi_gemm_moe", - config=moe_config, - implementation_config=implementation_config) + config=moe_config, + implementation_config=implementation_config) return DSMoERegistry.instantiate_config(config) diff --git a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py index 40058d3ffb92..5d85d03dba0a 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py @@ -23,6 +23,7 @@ # from ....kernels.cutlass_ops import MoEGEMM from ....inference_parameter import InferenceParameter + def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> torch.Tensor: if is_gated(act_type): act_func_map = { @@ -46,6 +47,7 @@ def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> return act_func_map[act_type](out_states) + def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Reference gating code. @@ -68,6 +70,7 @@ def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] return sorted_indices, original_indices, exp_count_cumsum, gates1_s + @DSMoERegistry.register_module class DSPytorchMultiGemmMoE(DSMoEBase): """ @@ -267,10 +270,12 @@ def forward(self, max_bound = expert_cumsum[expert_idx] input_slice = moe_input[min_bound:max_bound] - intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx] if mlp_1_b is not None else None) + intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], + mlp_1_b[expert_idx] if mlp_1_b is not None else None) intermediate = _activation_reference(intermediate, self._config.activation) - output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx] if mlp_2_b is not None else None) + output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], + mlp_2_b[expert_idx] if mlp_2_b is not None else None) output_unordered[min_bound:max_bound] = output_slice @@ -304,4 +309,3 @@ def forward(self, self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts) return output - From 022a7c6d7336e001bccb2c8932025ce740511a48 Mon Sep 17 00:00:00 2001 From: HeyangQin Date: Wed, 29 May 2024 18:51:46 +0000 Subject: [PATCH 3/7] remove unused code --- .../implementations/moe/pytorch_multi_gemm.py | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py index 5d85d03dba0a..d1138739733f 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py @@ -48,29 +48,6 @@ def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> return act_func_map[act_type](out_states) -def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Reference gating code. - """ - logits = logits.float() - probs = torch.nn.functional.softmax(logits, dim=1) - - indices1_s = torch.argmax(probs, dim=-1) - mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1]) - indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1 - indices1_s = torch.min(indices1_s, indices_mask) - - gates1_s = (probs * mask1).sum(dim=1) - - sorted_indices = indices1_s.sort()[1] - original_indices = sorted_indices.sort()[1] - - exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long() - exp_count_cumsum = exp_count.cumsum(dim=0) - - return sorted_indices, original_indices, exp_count_cumsum, gates1_s - - @DSMoERegistry.register_module class DSPytorchMultiGemmMoE(DSMoEBase): """ From c7db1805ce220b6b56b53cf72505eba56b134dd5 Mon Sep 17 00:00:00 2001 From: HeyangQin Date: Wed, 29 May 2024 18:57:14 +0000 Subject: [PATCH 4/7] use accelerator abstract --- deepspeed/inference/v2/modules/heuristics.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index 090739f8103c..4dc0c9e24fe8 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -32,8 +32,7 @@ DSUnembedRegistry, ) -import torch - +from deepspeed.accelerator import get_accelerator def instantiate_attention(attention_config: DSSelfAttentionConfig, engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase: @@ -132,7 +131,7 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin } # check if we are on H100 or above - if torch.cuda.get_device_capability(0)[0] >= 9: + if get_accelerator().get_device_capability(0)[0] >= 9: config = ConfigBundle(name="pytorch_multi_gemm_moe", config=moe_config, implementation_config=implementation_config) From 642e105f6f5528399f76e4f2b310621d52f697b9 Mon Sep 17 00:00:00 2001 From: HeyangQin Date: Wed, 29 May 2024 19:04:02 +0000 Subject: [PATCH 5/7] fix format --- deepspeed/inference/v2/modules/heuristics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index 4dc0c9e24fe8..548154120150 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -34,6 +34,7 @@ from deepspeed.accelerator import get_accelerator + def instantiate_attention(attention_config: DSSelfAttentionConfig, engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase: """ From ce6b694468ad12e7cfab945510bf2e015fc34aa5 Mon Sep 17 00:00:00 2001 From: HeyangQin Date: Wed, 29 May 2024 21:32:58 +0000 Subject: [PATCH 6/7] revert the accelerator commit --- deepspeed/inference/v2/modules/heuristics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index 548154120150..fe1a72ff030d 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -32,7 +32,7 @@ DSUnembedRegistry, ) -from deepspeed.accelerator import get_accelerator +import torch def instantiate_attention(attention_config: DSSelfAttentionConfig, @@ -132,7 +132,7 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin } # check if we are on H100 or above - if get_accelerator().get_device_capability(0)[0] >= 9: + if torch.cuda.get_device_capability(0)[0] >= 9: #ignore-cuda config = ConfigBundle(name="pytorch_multi_gemm_moe", config=moe_config, implementation_config=implementation_config) From 416fae47f209c08902a55686f3b35de8924f013d Mon Sep 17 00:00:00 2001 From: HeyangQin Date: Thu, 30 May 2024 07:45:00 +0000 Subject: [PATCH 7/7] fix format --- deepspeed/inference/v2/modules/heuristics.py | 2 +- .../v2/modules/implementations/moe/pytorch_multi_gemm.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index fe1a72ff030d..844fbc08bb30 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -132,7 +132,7 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin } # check if we are on H100 or above - if torch.cuda.get_device_capability(0)[0] >= 9: #ignore-cuda + if torch.cuda.get_device_capability(0)[0] >= 9: #ignore-cuda config = ConfigBundle(name="pytorch_multi_gemm_moe", config=moe_config, implementation_config=implementation_config) diff --git a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py index d1138739733f..9cf680b494e8 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py @@ -231,7 +231,6 @@ def forward(self, hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim]. gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. """ - print("Using DSPytorchMultiGemmMoE forward pass") moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w)