diff --git a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py index 5d85d03dba0a..d1138739733f 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/pytorch_multi_gemm.py @@ -48,29 +48,6 @@ def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) -> return act_func_map[act_type](out_states) -def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Reference gating code. - """ - logits = logits.float() - probs = torch.nn.functional.softmax(logits, dim=1) - - indices1_s = torch.argmax(probs, dim=-1) - mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1]) - indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1 - indices1_s = torch.min(indices1_s, indices_mask) - - gates1_s = (probs * mask1).sum(dim=1) - - sorted_indices = indices1_s.sort()[1] - original_indices = sorted_indices.sort()[1] - - exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long() - exp_count_cumsum = exp_count.cumsum(dim=0) - - return sorted_indices, original_indices, exp_count_cumsum, gates1_s - - @DSMoERegistry.register_module class DSPytorchMultiGemmMoE(DSMoEBase): """