diff --git a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py index 9c73f1e7ccbd2..a49ad5f122d4c 100644 --- a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py @@ -104,10 +104,6 @@ def create_moe_onnx_graph( ), ] - fc1_experts_weights = fc1_experts_weights.view(16, 6144, 10752) - fc2_experts_weights = fc2_experts_weights.view(16, 6144, 10752).transpose(1, 2) - fc3_experts_weights = fc3_experts_weights.view(16, 6144, 10752) - fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] @@ -310,9 +306,9 @@ def __init__(self, hidden_size: int, self.moe_num_experts, self.hidden_size, self.ffn_hidden_size, - self.mlp.w1, - self.mlp.v1, - self.mlp.w2, + self.mlp.w1.view(moe_num_experts, hidden_size, ffn_hidden_size).reshape(moe_num_experts, hidden_size, ffn_hidden_size), + self.mlp.w2.view(moe_num_experts, ffn_hidden_size, hidden_size).transpose(1, 2).reshape(moe_num_experts, ffn_hidden_size, hidden_size), + self.mlp.v1.view(moe_num_experts, hidden_size, ffn_hidden_size).reshape(moe_num_experts, hidden_size, ffn_hidden_size), self.moe_top_k )