Skip to content

Commit

Permalink
added iobinding to moe mixtral benchmarking kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jun 19, 2024
1 parent dadd0c4 commit 41d86e9
Showing 1 changed file with 94 additions and 12 deletions.
106 changes: 94 additions & 12 deletions onnxruntime/test/python/transformers/test_parity_mixtral_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# license information.
# --------------------------------------------------------------------------
import unittest
import time
from collections import OrderedDict

import numpy
Expand All @@ -23,10 +24,14 @@
torch.manual_seed(42)
numpy.random.seed(42)

ORT_DTYPE = TensorProto.FLOAT
ORT_DTYPE = TensorProto.FLOAT16
NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32
THRESHOLD = 3e-2

print(f"NP_TYPE_type: {NP_TYPE}")




def value_string_of(numpy_array):
arr = numpy_array.flatten()
Expand Down Expand Up @@ -69,12 +74,18 @@ def create_moe_onnx_graph(
),
]

print(f"Test 11")
fc1_shape = [num_experts, hidden_size, inter_size]
fc2_shape = [num_experts, inter_size, hidden_size]
fc3_shape = [num_experts, hidden_size, inter_size]

torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32
print(f"torch_type: {torch_type}")

print(f"fc1_experts_weights dtype before conversion: {fc1_experts_weights.dtype}")
print(f"fc2_experts_weights dtype before conversion: {fc2_experts_weights.dtype}")
print(f"fc3_experts_weights dtype before conversion: {fc3_experts_weights.dtype}")

initializers = [
helper.make_tensor(
"fc1_experts_weights",
Expand All @@ -98,11 +109,16 @@ def create_moe_onnx_graph(
raw=False,
),
]
print(f"fc1_experts_weights dtype after conversion: {fc1_experts_weights.dtype}")
print(f"fc2_experts_weights dtype after conversion: {fc2_experts_weights.dtype}")
print(f"fc3_experts_weights dtype after conversion: {fc3_experts_weights.dtype}")

graph_inputs = [
helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]),
]

print(f"Type: {type(graph_inputs[0])}")

graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
Expand All @@ -111,6 +127,8 @@ def create_moe_onnx_graph(
)
)

print(f"Type: {type(graph_inputs[0])}")

graph_outputs = [
helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]),
]
Expand All @@ -123,6 +141,8 @@ def create_moe_onnx_graph(
initializers,
)

print(f"Test 12")

model = helper.make_model(graph)
return model.SerializeToString()

Expand Down Expand Up @@ -212,14 +232,15 @@ class MixtralSparseMoeBlock(nn.Module):

def __init__(self, config, batch_size, sequence_length):
super().__init__()
print(f"Test 6")
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok

print(f"Test 9")
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

print(f"Test 10")
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

w1_list = []
Expand All @@ -229,11 +250,11 @@ def __init__(self, config, batch_size, sequence_length):
w1_list.append(self.experts[i].w1.weight)
w2_list.append(self.experts[i].w2.weight)
w3_list.append(self.experts[i].w3.weight)

print(f"Test 10")
self.moe_experts_weight1 = torch.stack(w1_list, dim=0)
self.moe_experts_weight2 = torch.stack(w2_list, dim=0)
self.moe_experts_weight3 = torch.stack(w3_list, dim=0)

print(f"Test 10")
self.batch_size = batch_size
self.sequence_length = sequence_length
self.moe_onnx_graph = create_moe_onnx_graph(
Expand All @@ -246,23 +267,77 @@ def __init__(self, config, batch_size, sequence_length):
self.moe_experts_weight3,
self.top_k,
)


print(f"Test 10")

self.ort_sess = self.create_ort_session()


def create_ort_session(self):
print(f"Test 8")
from onnxruntime import InferenceSession, SessionOptions
print(f"Test 13")

sess_options = SessionOptions()

cuda_providers = ["CUDAExecutionProvider"]
print(f"Available providers: {onnxruntime.get_available_providers()}")
print(f"Test 14")
if cuda_providers[0] not in onnxruntime.get_available_providers():
return None

print(f"Test 15")
sess_options.log_severity_level = 2
ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"])
print(f"Test 16")

return ort_session

def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
print(f"Test 7")
iobinding = self.ort_sess.io_binding()
device_id = torch.cuda.current_device()

iobinding.bind_input(
name="input",
device_type="cuda",
device_id=device_id,
element_type=NP_TYPE,
shape=ort_inputs["input"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(),
)

print(f"fc1_experts_weights dtype after conversion: {NP_TYPE.dtype}")
iobinding.bind_input(
name="router_probs",
device_type="cuda",
device_id=device_id,
element_type=NP_TYPE,
shape=ort_inputs["router_probs"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(
ort_inputs["router_probs"], "cuda", device_id
).data_ptr(),
)

iobinding.bind_output(
name="output",
device_type="cuda",
device_id=device_id,
element_type=NP_TYPE,
shape=ort_inputs["input"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(
numpy.zeros(ort_inputs["input"].shape), "cuda", device_id
).data_ptr(),
)

s = time.time()
for _ in range(repeat):
iobinding.synchronize_inputs()
self.ort_sess.run_with_iobinding(iobinding)
iobinding.synchronize_outputs()
e = time.time()
print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms")

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
Expand Down Expand Up @@ -305,21 +380,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states # , router_logits

def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

print(f"Test: 17")
ort_inputs = {
"input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)),
"router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)),
}

ort_output = None
if self.ort_sess is not None:
ort_output = self.ort_sess.run(None, ort_inputs)
return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits
if not iobinding:
ort_output = self.ort_sess.run(None, ort_inputs)
else:
self.ort_run_with_iobinding(ort_inputs)
return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits

# print_tensor("input", ort_inputs["input"])
# print_tensor("router_probs", ort_inputs["router_probs"])
Expand All @@ -333,7 +411,7 @@ def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def parity_check(self):
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
torch_output = self.forward(hidden_state)
ort_output = self.ort_forward(hidden_state)
ort_output = self.ort_forward(hidden_state, iobinding=True)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04)
print(
Expand All @@ -352,9 +430,13 @@ def test_mixtral_moe_parity(self):
for batch_size in [1, 16]:
for sequence_length in [128, 1024]:
# use a small sizes to speed up the test
config = MixtralConfig(hidden_size=256, intermediate_size=1024)
print(f"Test 1")
config = MixtralConfig()
print(f"Test 2")
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
print(f"Test 3")
mixtral_moe.parity_check()
print(f"Test 4")


if __name__ == "__main__":
Expand Down

0 comments on commit 41d86e9

Please sign in to comment.