Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemenation of IObinding in Mixtral MoE Parity Script #21153

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

t-khaspear
Copy link

Motivation and Context

These changes were done to effectively use iobinding to mimic the results of kernel latencies with the MoE mixtral model. Now, benchmarking is available for the mixtral model through this parity script.

deleted the moe onnx model once it is done being used
import unittest
from collections import OrderedDict

import numpy
import onnx

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'onnx' is imported with both 'import' and 'import from'.
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.
wangyems
wangyems previously approved these changes Jun 27, 2024
Copy link
Contributor

@wangyems wangyems left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -38,6 +42,18 @@ def print_tensor(name, numpy_array):
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};")


def save_model_to_disk(model, model_path):
external_data_path = "mixtral_moe.onnx" + ".data"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: external_data_path = model_path + ".data"

wangyems
wangyems previously approved these changes Jun 28, 2024
w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
for expert_idx in range(0, self.moe_num_experts):

Check warning

Code scanning / lintrunner

RUFF/PIE808 Warning test

Unnecessary start argument in range.
See https://docs.astral.sh/ruff/rules/unnecessary-range-start
@@ -0,0 +1,461 @@
# --------------------------------------------------------------------------

Check warning

Code scanning / lintrunner

RUFF/format Warning test

Run lintrunner -a to apply this patch.
@@ -0,0 +1,461 @@
# --------------------------------------------------------------------------

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning test

Run lintrunner -a to apply this patch.

import numpy
import torch
import torch.nn.functional as F

Check warning

Code scanning / lintrunner

RUFF/F401 Warning test

torch.nn.functional imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
Comment on lines +330 to +340
onnx_model_local = create_moe_onnx_graph(
num_rows,
num_experts,
num_experts,
hidden_size,
inter_size // get_size(),
fc1_experts_weights,
fc2_experts_weights,
fc3_experts_weights,
tensor_shards=get_size(),
)

Check failure

Code scanning / CodeQL

Wrong name for an argument in a call Error test

Keyword argument 'tensor_shards' is not a supported parameter name of
function create_moe_onnx_graph
.

import numpy
import torch
import torch.nn.functional as F

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'F' is not used.
from typing import Tuple


import onnxruntime

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'onnxruntime' is imported with both 'import' and 'import from'.


import onnxruntime
import onnx

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'onnx' is imported with both 'import' and 'import from'.
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.
self.ort_sess = self.create_ort_session()


def test_moe_with_tensor_parallelism(
Copy link
Contributor

@wangyems wangyems Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ORT moe op's tensor parallelism is tested so we do not need to test again here. let's just keep this script for testing single GPU

self.moe_num_experts = config.num_local_experts
ffn_act_fn = {"name": config.hidden_act}

self.w1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size))
Copy link
Contributor

@wangyems wangyems Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the huggingface implementation https://github.com/huggingface/transformers/blob/c54af4c77ed5d72ddcb79d0cc4804d97f21deabc/src/transformers/models/dbrx/modeling_dbrx.py#L738

        self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
        self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
        self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))

let's not change the implementation

Comment on lines 260 to 269
w1_list = []
v1_list = []
w2_list = []
for i in range(self.moe_num_experts):
w1_list.append(self.mlp.w1[i])
v1_list.append(self.mlp.v1[i])
w2_list.append(self.mlp.w2[i])
self.moe_experts_weight1 = torch.stack(w1_list, dim=0)
self.moe_experts_weight2 = torch.stack(v1_list, dim=0)
self.moe_experts_weight3 = torch.stack(w2_list, dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are not needed

self.moe_num_experts,
self.hidden_size,
self.ffn_hidden_size,
self.moe_experts_weight1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass self.mlp.w1/w2/v1 directly since they are defined with shape [num_experts, ...]
this is the part that's different from mixtral
you probably need to transpose one of them to make it align with ORT format

Comment on lines 99 to 101
fc1_shape = [num_experts, num_experts * inter_size, hidden_size]
fc2_shape = [num_experts, num_experts * inter_size, hidden_size]
fc3_shape = [num_experts, num_experts * inter_size, hidden_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep it same as mixtral's




class DbrxRouter(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this class to just after DBRXconfig

batch_size,
sequence_length,
config)
dbrx_moe.test_moe_with_tensor_parallelism(hidden_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only test single GPU here

return out


def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor:
Copy link
Contributor

@wangyems wangyems Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's implement ort_forward() in class DbrxFFN since ORT MoE contains topk&softmax (part of DbrxRouter)

from collections import OrderedDict

import numpy
import os

Check warning

Code scanning / lintrunner

RUFF/F401 Warning test

def parity_check(self):
config = DBRXConfig()
ffn = DbrxFFN(config, self.batch_size, self.sequence_length)
router = DbrxRouter(hidden_size=config.hidden_size,

Check warning

Code scanning / lintrunner

RUFF/F841 Warning test

Local variable router is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_size)
torch_output = ffn.forward(hidden_state)
print("forward: ", torch_output)
ort_output = ffn.ort_forward(hidden_state, iobinding=False)

Check warning

Code scanning / lintrunner

RUFF/F841 Warning test

Local variable ort\_output is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
from collections import OrderedDict

import numpy
import os

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'os' is not used.
Comment on lines +64 to +66
#def delete_model_data(external_data):
#os.remove("dbrx_moe.onnx")
#os.remove(external_data)

Check notice

Code scanning / CodeQL

Commented-out code Note test

This comment appears to contain commented-out code.
def parity_check(self):
config = DBRXConfig()
ffn = DbrxFFN(config, self.batch_size, self.sequence_length)
router = DbrxRouter(hidden_size=config.hidden_size,

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable router is not used.
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_size)
torch_output = ffn.forward(hidden_state)
print("forward: ", torch_output)
ort_output = ffn.ort_forward(hidden_state, iobinding=False)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable ort_output is not used.
Comment on lines 313 to 315
self.mlp.w1,
self.mlp.v1,
self.mlp.w2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order should be w1, w2, v1 and with certain transpose operations

["output"],
"MoE_0",
k=topk,
normalize_routing_weights=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be 0

Comment on lines 107 to 109
fc1_experts_weights = fc1_experts_weights.view(16, 6144, 10752)
fc2_experts_weights = fc2_experts_weights.view(16, 6144, 10752).transpose(1, 2)
fc3_experts_weights = fc3_experts_weights.view(16, 6144, 10752)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's recommended to do view() and transpose() outside of this function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants