Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemenation of IObinding in Mixtral MoE Parity Script #21153

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
126 changes: 105 additions & 21 deletions onnxruntime/test/python/transformers/test_parity_mixtral_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import time
import unittest
from collections import OrderedDict

import numpy
import onnx
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'onnx' is imported with both 'import' and 'import from'.
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.
import pytest
import torch
import torch.nn.functional as F
from onnx import TensorProto, helper
Expand All @@ -23,7 +27,7 @@
torch.manual_seed(42)
numpy.random.seed(42)

ORT_DTYPE = TensorProto.FLOAT
ORT_DTYPE = TensorProto.FLOAT16
NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32
THRESHOLD = 3e-2

Expand All @@ -38,6 +42,18 @@
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};")


def save_model_to_disk(model, model_path):
external_data_path = "mixtral_moe.onnx" + ".data"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: external_data_path = model_path + ".data"

onnx.save_model(
model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path
)


def delete_model_data(external_data):
os.remove("mixtral_moe.onnx")
os.remove(external_data)


def create_moe_onnx_graph(
num_rows,
num_experts,
Expand Down Expand Up @@ -80,22 +96,22 @@
"fc1_experts_weights",
ORT_DTYPE,
fc1_shape,
fc1_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
fc1_experts_weights.to(torch_type).detach().numpy().tobytes(),
raw=True,
),
helper.make_tensor(
"fc2_experts_weights",
ORT_DTYPE,
fc2_shape,
fc2_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
fc2_experts_weights.to(torch_type).detach().numpy().tobytes(),
raw=True,
),
helper.make_tensor(
"fc3_experts_weights",
ORT_DTYPE,
fc3_shape,
fc3_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
fc3_experts_weights.to(torch_type).detach().numpy().tobytes(),
raw=True,
),
]

Expand Down Expand Up @@ -124,7 +140,11 @@
)

model = helper.make_model(graph)
return model.SerializeToString()
model_path = "mixtral_moe.onnx"
wangyems marked this conversation as resolved.
Show resolved Hide resolved

save_model_to_disk(model, model_path)

return model_path


class ClassInstantier(OrderedDict):
Expand Down Expand Up @@ -216,10 +236,8 @@
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok

# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

w1_list = []
Expand All @@ -229,11 +247,9 @@
w1_list.append(self.experts[i].w1.weight)
w2_list.append(self.experts[i].w2.weight)
w3_list.append(self.experts[i].w3.weight)

self.moe_experts_weight1 = torch.stack(w1_list, dim=0)
self.moe_experts_weight2 = torch.stack(w2_list, dim=0)
self.moe_experts_weight3 = torch.stack(w3_list, dim=0)

self.batch_size = batch_size
self.sequence_length = sequence_length
self.moe_onnx_graph = create_moe_onnx_graph(
Expand All @@ -257,12 +273,54 @@
cuda_providers = ["CUDAExecutionProvider"]
if cuda_providers[0] not in onnxruntime.get_available_providers():
return None

sess_options.log_severity_level = 2
ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"])

return ort_session

def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
iobinding = self.ort_sess.io_binding()
device_id = torch.cuda.current_device()

iobinding.bind_input(
name="input",
device_type="cuda",
device_id=device_id,
element_type=NP_TYPE,
shape=ort_inputs["input"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(),
)

iobinding.bind_input(
name="router_probs",
device_type="cuda",
device_id=device_id,
element_type=NP_TYPE,
shape=ort_inputs["router_probs"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(
ort_inputs["router_probs"], "cuda", device_id
).data_ptr(),
)

iobinding.bind_output(
name="output",
device_type="cuda",
device_id=device_id,
element_type=NP_TYPE,
shape=ort_inputs["input"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(
numpy.zeros(ort_inputs["input"].shape), "cuda", device_id
).data_ptr(),
)

s = time.time()
for _ in range(repeat):
iobinding.synchronize_inputs()
self.ort_sess.run_with_iobinding(iobinding)
iobinding.synchronize_outputs()
e = time.time()
print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms")

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
Expand Down Expand Up @@ -305,21 +363,23 @@
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states # , router_logits

def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

ort_inputs = {
"input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)),
"router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)),
}

ort_output = None
if self.ort_sess is not None:
ort_output = self.ort_sess.run(None, ort_inputs)
return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits
if not iobinding:
ort_output = self.ort_sess.run(None, ort_inputs)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
else:
self.ort_run_with_iobinding(ort_inputs)
return None

# print_tensor("input", ort_inputs["input"])
# print_tensor("router_probs", ort_inputs["router_probs"])
Expand All @@ -328,12 +388,12 @@
# print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy())
# print_tensor("output", ort_output[0])

return None
return ort_output

def parity_check(self):
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
torch_output = self.forward(hidden_state)
ort_output = self.ort_forward(hidden_state)
ort_output = self.ort_forward(hidden_state, iobinding=True)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04)
print(
Expand All @@ -346,16 +406,40 @@
" parity: OK",
)

def benchmark(self):
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
self.ort_forward(hidden_state, iobinding=True)


class TestMixtralMoE(unittest.TestCase):

def test_mixtral_moe_parity(self):
for batch_size in [1, 16]:
for sequence_length in [128, 1024]:
for batch_size in [1, 2]:
for sequence_length in [8, 16]:
# use a small sizes to speed up the test
config = MixtralConfig(hidden_size=256, intermediate_size=1024)
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
mixtral_moe.parity_check()

@pytest.mark.slow
def test_mixtral_moe_large(self):
for batch_size in [1, 8]:
for sequence_length in [16, 64]:
config = MixtralConfig()
wangyems marked this conversation as resolved.
Show resolved Hide resolved
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
mixtral_moe.parity_check()

@pytest.mark.slow
def test_mixtral_moe_benchmark(self):
for batch_size in [32, 64]:
for sequence_length in [128, 1024]:
config = MixtralConfig()
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
mixtral_moe.benchmark()

external_data_path = "mixtral_moe.onnx" + ".data"
delete_model_data(external_data_path)


if __name__ == "__main__":
unittest.main()
Loading