Skip to content

Commit

Permalink
reshaped tesors for parity to match
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Aug 2, 2024
1 parent 842001b commit 8d41aae
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions onnxruntime/test/python/transformers/test_parity_dbrx_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ def create_moe_onnx_graph(
),
]

fc1_experts_weights = fc1_experts_weights.view(16, 6144, 10752)
fc2_experts_weights = fc2_experts_weights.view(16, 6144, 10752).transpose(1, 2)
fc3_experts_weights = fc3_experts_weights.view(16, 6144, 10752)

fc1_shape = [num_experts, hidden_size, inter_size]
fc2_shape = [num_experts, inter_size, hidden_size]
fc3_shape = [num_experts, hidden_size, inter_size]
Expand Down Expand Up @@ -310,9 +306,9 @@ def __init__(self, hidden_size: int,
self.moe_num_experts,
self.hidden_size,
self.ffn_hidden_size,
self.mlp.w1,
self.mlp.v1,
self.mlp.w2,
self.mlp.w1.view(moe_num_experts, hidden_size, ffn_hidden_size).reshape(moe_num_experts, hidden_size, ffn_hidden_size),
self.mlp.w2.view(moe_num_experts, ffn_hidden_size, hidden_size).transpose(1, 2).reshape(moe_num_experts, ffn_hidden_size, hidden_size),
self.mlp.v1.view(moe_num_experts, hidden_size, ffn_hidden_size).reshape(moe_num_experts, hidden_size, ffn_hidden_size),
self.moe_top_k
)

Expand Down

0 comments on commit 8d41aae

Please sign in to comment.