-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemenation of IObinding in Mixtral MoE Parity Script #21153
base: main
Are you sure you want to change the base?
Conversation
deleted the moe onnx model once it is done being used
import unittest | ||
from collections import OrderedDict | ||
|
||
import numpy | ||
import onnx |
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note test
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -38,6 +42,18 @@ def print_tensor(name, numpy_array): | |||
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};") | |||
|
|||
|
|||
def save_model_to_disk(model, model_path): | |||
external_data_path = "mixtral_moe.onnx" + ".data" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: external_data_path = model_path + ".data"
w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] | ||
v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] | ||
w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] | ||
for expert_idx in range(0, self.moe_num_experts): |
Check warning
Code scanning / lintrunner
RUFF/PIE808 Warning test
See https://docs.astral.sh/ruff/rules/unnecessary-range-start
@@ -0,0 +1,461 @@ | |||
# -------------------------------------------------------------------------- |
Check warning
Code scanning / lintrunner
RUFF/format Warning test
@@ -0,0 +1,461 @@ | |||
# -------------------------------------------------------------------------- |
Check warning
Code scanning / lintrunner
BLACK-ISORT/format Warning test
|
||
import numpy | ||
import torch | ||
import torch.nn.functional as F |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning test
See https://docs.astral.sh/ruff/rules/unused-import
onnx_model_local = create_moe_onnx_graph( | ||
num_rows, | ||
num_experts, | ||
num_experts, | ||
hidden_size, | ||
inter_size // get_size(), | ||
fc1_experts_weights, | ||
fc2_experts_weights, | ||
fc3_experts_weights, | ||
tensor_shards=get_size(), | ||
) |
Check failure
Code scanning / CodeQL
Wrong name for an argument in a call Error test
function create_moe_onnx_graph
|
||
import numpy | ||
import torch | ||
import torch.nn.functional as F |
Check notice
Code scanning / CodeQL
Unused import Note test
from typing import Tuple | ||
|
||
|
||
import onnxruntime |
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note test
|
||
|
||
import onnxruntime | ||
import onnx |
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note test
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.
self.ort_sess = self.create_ort_session() | ||
|
||
|
||
def test_moe_with_tensor_parallelism( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ORT moe op's tensor parallelism is tested so we do not need to test again here. let's just keep this script for testing single GPU
self.moe_num_experts = config.num_local_experts | ||
ffn_act_fn = {"name": config.hidden_act} | ||
|
||
self.w1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the huggingface implementation https://github.com/huggingface/transformers/blob/c54af4c77ed5d72ddcb79d0cc4804d97f21deabc/src/transformers/models/dbrx/modeling_dbrx.py#L738
self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
let's not change the implementation
w1_list = [] | ||
v1_list = [] | ||
w2_list = [] | ||
for i in range(self.moe_num_experts): | ||
w1_list.append(self.mlp.w1[i]) | ||
v1_list.append(self.mlp.v1[i]) | ||
w2_list.append(self.mlp.w2[i]) | ||
self.moe_experts_weight1 = torch.stack(w1_list, dim=0) | ||
self.moe_experts_weight2 = torch.stack(v1_list, dim=0) | ||
self.moe_experts_weight3 = torch.stack(w2_list, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are not needed
self.moe_num_experts, | ||
self.hidden_size, | ||
self.ffn_hidden_size, | ||
self.moe_experts_weight1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass self.mlp.w1/w2/v1 directly since they are defined with shape [num_experts, ...]
this is the part that's different from mixtral
you probably need to transpose one of them to make it align with ORT format
fc1_shape = [num_experts, num_experts * inter_size, hidden_size] | ||
fc2_shape = [num_experts, num_experts * inter_size, hidden_size] | ||
fc3_shape = [num_experts, num_experts * inter_size, hidden_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep it same as mixtral's
|
||
|
||
|
||
class DbrxRouter(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this class to just after DBRXconfig
batch_size, | ||
sequence_length, | ||
config) | ||
dbrx_moe.test_moe_with_tensor_parallelism(hidden_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only test single GPU here
return out | ||
|
||
|
||
def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's implement ort_forward() in class DbrxFFN since ORT MoE contains topk&softmax (part of DbrxRouter)
from collections import OrderedDict | ||
|
||
import numpy | ||
import os |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning test
See https://docs.astral.sh/ruff/rules/unused-import
def parity_check(self): | ||
config = DBRXConfig() | ||
ffn = DbrxFFN(config, self.batch_size, self.sequence_length) | ||
router = DbrxRouter(hidden_size=config.hidden_size, |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning test
See https://docs.astral.sh/ruff/rules/unused-variable
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_size) | ||
torch_output = ffn.forward(hidden_state) | ||
print("forward: ", torch_output) | ||
ort_output = ffn.ort_forward(hidden_state, iobinding=False) |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning test
See https://docs.astral.sh/ruff/rules/unused-variable
from collections import OrderedDict | ||
|
||
import numpy | ||
import os |
Check notice
Code scanning / CodeQL
Unused import Note test
#def delete_model_data(external_data): | ||
#os.remove("dbrx_moe.onnx") | ||
#os.remove(external_data) |
Check notice
Code scanning / CodeQL
Commented-out code Note test
def parity_check(self): | ||
config = DBRXConfig() | ||
ffn = DbrxFFN(config, self.batch_size, self.sequence_length) | ||
router = DbrxRouter(hidden_size=config.hidden_size, |
Check notice
Code scanning / CodeQL
Unused local variable Note test
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_size) | ||
torch_output = ffn.forward(hidden_state) | ||
print("forward: ", torch_output) | ||
ort_output = ffn.ort_forward(hidden_state, iobinding=False) |
Check notice
Code scanning / CodeQL
Unused local variable Note test
self.mlp.w1, | ||
self.mlp.v1, | ||
self.mlp.w2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
order should be w1, w2, v1 and with certain transpose operations
["output"], | ||
"MoE_0", | ||
k=topk, | ||
normalize_routing_weights=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be 0
fc1_experts_weights = fc1_experts_weights.view(16, 6144, 10752) | ||
fc2_experts_weights = fc2_experts_weights.view(16, 6144, 10752).transpose(1, 2) | ||
fc3_experts_weights = fc3_experts_weights.view(16, 6144, 10752) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's recommended to do view() and transpose() outside of this function
Motivation and Context
These changes were done to effectively use iobinding to mimic the results of kernel latencies with the MoE mixtral model. Now, benchmarking is available for the mixtral model through this parity script.