Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
HeyangQin committed May 29, 2024
1 parent 8274a01 commit 88d758e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
9 changes: 5 additions & 4 deletions deepspeed/inference/v2/modules/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import torch


def instantiate_attention(attention_config: DSSelfAttentionConfig,
engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase:
"""
Expand Down Expand Up @@ -133,12 +134,12 @@ def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngin
# check if we are on H100 or above
if torch.cuda.get_device_capability(0)[0] >= 9:
config = ConfigBundle(name="pytorch_multi_gemm_moe",
config=moe_config,
implementation_config=implementation_config)
config=moe_config,
implementation_config=implementation_config)
else:
config = ConfigBundle(name="cutlass_multi_gemm_moe",
config=moe_config,
implementation_config=implementation_config)
config=moe_config,
implementation_config=implementation_config)
return DSMoERegistry.instantiate_config(config)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# from ....kernels.cutlass_ops import MoEGEMM
from ....inference_parameter import InferenceParameter


def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> torch.Tensor:
if is_gated(act_type):
act_func_map = {
Expand All @@ -46,6 +47,7 @@ def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) ->

return act_func_map[act_type](out_states)


def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reference gating code.
Expand All @@ -68,6 +70,7 @@ def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]

return sorted_indices, original_indices, exp_count_cumsum, gates1_s


@DSMoERegistry.register_module
class DSPytorchMultiGemmMoE(DSMoEBase):
"""
Expand Down Expand Up @@ -267,10 +270,12 @@ def forward(self,
max_bound = expert_cumsum[expert_idx]

input_slice = moe_input[min_bound:max_bound]
intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx] if mlp_1_b is not None else None)
intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx],
mlp_1_b[expert_idx] if mlp_1_b is not None else None)

intermediate = _activation_reference(intermediate, self._config.activation)
output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx] if mlp_2_b is not None else None)
output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx],
mlp_2_b[expert_idx] if mlp_2_b is not None else None)

output_unordered[min_bound:max_bound] = output_slice

Expand Down Expand Up @@ -304,4 +309,3 @@ def forward(self,

self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts)
return output

0 comments on commit 88d758e

Please sign in to comment.