From 41d86e9dca0bb36f6236aa764004d57fc347da1d Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 19 Jun 2024 22:13:32 +0000 Subject: [PATCH] added iobinding to moe mixtral benchmarking kernel --- .../transformers/test_parity_mixtral_moe.py | 106 ++++++++++++++++-- 1 file changed, 94 insertions(+), 12 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 00704626028a0..2fd3b9042bf1f 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -10,6 +10,7 @@ # license information. # -------------------------------------------------------------------------- import unittest +import time from collections import OrderedDict import numpy @@ -23,10 +24,14 @@ torch.manual_seed(42) numpy.random.seed(42) -ORT_DTYPE = TensorProto.FLOAT +ORT_DTYPE = TensorProto.FLOAT16 NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 THRESHOLD = 3e-2 +print(f"NP_TYPE_type: {NP_TYPE}") + + + def value_string_of(numpy_array): arr = numpy_array.flatten() @@ -69,12 +74,18 @@ def create_moe_onnx_graph( ), ] + print(f"Test 11") fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + print(f"torch_type: {torch_type}") + print(f"fc1_experts_weights dtype before conversion: {fc1_experts_weights.dtype}") + print(f"fc2_experts_weights dtype before conversion: {fc2_experts_weights.dtype}") + print(f"fc3_experts_weights dtype before conversion: {fc3_experts_weights.dtype}") + initializers = [ helper.make_tensor( "fc1_experts_weights", @@ -98,11 +109,16 @@ def create_moe_onnx_graph( raw=False, ), ] + print(f"fc1_experts_weights dtype after conversion: {fc1_experts_weights.dtype}") + print(f"fc2_experts_weights dtype after conversion: {fc2_experts_weights.dtype}") + print(f"fc3_experts_weights dtype after conversion: {fc3_experts_weights.dtype}") graph_inputs = [ helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), ] + print(f"Type: {type(graph_inputs[0])}") + graph_inputs.append( helper.make_tensor_value_info( "router_probs", @@ -111,6 +127,8 @@ def create_moe_onnx_graph( ) ) + print(f"Type: {type(graph_inputs[0])}") + graph_outputs = [ helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), ] @@ -123,6 +141,8 @@ def create_moe_onnx_graph( initializers, ) + print(f"Test 12") + model = helper.make_model(graph) return model.SerializeToString() @@ -212,14 +232,15 @@ class MixtralSparseMoeBlock(nn.Module): def __init__(self, config, batch_size, sequence_length): super().__init__() + print(f"Test 6") self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok - + print(f"Test 9") # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - + print(f"Test 10") self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) w1_list = [] @@ -229,11 +250,11 @@ def __init__(self, config, batch_size, sequence_length): w1_list.append(self.experts[i].w1.weight) w2_list.append(self.experts[i].w2.weight) w3_list.append(self.experts[i].w3.weight) - + print(f"Test 10") self.moe_experts_weight1 = torch.stack(w1_list, dim=0) self.moe_experts_weight2 = torch.stack(w2_list, dim=0) self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - + print(f"Test 10") self.batch_size = batch_size self.sequence_length = sequence_length self.moe_onnx_graph = create_moe_onnx_graph( @@ -246,23 +267,77 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight3, self.top_k, ) + + + print(f"Test 10") self.ort_sess = self.create_ort_session() + def create_ort_session(self): + print(f"Test 8") from onnxruntime import InferenceSession, SessionOptions + print(f"Test 13") sess_options = SessionOptions() cuda_providers = ["CUDAExecutionProvider"] + print(f"Available providers: {onnxruntime.get_available_providers()}") + print(f"Test 14") if cuda_providers[0] not in onnxruntime.get_available_providers(): return None - + print(f"Test 15") sess_options.log_severity_level = 2 ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) + print(f"Test 16") return ort_session + def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + print(f"Test 7") + iobinding = self.ort_sess.io_binding() + device_id = torch.cuda.current_device() + + iobinding.bind_input( + name="input", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), + ) + + print(f"fc1_experts_weights dtype after conversion: {NP_TYPE.dtype}") + iobinding.bind_input( + name="router_probs", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["router_probs"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + ort_inputs["router_probs"], "cuda", device_id + ).data_ptr(), + ) + + iobinding.bind_output( + name="output", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + numpy.zeros(ort_inputs["input"].shape), "cuda", device_id + ).data_ptr(), + ) + + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -305,12 +380,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states # , router_logits - def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - + print(f"Test: 17") ort_inputs = { "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), @@ -318,8 +393,11 @@ def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ort_output = None if self.ort_sess is not None: - ort_output = self.ort_sess.run(None, ort_inputs) - return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits + if not iobinding: + ort_output = self.ort_sess.run(None, ort_inputs) + else: + self.ort_run_with_iobinding(ort_inputs) + return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits # print_tensor("input", ort_inputs["input"]) # print_tensor("router_probs", ort_inputs["router_probs"]) @@ -333,7 +411,7 @@ def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def parity_check(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) torch_output = self.forward(hidden_state) - ort_output = self.ort_forward(hidden_state) + ort_output = self.ort_forward(hidden_state, iobinding=True) if ort_output is not None: assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04) print( @@ -352,9 +430,13 @@ def test_mixtral_moe_parity(self): for batch_size in [1, 16]: for sequence_length in [128, 1024]: # use a small sizes to speed up the test - config = MixtralConfig(hidden_size=256, intermediate_size=1024) + print(f"Test 1") + config = MixtralConfig() + print(f"Test 2") mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + print(f"Test 3") mixtral_moe.parity_check() + print(f"Test 4") if __name__ == "__main__":