Skip to content

Commit

Permalink
remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
HeyangQin committed May 29, 2024
1 parent 88d758e commit 022a7c6
Showing 1 changed file with 0 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,6 @@ def _activation_reference(out_states: torch.Tensor, act_type: ActivationType) ->
return act_func_map[act_type](out_states)


def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reference gating code.
"""
logits = logits.float()
probs = torch.nn.functional.softmax(logits, dim=1)

indices1_s = torch.argmax(probs, dim=-1)
mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1])
indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1
indices1_s = torch.min(indices1_s, indices_mask)

gates1_s = (probs * mask1).sum(dim=1)

sorted_indices = indices1_s.sort()[1]
original_indices = sorted_indices.sort()[1]

exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long()
exp_count_cumsum = exp_count.cumsum(dim=0)

return sorted_indices, original_indices, exp_count_cumsum, gates1_s


@DSMoERegistry.register_module
class DSPytorchMultiGemmMoE(DSMoEBase):
"""
Expand Down

0 comments on commit 022a7c6

Please sign in to comment.