Skip to content

Commit

Permalink
Fix MoE tensor parallelism tests (#20147)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Previously the expert weights are in row-major. But with the updated
cutlass extension introduced by
#20108, weights are stored
in col-major that aligns with Pytorch implementation. This change fixes
the way the tensors are sliced across shards.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wangyems authored Mar 29, 2024
1 parent 2f31560 commit f3a8642
Showing 1 changed file with 21 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,27 @@ def test_moe_with_tensor_parallelism(
inter_size,
)

fc1_experts_weights = fc1_experts_weights_all[
:, :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size()
]
fc2_experts_weights = fc2_experts_weights_all[
:, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size(), :
]
fc3_experts_weights = fc3_experts_weights_all[
:, :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size()
]
def get_fc1_tensor_shards(expert_weights):
return (
expert_weights.reshape(-1, inter_size, hidden_size)
.transpose(0, 2, 1)[
:, :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size()
]
.transpose(0, 2, 1)
)

def get_fc2_tensor_shards(expert_weights):
return (
expert_weights.reshape(-1, hidden_size, inter_size)
.transpose(0, 2, 1)[
:, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size(), :
]
.transpose(0, 2, 1)
)

fc1_experts_weights = get_fc1_tensor_shards(fc1_experts_weights_all)
fc2_experts_weights = get_fc2_tensor_shards(fc2_experts_weights_all)
fc3_experts_weights = get_fc1_tensor_shards(fc3_experts_weights_all)
fc1_experts_bias = fc1_experts_bias_all[
:, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size()
]
Expand Down

0 comments on commit f3a8642

Please sign in to comment.