From 41d86e9dca0bb36f6236aa764004d57fc347da1d Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 19 Jun 2024 22:13:32 +0000 Subject: [PATCH 01/15] added iobinding to moe mixtral benchmarking kernel --- .../transformers/test_parity_mixtral_moe.py | 106 ++++++++++++++++-- 1 file changed, 94 insertions(+), 12 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 00704626028a0..2fd3b9042bf1f 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -10,6 +10,7 @@ # license information. # -------------------------------------------------------------------------- import unittest +import time from collections import OrderedDict import numpy @@ -23,10 +24,14 @@ torch.manual_seed(42) numpy.random.seed(42) -ORT_DTYPE = TensorProto.FLOAT +ORT_DTYPE = TensorProto.FLOAT16 NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 THRESHOLD = 3e-2 +print(f"NP_TYPE_type: {NP_TYPE}") + + + def value_string_of(numpy_array): arr = numpy_array.flatten() @@ -69,12 +74,18 @@ def create_moe_onnx_graph( ), ] + print(f"Test 11") fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + print(f"torch_type: {torch_type}") + print(f"fc1_experts_weights dtype before conversion: {fc1_experts_weights.dtype}") + print(f"fc2_experts_weights dtype before conversion: {fc2_experts_weights.dtype}") + print(f"fc3_experts_weights dtype before conversion: {fc3_experts_weights.dtype}") + initializers = [ helper.make_tensor( "fc1_experts_weights", @@ -98,11 +109,16 @@ def create_moe_onnx_graph( raw=False, ), ] + print(f"fc1_experts_weights dtype after conversion: {fc1_experts_weights.dtype}") + print(f"fc2_experts_weights dtype after conversion: {fc2_experts_weights.dtype}") + print(f"fc3_experts_weights dtype after conversion: {fc3_experts_weights.dtype}") graph_inputs = [ helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), ] + print(f"Type: {type(graph_inputs[0])}") + graph_inputs.append( helper.make_tensor_value_info( "router_probs", @@ -111,6 +127,8 @@ def create_moe_onnx_graph( ) ) + print(f"Type: {type(graph_inputs[0])}") + graph_outputs = [ helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), ] @@ -123,6 +141,8 @@ def create_moe_onnx_graph( initializers, ) + print(f"Test 12") + model = helper.make_model(graph) return model.SerializeToString() @@ -212,14 +232,15 @@ class MixtralSparseMoeBlock(nn.Module): def __init__(self, config, batch_size, sequence_length): super().__init__() + print(f"Test 6") self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok - + print(f"Test 9") # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - + print(f"Test 10") self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) w1_list = [] @@ -229,11 +250,11 @@ def __init__(self, config, batch_size, sequence_length): w1_list.append(self.experts[i].w1.weight) w2_list.append(self.experts[i].w2.weight) w3_list.append(self.experts[i].w3.weight) - + print(f"Test 10") self.moe_experts_weight1 = torch.stack(w1_list, dim=0) self.moe_experts_weight2 = torch.stack(w2_list, dim=0) self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - + print(f"Test 10") self.batch_size = batch_size self.sequence_length = sequence_length self.moe_onnx_graph = create_moe_onnx_graph( @@ -246,23 +267,77 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight3, self.top_k, ) + + + print(f"Test 10") self.ort_sess = self.create_ort_session() + def create_ort_session(self): + print(f"Test 8") from onnxruntime import InferenceSession, SessionOptions + print(f"Test 13") sess_options = SessionOptions() cuda_providers = ["CUDAExecutionProvider"] + print(f"Available providers: {onnxruntime.get_available_providers()}") + print(f"Test 14") if cuda_providers[0] not in onnxruntime.get_available_providers(): return None - + print(f"Test 15") sess_options.log_severity_level = 2 ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) + print(f"Test 16") return ort_session + def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + print(f"Test 7") + iobinding = self.ort_sess.io_binding() + device_id = torch.cuda.current_device() + + iobinding.bind_input( + name="input", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), + ) + + print(f"fc1_experts_weights dtype after conversion: {NP_TYPE.dtype}") + iobinding.bind_input( + name="router_probs", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["router_probs"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + ort_inputs["router_probs"], "cuda", device_id + ).data_ptr(), + ) + + iobinding.bind_output( + name="output", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + numpy.zeros(ort_inputs["input"].shape), "cuda", device_id + ).data_ptr(), + ) + + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -305,12 +380,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states # , router_logits - def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - + print(f"Test: 17") ort_inputs = { "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), @@ -318,8 +393,11 @@ def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ort_output = None if self.ort_sess is not None: - ort_output = self.ort_sess.run(None, ort_inputs) - return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits + if not iobinding: + ort_output = self.ort_sess.run(None, ort_inputs) + else: + self.ort_run_with_iobinding(ort_inputs) + return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits # print_tensor("input", ort_inputs["input"]) # print_tensor("router_probs", ort_inputs["router_probs"]) @@ -333,7 +411,7 @@ def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def parity_check(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) torch_output = self.forward(hidden_state) - ort_output = self.ort_forward(hidden_state) + ort_output = self.ort_forward(hidden_state, iobinding=True) if ort_output is not None: assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04) print( @@ -352,9 +430,13 @@ def test_mixtral_moe_parity(self): for batch_size in [1, 16]: for sequence_length in [128, 1024]: # use a small sizes to speed up the test - config = MixtralConfig(hidden_size=256, intermediate_size=1024) + print(f"Test 1") + config = MixtralConfig() + print(f"Test 2") mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + print(f"Test 3") mixtral_moe.parity_check() + print(f"Test 4") if __name__ == "__main__": From a7b5bf8ae5e6f3fd4458c32d698f328516a1cb7b Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 19 Jun 2024 19:41:49 -0700 Subject: [PATCH 02/15] Update test_parity_mixtral_moe.py --- .../transformers/test_parity_mixtral_moe.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 2fd3b9042bf1f..d3f19e84c0616 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -91,22 +91,22 @@ def create_moe_onnx_graph( "fc1_experts_weights", ORT_DTYPE, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), - raw=False, + fc1_experts_weights.to(torch_type).detach().numpy().tobytes(), + raw=True, ), helper.make_tensor( "fc2_experts_weights", ORT_DTYPE, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), - raw=False, + fc2_experts_weights.to(torch_type).detach().numpy().tobytes(), + raw=True, ), helper.make_tensor( "fc3_experts_weights", ORT_DTYPE, fc3_shape, - fc3_experts_weights.to(torch_type).flatten().tolist(), - raw=False, + fc3_experts_weights.to(torch_type).detach().numpy().tobytes(), + raw=True, ), ] print(f"fc1_experts_weights dtype after conversion: {fc1_experts_weights.dtype}") @@ -144,7 +144,12 @@ def create_moe_onnx_graph( print(f"Test 12") model = helper.make_model(graph) - return model.SerializeToString() + + import onnx + model_path = "mixtral_moe.onnx" + onnx.save_model(model, model_path, save_as_external_data=True, all_tensors_to_one_file=True) + + return model_path class ClassInstantier(OrderedDict): From 1e68394f6f40720dd2fdde053465c1e78a11d769 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 24 Jun 2024 19:48:22 +0000 Subject: [PATCH 03/15] Implemented iobinding for mixtral benchmarking --- .../transformers/test_parity_mixtral_moe.py | 53 ++++--------------- 1 file changed, 11 insertions(+), 42 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index d3f19e84c0616..ae512225681e6 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -20,6 +20,7 @@ from torch import nn import onnxruntime +import onnx torch.manual_seed(42) numpy.random.seed(42) @@ -28,8 +29,6 @@ NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 THRESHOLD = 3e-2 -print(f"NP_TYPE_type: {NP_TYPE}") - @@ -74,17 +73,12 @@ def create_moe_onnx_graph( ), ] - print(f"Test 11") + fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 - print(f"torch_type: {torch_type}") - - print(f"fc1_experts_weights dtype before conversion: {fc1_experts_weights.dtype}") - print(f"fc2_experts_weights dtype before conversion: {fc2_experts_weights.dtype}") - print(f"fc3_experts_weights dtype before conversion: {fc3_experts_weights.dtype}") initializers = [ helper.make_tensor( @@ -109,15 +103,11 @@ def create_moe_onnx_graph( raw=True, ), ] - print(f"fc1_experts_weights dtype after conversion: {fc1_experts_weights.dtype}") - print(f"fc2_experts_weights dtype after conversion: {fc2_experts_weights.dtype}") - print(f"fc3_experts_weights dtype after conversion: {fc3_experts_weights.dtype}") graph_inputs = [ helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), ] - print(f"Type: {type(graph_inputs[0])}") graph_inputs.append( helper.make_tensor_value_info( @@ -127,7 +117,6 @@ def create_moe_onnx_graph( ) ) - print(f"Type: {type(graph_inputs[0])}") graph_outputs = [ helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), @@ -141,11 +130,8 @@ def create_moe_onnx_graph( initializers, ) - print(f"Test 12") - model = helper.make_model(graph) - import onnx model_path = "mixtral_moe.onnx" onnx.save_model(model, model_path, save_as_external_data=True, all_tensors_to_one_file=True) @@ -237,15 +223,12 @@ class MixtralSparseMoeBlock(nn.Module): def __init__(self, config, batch_size, sequence_length): super().__init__() - print(f"Test 6") self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok - print(f"Test 9") # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - print(f"Test 10") self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) w1_list = [] @@ -255,11 +238,9 @@ def __init__(self, config, batch_size, sequence_length): w1_list.append(self.experts[i].w1.weight) w2_list.append(self.experts[i].w2.weight) w3_list.append(self.experts[i].w3.weight) - print(f"Test 10") self.moe_experts_weight1 = torch.stack(w1_list, dim=0) self.moe_experts_weight2 = torch.stack(w2_list, dim=0) self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - print(f"Test 10") self.batch_size = batch_size self.sequence_length = sequence_length self.moe_onnx_graph = create_moe_onnx_graph( @@ -274,32 +255,24 @@ def __init__(self, config, batch_size, sequence_length): ) - print(f"Test 10") self.ort_sess = self.create_ort_session() def create_ort_session(self): - print(f"Test 8") from onnxruntime import InferenceSession, SessionOptions - print(f"Test 13") sess_options = SessionOptions() cuda_providers = ["CUDAExecutionProvider"] - print(f"Available providers: {onnxruntime.get_available_providers()}") - print(f"Test 14") if cuda_providers[0] not in onnxruntime.get_available_providers(): return None - print(f"Test 15") sess_options.log_severity_level = 2 ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) - print(f"Test 16") return ort_session def ort_run_with_iobinding(self, ort_inputs, repeat=1000): - print(f"Test 7") iobinding = self.ort_sess.io_binding() device_id = torch.cuda.current_device() @@ -312,7 +285,6 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), ) - print(f"fc1_experts_weights dtype after conversion: {NP_TYPE.dtype}") iobinding.bind_input( name="router_probs", device_type="cuda", @@ -349,6 +321,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) + num_tokens = hidden_states.shape[0] routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) @@ -390,7 +363,6 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - print(f"Test: 17") ort_inputs = { "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), @@ -401,8 +373,9 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten if not iobinding: ort_output = self.ort_sess.run(None, ort_inputs) else: - self.ort_run_with_iobinding(ort_inputs) - return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits + for x in range(5): + self.ort_run_with_iobinding(ort_inputs) + return None # print_tensor("input", ort_inputs["input"]) # print_tensor("router_probs", ort_inputs["router_probs"]) @@ -411,21 +384,21 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) # print_tensor("output", ort_output[0]) - return None + return ort_output def parity_check(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) torch_output = self.forward(hidden_state) - ort_output = self.ort_forward(hidden_state, iobinding=True) - if ort_output is not None: - assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04) + ort_out = self.ort_forward(hidden_state, iobinding=True) + if ort_out is not None: + assert torch.allclose(torch_output, ort_out, rtol=1e-04, atol=1e-04) print( "batch_size:", self.batch_size, " sequence_length:", self.sequence_length, " max_diff:", - (torch_output - ort_output).abs().max(), + (torch_output - ort_out).abs().max(), " parity: OK", ) @@ -435,13 +408,9 @@ def test_mixtral_moe_parity(self): for batch_size in [1, 16]: for sequence_length in [128, 1024]: # use a small sizes to speed up the test - print(f"Test 1") config = MixtralConfig() - print(f"Test 2") mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) - print(f"Test 3") mixtral_moe.parity_check() - print(f"Test 4") if __name__ == "__main__": From 78b96169a0e970379f68a3dbf238873b55149218 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 24 Jun 2024 20:06:24 +0000 Subject: [PATCH 04/15] deleted an unnecessary line of code --- onnxruntime/test/python/transformers/test_parity_mixtral_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index ae512225681e6..61eb6f4e69a98 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -321,7 +321,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - num_tokens = hidden_states.shape[0] routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) From c0b78c2b625065590dde9453ae399a968917a786 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 24 Jun 2024 20:10:19 +0000 Subject: [PATCH 05/15] deleted unnecessary code --- .../test/python/transformers/test_parity_mixtral_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 61eb6f4e69a98..5bbdfeb81338c 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -372,8 +372,7 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten if not iobinding: ort_output = self.ort_sess.run(None, ort_inputs) else: - for x in range(5): - self.ort_run_with_iobinding(ort_inputs) + self.ort_run_with_iobinding(ort_inputs) return None # print_tensor("input", ort_inputs["input"]) From fcf07cfe5060644a7910031a1136e1ed24886011 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 24 Jun 2024 20:19:31 +0000 Subject: [PATCH 06/15] Deleted unnecessary code --- .../python/transformers/test_parity_mixtral_moe.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 5bbdfeb81338c..a22ad91c54ae5 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -382,21 +382,21 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) # print_tensor("output", ort_output[0]) - return ort_output + return None def parity_check(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) torch_output = self.forward(hidden_state) - ort_out = self.ort_forward(hidden_state, iobinding=True) - if ort_out is not None: - assert torch.allclose(torch_output, ort_out, rtol=1e-04, atol=1e-04) + ort_output = self.ort_forward(hidden_state, iobinding=True) + if ort_output is not None: + assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04) print( "batch_size:", self.batch_size, " sequence_length:", self.sequence_length, " max_diff:", - (torch_output - ort_out).abs().max(), + (torch_output - ort_output).abs().max(), " parity: OK", ) From aab9242d95b8c5dee2834ad43067c4a647b8a56a Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 25 Jun 2024 17:14:51 +0000 Subject: [PATCH 07/15] deleted white spaces in blank lines --- .../transformers/test_parity_mixtral_moe.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index a22ad91c54ae5..b68cc79c8185d 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -9,18 +9,18 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -import unittest import time +import unittest from collections import OrderedDict import numpy +import onnx import torch import torch.nn.functional as F from onnx import TensorProto, helper from torch import nn import onnxruntime -import onnx torch.manual_seed(42) numpy.random.seed(42) @@ -30,8 +30,6 @@ THRESHOLD = 3e-2 - - def value_string_of(numpy_array): arr = numpy_array.flatten() lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] @@ -73,13 +71,12 @@ def create_moe_onnx_graph( ), ] - fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 - + initializers = [ helper.make_tensor( "fc1_experts_weights", @@ -108,7 +105,6 @@ def create_moe_onnx_graph( helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), ] - graph_inputs.append( helper.make_tensor_value_info( "router_probs", @@ -117,7 +113,6 @@ def create_moe_onnx_graph( ) ) - graph_outputs = [ helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), ] @@ -131,10 +126,10 @@ def create_moe_onnx_graph( ) model = helper.make_model(graph) - + model_path = "mixtral_moe.onnx" onnx.save_model(model, model_path, save_as_external_data=True, all_tensors_to_one_file=True) - + return model_path @@ -253,12 +248,8 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight3, self.top_k, ) - - - self.ort_sess = self.create_ort_session() - def create_ort_session(self): from onnxruntime import InferenceSession, SessionOptions @@ -382,7 +373,7 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) # print_tensor("output", ort_output[0]) - return None + return ort_output def parity_check(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) From 7f7472bbc1b28fcf273fd125a1f5578bbff4529e Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 25 Jun 2024 21:04:25 +0000 Subject: [PATCH 08/15] added two tests: one for large cases and one for benchmarking cases deleted the moe onnx model once it is done being used --- .../transformers/test_parity_mixtral_moe.py | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index b68cc79c8185d..bde9bf5219a8a 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -11,6 +11,7 @@ # -------------------------------------------------------------------------- import time import unittest +import pytest from collections import OrderedDict import numpy @@ -126,11 +127,12 @@ def create_moe_onnx_graph( ) model = helper.make_model(graph) + return model.SerializeToString() - model_path = "mixtral_moe.onnx" - onnx.save_model(model, model_path, save_as_external_data=True, all_tensors_to_one_file=True) + #model_path = "mixtral_moe.onnx" + #onnx.save_model(model, model_path, save_as_external_data=True, all_tensors_to_one_file=True) - return model_path + #return model_path class ClassInstantier(OrderedDict): @@ -391,16 +393,36 @@ def parity_check(self): " parity: OK", ) + def benchmark(self): + self.ort_forward(iobinding=True) + class TestMixtralMoE(unittest.TestCase): def test_mixtral_moe_parity(self): - for batch_size in [1, 16]: - for sequence_length in [128, 1024]: + for batch_size in [1, 2]: + for sequence_length in [8, 16]: # use a small sizes to speed up the test + config = MixtralConfig(hidden_size=256, intermediate_size=1024) + mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + mixtral_moe.parity_check() + + @pytest.mark.slow + def test_mixtral_moe_large(self): + for batch_size in [1, 8]: + for sequence_length in [16, 64]: config = MixtralConfig() mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) mixtral_moe.parity_check() + @pytest.mark.slow + def test_mixtral_moe_benchmark(self): + for batch_size in [32, 64]: + for sequence_length in [128, 1024]: + config = MixtralConfig() + mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + mixtral_moe.benchmark() + + if __name__ == "__main__": unittest.main() From ad45401f270dc79a581e5f3622b108145b2f76b9 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 26 Jun 2024 22:34:17 +0000 Subject: [PATCH 09/15] Created a new function to save the model and to to recognize the external data --- .../transformers/test_parity_mixtral_moe.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index bde9bf5219a8a..00a928f4faf85 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -9,13 +9,14 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import os import time import unittest -import pytest from collections import OrderedDict import numpy import onnx +import pytest import torch import torch.nn.functional as F from onnx import TensorProto, helper @@ -41,6 +42,15 @@ def print_tensor(name, numpy_array): print(f"const std::vector {name} = {value_string_of(numpy_array)};") +def create_onnx_graph(model, model_path): + external_data_path = "mixtral_moe.onnx" + ".data" + onnx.save_model( + model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path + ) + + return model_path + + def create_moe_onnx_graph( num_rows, num_experts, @@ -127,12 +137,11 @@ def create_moe_onnx_graph( ) model = helper.make_model(graph) - return model.SerializeToString() + model_path = "mixtral_moe.onnx" - #model_path = "mixtral_moe.onnx" - #onnx.save_model(model, model_path, save_as_external_data=True, all_tensors_to_one_file=True) + save_model = create_onnx_graph(model, model_path) - #return model_path + return save_model class ClassInstantier(OrderedDict): @@ -250,6 +259,7 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight3, self.top_k, ) + self.ort_sess = self.create_ort_session() def create_ort_session(self): @@ -394,10 +404,12 @@ def parity_check(self): ) def benchmark(self): - self.ort_forward(iobinding=True) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + self.ort_forward(hidden_state, iobinding=True) class TestMixtralMoE(unittest.TestCase): + def test_mixtral_moe_parity(self): for batch_size in [1, 2]: for sequence_length in [8, 16]: @@ -422,6 +434,9 @@ def test_mixtral_moe_benchmark(self): mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) mixtral_moe.benchmark() + os.remove("mixtral_moe.onnx") + external_data_path = "mixtral_moe.onnx" + ".data" + os.remove(external_data_path) if __name__ == "__main__": From 2de42ed179774f4e161438dd7b64165caabf007d Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 27 Jun 2024 21:30:36 +0000 Subject: [PATCH 10/15] Renamed the new function and added a new one to delete the model and the model data --- .../python/transformers/test_parity_mixtral_moe.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 00a928f4faf85..848607ae45771 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -42,13 +42,16 @@ def print_tensor(name, numpy_array): print(f"const std::vector {name} = {value_string_of(numpy_array)};") -def create_onnx_graph(model, model_path): +def save_model_to_disk(model, model_path): external_data_path = "mixtral_moe.onnx" + ".data" onnx.save_model( model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path ) - return model_path + +def delete_model_data(external_data): + os.remove("mixtral_moe.onnx") + os.remove(external_data) def create_moe_onnx_graph( @@ -139,9 +142,9 @@ def create_moe_onnx_graph( model = helper.make_model(graph) model_path = "mixtral_moe.onnx" - save_model = create_onnx_graph(model, model_path) + save_model_to_disk(model, model_path) - return save_model + return model_path class ClassInstantier(OrderedDict): @@ -434,9 +437,8 @@ def test_mixtral_moe_benchmark(self): mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) mixtral_moe.benchmark() - os.remove("mixtral_moe.onnx") external_data_path = "mixtral_moe.onnx" + ".data" - os.remove(external_data_path) + delete_model_data(external_data_path) if __name__ == "__main__": From 3183dde3bef9e478998a9ea77df422ae70932743 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 28 Jun 2024 21:53:09 +0000 Subject: [PATCH 11/15] Changed "mixtral_moe.onn" to model_path --- onnxruntime/test/python/transformers/test_parity_mixtral_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 848607ae45771..3393e7d3be3c7 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -43,7 +43,7 @@ def print_tensor(name, numpy_array): def save_model_to_disk(model, model_path): - external_data_path = "mixtral_moe.onnx" + ".data" + external_data_path = model_path + ".data" onnx.save_model( model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path ) From 67827a0045168e55281cb2159be49c217264bfcd Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 2 Jul 2024 20:49:30 +0000 Subject: [PATCH 12/15] DBRX parity script --- .../transformers/test_parity_dbrx_moe.py | 461 ++++++++++++++++++ 1 file changed, 461 insertions(+) create mode 100644 onnxruntime/test/python/transformers/test_parity_dbrx_moe.py diff --git a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py new file mode 100644 index 0000000000000..82b801cd4940f --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py @@ -0,0 +1,461 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest +import time +from collections import OrderedDict + +import numpy +import torch +import torch.nn.functional as F +from onnx import TensorProto, helper +from torch import nn +from typing import Tuple + +import onnxruntime +import onnx + +torch.manual_seed(42) +numpy.random.seed(42) + +ORT_DTYPE = TensorProto.BFLOAT16 +NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.BFLOAT16 else numpy.float32 +THRESHOLD = 3e-2 + + + + +def value_string_of(numpy_array): + arr = numpy_array.flatten() + lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] + return "{\n " + "f,\n ".join(lines) + "f}" + + +def print_tensor(name, numpy_array): + print(f"const std::vector {name} = {value_string_of(numpy_array)};") + + +def create_moe_onnx_graph( + num_rows, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc3_experts_weights, + topk, +): + nodes = [ + helper.make_node( + "MoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ], + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="silu", + domain="com.microsoft", + ), + ] + + + fc1_shape = [num_experts, num_experts * inter_size, hidden_size] + fc2_shape = [num_experts, num_experts * inter_size, hidden_size] + fc3_shape = [num_experts, num_experts * inter_size, hidden_size] + + + torch_type = torch.bfloat16 if ORT_DTYPE == TensorProto.BFLOAT16 else torch.float32 + + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE, + fc1_shape, + fc1_experts_weights.to(torch_type).detach().numpy().tobytes(), + raw=True, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE, + fc2_shape, + fc2_experts_weights.to(torch_type).detach().numpy().tobytes(), + raw=True, + ), + helper.make_tensor( + "fc3_experts_weights", + ORT_DTYPE, + fc3_shape, + fc3_experts_weights.to(torch_type).detach().numpy().tobytes(), + raw=True, + ), + ] + + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), + ] + + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [num_rows, num_experts], + ) + ) + + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "silu": nn.SiLU, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +class DBRXConfig: + def __init__( + self, + hidden_size=6144, + intermediate_size=10752, + num_hidden_layers=40, + num_attention_heads=48, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + rope_theta=5e5, + attention_dropout=0.0, + num_experts_per_tok=4, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + +class DbrxExpertGLU(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, config: DBRXConfig): + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = config.intermediate_size + self.moe_num_experts = config.num_local_experts + ffn_act_fn = {"name": config.hidden_act} + + self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + + act_fn_name = ffn_act_fn.get("name", "silu") + self.activation_fn = ACT2FN[act_fn_name] + + def forward( + self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor + ) -> torch.Tensor: + gate_proj = x.matmul(expert_w1.t()) + up_proj = x.matmul(expert_v1.t()) + gate_proj = self.activation_fn(gate_proj) + intermediate_states = gate_proj * up_proj + down_proj = intermediate_states.matmul(expert_w2) + return down_proj + + +class DbrxExperts(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, config: DBRXConfig): + super().__init__() + self.moe_num_experts = config.num_local_experts + self.mlp = DbrxExpertGLU( + hidden_size=hidden_size, + ffn_hidden_size=config.intermediate_size, + moe_num_experts=moe_num_experts, + ffn_act_fn=config.hidden_act, + ) + + def forward( + self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor + ) -> torch.Tensor: + bsz, q_len, hidden_size = x.shape + x = x.view(-1, hidden_size) + out = torch.zeros_like(x) + + expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + # Chunk experts at once to avoid storing full parameter multiple times in autograd + w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] + v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] + w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] + for expert_idx in range(0, self.moe_num_experts): + topk_idx, token_idx = torch.where(expert_mask[expert_idx]) + if token_idx.shape[0] == 0: + continue + + token_list = token_idx + topk_list = topk_idx + + expert_tokens = x[None, token_list].reshape(-1, hidden_size) + expert_out = ( + self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) + * top_weights[token_list, topk_list, None] + ) + + out.index_add_(0, token_idx, expert_out) + + out = out.reshape(bsz, q_len, hidden_size) + return out + + +class DbrxRouter(nn.Module): + def __init__( + self, + hidden_size: int, + config: DBRXConfig, + moe_num_experts: int, + moe_top_k: int, + batch_size: int, + sequence_length: int, + ffn_hidden_size: int, + ffn_act_fn: dict + ): + super().__init__() + self.hidden_size = hidden_size + self.moe_num_experts = config.num_local_experts + self.moe_top_k = config.num_experts_per_tok + self.ffn_hidden_size = config.intermediate_size + self.ffn_act_fn = {"name", config.hidden_act} + + self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False) + self.experts = nn.ModuleList([DbrxExpertGLU(hidden_size, ffn_hidden_size, moe_num_experts, ffn_act_fn, config) for _ in range(self.moe_num_experts)]) + + + w1_list = [] + v1_list = [] + w2_list = [] + for i in range(self.moe_num_experts): + w1_list.append(self.experts[i].w1) + v1_list.append(self.experts[i].v1) + w2_list.append(self.experts[i].w2) + self.moe_experts_weight1 = torch.stack(w1_list, dim=0) + self.moe_experts_weight2 = torch.stack(v1_list, dim=0) + self.moe_experts_weight3 = torch.stack(w2_list, dim=0) + self.batch_size = batch_size + self.sequence_length = sequence_length + self.moe_onnx_graph = create_moe_onnx_graph( + self.batch_size * self.sequence_length, + self.moe_num_experts, + self.hidden_size, + self.ffn_hidden_size, + self.moe_experts_weight1, + self.moe_experts_weight2, + self.moe_experts_weight3, + self.moe_top_k + ) + + self.ort_sess = self.create_ort_session() + + + def create_ort_session(self): + from onnxruntime import InferenceSession, SessionOptions + + sess_options = SessionOptions() + + cuda_providers = ["CUDAExecutionProvider"] + if cuda_providers[0] not in onnxruntime.get_available_providers(): + return None + sess_options.log_severity_level = 2 + ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) + + return ort_session + + def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + iobinding = self.ort_sess.io_binding() + device_id = torch.cuda.current_device() + + iobinding.bind_input( + name="input", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), + ) + + iobinding.bind_input( + name="router_probs", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["router_probs"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + ort_inputs["router_probs"], "cuda", device_id + ).data_ptr(), + ) + + iobinding.bind_output( + name="output", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + numpy.zeros(ort_inputs["input"].shape), "cuda", device_id + ).data_ptr(), + ) + + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32) + top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) + + top_weights_scale = ( + torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True) + if self.moe_normalize_expert_weights is not None + else 1.0 + ) + top_weights = top_weights / top_weights_scale + + weights = weights.to(hidden_states.dtype) + top_weights = top_weights.to(hidden_states.dtype) + return weights, top_weights, top_experts + + def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + ort_inputs = { + "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), + "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), + } + + ort_output = None + if self.ort_sess is not None: + if not iobinding: + ort_output = self.ort_sess.run(None, ort_inputs) + else: + self.ort_run_with_iobinding(ort_inputs) + return None + + # print_tensor("input", ort_inputs["input"]) + # print_tensor("router_probs", ort_inputs["router_probs"]) + # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) + # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) + # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) + # print_tensor("output", ort_output[0]) + + return ort_output + + def parity_check(self): + experts = DbrxExperts() + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + torch_output = self.forward(hidden_state) + final_torch_output = experts.forward(torch_output) + ort_output = self.ort_forward(hidden_state, iobinding=True) + if ort_output is not None: + assert torch.allclose(final_torch_output, ort_output, rtol=1e-04, atol=1e-04) + print( + "batch_size:", + self.batch_size, + " sequence_length:", + self.sequence_length, + " max_diff:", + (torch_output - ort_output).abs().max(), + " parity: OK", + ) + + +class TestDBRXMoE(unittest.TestCase): + def test_dbrx_moe_parity(self): + for batch_size in [1, 16]: + for sequence_length in [128, 1024]: + # use a small sizes to speed up the test + config = DBRXConfig() + hidden_size = config.hidden_size + moe_num_experts = config.num_local_experts + moe_top_k = config.num_experts_per_tok + ffn_hidden_size = config.intermediate_size + ffn_act_fn = {"name", config.hidden_act} + dbrx_moe = DbrxRouter(hidden_size, + config, + moe_num_experts, + moe_top_k, + batch_size, + sequence_length, + ffn_hidden_size, + ffn_act_fn,) + dbrx_moe.parity_check() + + +if __name__ == "__main__": + unittest.main() From 216dcd7c2d74d8fb422b7a02e91490e9a9b792a5 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 10 Jul 2024 01:08:27 +0000 Subject: [PATCH 13/15] sharding implementation --- .../transformers/test_parity_dbrx_moe.py | 396 +++++++++++++----- 1 file changed, 297 insertions(+), 99 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py index 82b801cd4940f..54c93ac3f24ec 100644 --- a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py @@ -16,22 +16,32 @@ import numpy import torch import torch.nn.functional as F +from mpi4py import MPI from onnx import TensorProto, helper from torch import nn from typing import Tuple + import onnxruntime import onnx +comm = MPI.COMM_WORLD + torch.manual_seed(42) numpy.random.seed(42) -ORT_DTYPE = TensorProto.BFLOAT16 -NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.BFLOAT16 else numpy.float32 -THRESHOLD = 3e-2 +def get_rank(): + return comm.Get_rank() + +def get_size(): + return comm.Get_size() +def print_out(*args): + if get_rank() == 0: + print(*args) + def value_string_of(numpy_array): arr = numpy_array.flatten() @@ -42,6 +52,13 @@ def value_string_of(numpy_array): def print_tensor(name, numpy_array): print(f"const std::vector {name} = {value_string_of(numpy_array)};") +ORT_DTYPE = TensorProto.FLOAT16 +NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 +THRESHOLD = 3e-2 +THRESHOLD_EP = 1e-6 + +local_rank = get_rank() + def create_moe_onnx_graph( num_rows, @@ -74,13 +91,21 @@ def create_moe_onnx_graph( ), ] + print("fc1_experts_weights shape:", fc1_experts_weights.shape) + print("fc2_experts_weights shape:", fc2_experts_weights.shape) + print("fc3_experts_weights shape:", fc3_experts_weights.shape) + fc1_shape = [num_experts, num_experts * inter_size, hidden_size] fc2_shape = [num_experts, num_experts * inter_size, hidden_size] fc3_shape = [num_experts, num_experts * inter_size, hidden_size] + print("Expected fc1_shape:", fc1_shape) + print("Expected fc2_shape:", fc2_shape) + print("Expected fc3_shape:", fc3_shape) - torch_type = torch.bfloat16 if ORT_DTYPE == TensorProto.BFLOAT16 else torch.float32 + + torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 initializers = [ @@ -156,7 +181,7 @@ class DBRXConfig: def __init__( self, hidden_size=6144, - intermediate_size=10752, + intermediate_size=1500, num_hidden_layers=40, num_attention_heads=48, num_key_value_heads=8, @@ -198,9 +223,9 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, self.moe_num_experts = config.num_local_experts ffn_act_fn = {"name": config.hidden_act} - self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) - self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) - self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.w1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) + self.v1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) + self.w2 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) act_fn_name = ffn_act_fn.get("name", "silu") self.activation_fn = ACT2FN[act_fn_name] @@ -217,87 +242,28 @@ def forward( class DbrxExperts(nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, config: DBRXConfig): + def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, batch_size: int, sequence_length: int, config: DBRXConfig): super().__init__() self.moe_num_experts = config.num_local_experts - self.mlp = DbrxExpertGLU( - hidden_size=hidden_size, - ffn_hidden_size=config.intermediate_size, - moe_num_experts=moe_num_experts, - ffn_act_fn=config.hidden_act, - ) - - def forward( - self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor - ) -> torch.Tensor: - bsz, q_len, hidden_size = x.shape - x = x.view(-1, hidden_size) - out = torch.zeros_like(x) - - expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) - # Chunk experts at once to avoid storing full parameter multiple times in autograd - w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( - self.moe_num_experts, dim=0 - ) - v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( - self.moe_num_experts, dim=0 - ) - w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( - self.moe_num_experts, dim=0 - ) - w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] - v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] - w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] - for expert_idx in range(0, self.moe_num_experts): - topk_idx, token_idx = torch.where(expert_mask[expert_idx]) - if token_idx.shape[0] == 0: - continue - - token_list = token_idx - topk_list = topk_idx - - expert_tokens = x[None, token_list].reshape(-1, hidden_size) - expert_out = ( - self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) - * top_weights[token_list, topk_list, None] - ) - - out.index_add_(0, token_idx, expert_out) - - out = out.reshape(bsz, q_len, hidden_size) - return out - - -class DbrxRouter(nn.Module): - def __init__( - self, - hidden_size: int, - config: DBRXConfig, - moe_num_experts: int, - moe_top_k: int, - batch_size: int, - sequence_length: int, - ffn_hidden_size: int, - ffn_act_fn: dict - ): - super().__init__() + self.config = DBRXConfig() self.hidden_size = hidden_size - self.moe_num_experts = config.num_local_experts - self.moe_top_k = config.num_experts_per_tok self.ffn_hidden_size = config.intermediate_size - self.ffn_act_fn = {"name", config.hidden_act} - - self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False) - self.experts = nn.ModuleList([DbrxExpertGLU(hidden_size, ffn_hidden_size, moe_num_experts, ffn_act_fn, config) for _ in range(self.moe_num_experts)]) - + self.moe_top_k = config.num_experts_per_tok + self.mlp = DbrxExpertGLU( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + ffn_act_fn=config.hidden_act, + config=config + ) w1_list = [] v1_list = [] w2_list = [] for i in range(self.moe_num_experts): - w1_list.append(self.experts[i].w1) - v1_list.append(self.experts[i].v1) - w2_list.append(self.experts[i].w2) + w1_list.append(self.mlp.w1[i]) + v1_list.append(self.mlp.v1[i]) + w2_list.append(self.mlp.w2[i]) self.moe_experts_weight1 = torch.stack(w1_list, dim=0) self.moe_experts_weight2 = torch.stack(v1_list, dim=0) self.moe_experts_weight3 = torch.stack(w2_list, dim=0) @@ -317,6 +283,146 @@ def __init__( self.ort_sess = self.create_ort_session() + def test_moe_with_tensor_parallelism( + self, + hidden_size, + inter_size, + num_experts, + num_rows, + threshold=THRESHOLD, + ): + assert inter_size % get_size() == 0 + + ( + onnx_model_full, + fc1_experts_weights_all, + fc2_experts_weights_all, + fc3_experts_weights_all, + ) = self.generate_weights_and_initial_model( + num_rows, + num_experts, + hidden_size, + inter_size, + ) + + def get_fc1_tensor_shards(expert_weights): + return ( + expert_weights.reshape(-1, inter_size, hidden_size) + .transpose(0, 2, 1)[ + :, :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size() + ] + .transpose(0, 2, 1) + ) + + def get_fc2_tensor_shards(expert_weights): + return ( + expert_weights.reshape(-1, hidden_size, inter_size) + .transpose(0, 2, 1)[ + :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size(), : + ] + .transpose(0, 2, 1) + ) + + fc1_experts_weights = get_fc1_tensor_shards(fc1_experts_weights_all) + fc2_experts_weights = get_fc2_tensor_shards(fc2_experts_weights_all) + fc3_experts_weights = get_fc1_tensor_shards(fc3_experts_weights_all) + + onnx_model_local = create_moe_onnx_graph( + num_rows, + num_experts, + num_experts, + hidden_size, + inter_size // get_size(), + fc1_experts_weights, + fc2_experts_weights, + fc3_experts_weights, + tensor_shards=get_size(), + ) + + self.run_ort_with_parity_check( + onnx_model_full, + onnx_model_local, + num_rows, + hidden_size, + num_experts, + inter_size, + threshold, + ) + + def run_ort_with_parity_check( + self, + onnx_model_full, + onnx_model_local, + num_rows, + hidden_size, + num_experts, + inter_size, + threshold, + ): + sess_options = onnxruntime.SessionOptions() + cuda_provider_options = {"device_id": local_rank} + execution_providers = [("CUDAExecutionProvider", cuda_provider_options)] + + ort_session = onnxruntime.InferenceSession(onnx_model_full, sess_options, providers=execution_providers) + ort_session_local = onnxruntime.InferenceSession(onnx_model_local, sess_options, providers=execution_providers) + + ort_inputs = { + ort_session.get_inputs()[0].name: numpy.random.rand(num_rows, hidden_size).astype(NP_TYPE), + ort_session.get_inputs()[1].name: numpy.random.rand(num_rows, num_experts).astype(NP_TYPE), + } + + output = ort_session.run(None, ort_inputs) + sharded_output = ort_session_local.run(None, ort_inputs) + + print_out("max diff:", numpy.max(numpy.abs(output[0] - sharded_output[0]))) + assert numpy.allclose(output[0], sharded_output[0], atol=threshold, rtol=threshold) + + print_out( + "hidden_size:", + hidden_size, + " inter_size:", + inter_size, + " num_experts:", + num_experts, + " num_rows:", + num_rows, + " world_size:", + get_size(), + " Parity: OK", + ) + + + def generate_weights_and_initial_model( + self, + num_rows, + num_experts, + hidden_size, + inter_size, + ): + #s = 0.1 + fc1_experts_weights_all = self.moe_experts_weight1 + fc2_experts_weights_all = self.moe_experts_weight2 + fc3_experts_weights_all = self.moe_experts_weight3 + + onnx_model_full = create_moe_onnx_graph( + num_rows, + num_experts, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights_all, + fc2_experts_weights_all, + fc3_experts_weights_all, + ) + + return ( + onnx_model_full, + fc1_experts_weights_all, + fc2_experts_weights_all, + fc3_experts_weights_all, + ) + + def create_ort_session(self): from onnxruntime import InferenceSession, SessionOptions @@ -373,21 +479,46 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): e = time.time() print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32) - top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) + def forward( + self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor + ) -> torch.Tensor: + bsz, q_len, hidden_size = x.shape + x = x.view(-1, hidden_size) + out = torch.zeros_like(x) - top_weights_scale = ( - torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True) - if self.moe_normalize_expert_weights is not None - else 1.0 + expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + # Chunk experts at once to avoid storing full parameter multiple times in autograd + w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 ) - top_weights = top_weights / top_weights_scale + v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] + v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] + w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] + for expert_idx in range(0, self.moe_num_experts): + topk_idx, token_idx = torch.where(expert_mask[expert_idx]) + if token_idx.shape[0] == 0: + continue + + token_list = token_idx + topk_list = topk_idx + + expert_tokens = x[None, token_list].reshape(-1, hidden_size) + expert_out = ( + self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) + * top_weights[token_list, topk_list, None] + ) + + out.index_add_(0, token_idx, expert_out) + + out = out.reshape(bsz, q_len, hidden_size) + return out - weights = weights.to(hidden_states.dtype) - top_weights = top_weights.to(hidden_states.dtype) - return weights, top_weights, top_experts def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -435,6 +566,70 @@ def parity_check(self): ) + +class DbrxFFN(nn.Module): + def __init__(self, config: DBRXConfig): + super().__init__() + + self.router = DbrxRouter( + hidden_size=config.hidden_size, + moe_num_experts=config.num_local_experts, + moe_top_k=config.num_experts_per_tok, + ) + + self.experts = DbrxExperts( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + moe_num_experts=config.num_local_experts, + ffn_act_fn=config.hidden_act, + ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + weights, top_weights, top_experts = self.router(x) + out = self.experts(x, weights, top_weights, top_experts) + return out, weights + + + +class DbrxRouter(nn.Module): + def __init__( + self, + hidden_size: int, + config: DBRXConfig, + moe_num_experts: int, + moe_top_k: int, + batch_size: int, + sequence_length: int, + ffn_hidden_size: int, + ffn_act_fn: dict + ): + super().__init__() + self.hidden_size = hidden_size + self.moe_num_experts = config.num_local_experts + self.moe_top_k = config.num_experts_per_tok + self.ffn_hidden_size = config.intermediate_size + self.ffn_act_fn = {"name", config.hidden_act} + + self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32) + top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) + + top_weights_scale = ( + torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True) + if self.moe_normalize_expert_weights is not None + else 1.0 + ) + top_weights = top_weights / top_weights_scale + + weights = weights.to(hidden_states.dtype) + top_weights = top_weights.to(hidden_states.dtype) + return weights, top_weights, top_experts + + + class TestDBRXMoE(unittest.TestCase): def test_dbrx_moe_parity(self): for batch_size in [1, 16]: @@ -443,18 +638,21 @@ def test_dbrx_moe_parity(self): config = DBRXConfig() hidden_size = config.hidden_size moe_num_experts = config.num_local_experts - moe_top_k = config.num_experts_per_tok + #moe_top_k = config.num_experts_per_tok ffn_hidden_size = config.intermediate_size ffn_act_fn = {"name", config.hidden_act} - dbrx_moe = DbrxRouter(hidden_size, - config, + dbrx_moe = DbrxExperts(hidden_size, + ffn_hidden_size, moe_num_experts, - moe_top_k, + ffn_act_fn, batch_size, sequence_length, - ffn_hidden_size, - ffn_act_fn,) - dbrx_moe.parity_check() + config) + dbrx_moe.test_moe_with_tensor_parallelism(hidden_size, + ffn_hidden_size, + moe_num_experts, + num_rows=batch_size * sequence_length, + threshold=THRESHOLD) if __name__ == "__main__": From 842001be756d451461eebdb62e2ebe95c1282b89 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 12 Jul 2024 00:16:48 +0000 Subject: [PATCH 14/15] script changes --- .../transformers/test_parity_dbrx_moe.py | 246 ++++++++++-------- 1 file changed, 137 insertions(+), 109 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py index 54c93ac3f24ec..9c73f1e7ccbd2 100644 --- a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py @@ -14,11 +14,13 @@ from collections import OrderedDict import numpy +import os import torch import torch.nn.functional as F from mpi4py import MPI from onnx import TensorProto, helper from torch import nn +import torch.nn.init as init from typing import Tuple @@ -52,6 +54,17 @@ def value_string_of(numpy_array): def print_tensor(name, numpy_array): print(f"const std::vector {name} = {value_string_of(numpy_array)};") +def save_model_to_disk(model, model_path): + external_data_path = model_path + ".data" + onnx.save_model( + model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path + ) + + +#def delete_model_data(external_data): + #os.remove("dbrx_moe.onnx") + #os.remove(external_data) + ORT_DTYPE = TensorProto.FLOAT16 NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 THRESHOLD = 3e-2 @@ -91,19 +104,13 @@ def create_moe_onnx_graph( ), ] - print("fc1_experts_weights shape:", fc1_experts_weights.shape) - print("fc2_experts_weights shape:", fc2_experts_weights.shape) - print("fc3_experts_weights shape:", fc3_experts_weights.shape) - - - fc1_shape = [num_experts, num_experts * inter_size, hidden_size] - fc2_shape = [num_experts, num_experts * inter_size, hidden_size] - fc3_shape = [num_experts, num_experts * inter_size, hidden_size] - - print("Expected fc1_shape:", fc1_shape) - print("Expected fc2_shape:", fc2_shape) - print("Expected fc3_shape:", fc3_shape) + fc1_experts_weights = fc1_experts_weights.view(16, 6144, 10752) + fc2_experts_weights = fc2_experts_weights.view(16, 6144, 10752).transpose(1, 2) + fc3_experts_weights = fc3_experts_weights.view(16, 6144, 10752) + fc1_shape = [num_experts, hidden_size, inter_size] + fc2_shape = [num_experts, inter_size, hidden_size] + fc3_shape = [num_experts, hidden_size, inter_size] torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 @@ -160,7 +167,12 @@ def create_moe_onnx_graph( ) model = helper.make_model(graph) - return model.SerializeToString() + model_path = "dbrx_moe.onnx" + + save_model_to_disk(model, model_path) + + return model_path + #return model.SerializeToString() @@ -181,7 +193,7 @@ class DBRXConfig: def __init__( self, hidden_size=6144, - intermediate_size=1500, + intermediate_size=10752, num_hidden_layers=40, num_attention_heads=48, num_key_value_heads=8, @@ -194,7 +206,7 @@ def __init__( num_experts_per_tok=4, num_local_experts=16, output_router_logits=False, - router_aux_loss_coef=0.001, + router_aux_loss_coef=0.001 ): self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -214,6 +226,30 @@ def __init__( self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef +class DbrxRouter(nn.Module): + def __init__( + self, + hidden_size: int, + moe_num_experts: int, + moe_top_k: int, + config: DBRXConfig, + ): + super().__init__() + self.hidden_size = hidden_size + self.moe_num_experts = config.num_local_experts + self.moe_top_k = config.num_experts_per_tok + + self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32) + top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) + + weights = weights.to(hidden_states.dtype) + top_weights = top_weights.to(hidden_states.dtype) + return weights, top_weights, top_experts + class DbrxExpertGLU(nn.Module): def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, config: DBRXConfig): @@ -223,9 +259,13 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, self.moe_num_experts = config.num_local_experts ffn_act_fn = {"name": config.hidden_act} - self.w1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) - self.v1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) - self.w2 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) + self.w1 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size)) + self.v1 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size)) + self.w2 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size)) + + init.xavier_uniform_(self.w1) + init.xavier_uniform_(self.v1) + init.xavier_uniform_(self.w2) act_fn_name = ffn_act_fn.get("name", "silu") self.activation_fn = ACT2FN[act_fn_name] @@ -242,7 +282,13 @@ def forward( class DbrxExperts(nn.Module): - def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, batch_size: int, sequence_length: int, config: DBRXConfig): + def __init__(self, hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + ffn_act_fn: dict, + batch_size: int, + sequence_length: int, + config: DBRXConfig): super().__init__() self.moe_num_experts = config.num_local_experts self.config = DBRXConfig() @@ -257,16 +303,6 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, config=config ) - w1_list = [] - v1_list = [] - w2_list = [] - for i in range(self.moe_num_experts): - w1_list.append(self.mlp.w1[i]) - v1_list.append(self.mlp.v1[i]) - w2_list.append(self.mlp.w2[i]) - self.moe_experts_weight1 = torch.stack(w1_list, dim=0) - self.moe_experts_weight2 = torch.stack(v1_list, dim=0) - self.moe_experts_weight3 = torch.stack(w2_list, dim=0) self.batch_size = batch_size self.sequence_length = sequence_length self.moe_onnx_graph = create_moe_onnx_graph( @@ -274,9 +310,9 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, self.moe_num_experts, self.hidden_size, self.ffn_hidden_size, - self.moe_experts_weight1, - self.moe_experts_weight2, - self.moe_experts_weight3, + self.mlp.w1, + self.mlp.v1, + self.mlp.w2, self.moe_top_k ) @@ -399,10 +435,10 @@ def generate_weights_and_initial_model( hidden_size, inter_size, ): - #s = 0.1 - fc1_experts_weights_all = self.moe_experts_weight1 - fc2_experts_weights_all = self.moe_experts_weight2 - fc3_experts_weights_all = self.moe_experts_weight3 + s = 0.1 + fc1_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE) + fc2_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, inter_size, hidden_size)).astype(NP_TYPE) + fc3_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE) onnx_model_full = create_moe_onnx_graph( num_rows, @@ -520,41 +556,20 @@ def forward( return out - def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - ort_inputs = { - "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), - "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), - } - - ort_output = None - if self.ort_sess is not None: - if not iobinding: - ort_output = self.ort_sess.run(None, ort_inputs) - else: - self.ort_run_with_iobinding(ort_inputs) - return None - - # print_tensor("input", ort_inputs["input"]) - # print_tensor("router_probs", ort_inputs["router_probs"]) - # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) - # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) - # print_tensor("output", ort_output[0]) - - return ort_output - def parity_check(self): - experts = DbrxExperts() - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) - torch_output = self.forward(hidden_state) - final_torch_output = experts.forward(torch_output) - ort_output = self.ort_forward(hidden_state, iobinding=True) + config = DBRXConfig() + ffn = DbrxFFN(config, self.batch_size, self.sequence_length) + router = DbrxRouter(hidden_size=config.hidden_size, + moe_num_experts=config.num_local_experts, + moe_top_k=config.num_local_experts, + config=DBRXConfig()) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_size) + torch_output = ffn.forward(hidden_state) + print("forward: ", torch_output) + ort_output = ffn.ort_forward(hidden_state, iobinding=False) + """ if ort_output is not None: - assert torch.allclose(final_torch_output, ort_output, rtol=1e-04, atol=1e-04) + assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04) print( "batch_size:", self.batch_size, @@ -564,17 +579,20 @@ def parity_check(self): (torch_output - ort_output).abs().max(), " parity: OK", ) + """ class DbrxFFN(nn.Module): - def __init__(self, config: DBRXConfig): + def __init__(self, config: DBRXConfig, batch_size, sequence_length): super().__init__() - + self.batch_size = batch_size + self.sequence_length = sequence_length self.router = DbrxRouter( hidden_size=config.hidden_size, moe_num_experts=config.num_local_experts, moe_top_k=config.num_experts_per_tok, + config=DBRXConfig() ) self.experts = DbrxExperts( @@ -582,51 +600,58 @@ def __init__(self, config: DBRXConfig): ffn_hidden_size=config.intermediate_size, moe_num_experts=config.num_local_experts, ffn_act_fn=config.hidden_act, - ) + batch_size=batch_size, + sequence_length=sequence_length, + config=DBRXConfig() + ) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + print("Input:", x) weights, top_weights, top_experts = self.router(x) + print("After router:", weights, top_weights, top_experts) out = self.experts(x, weights, top_weights, top_experts) + print("After experts:", out) return out, weights + def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) -class DbrxRouter(nn.Module): - def __init__( - self, - hidden_size: int, - config: DBRXConfig, - moe_num_experts: int, - moe_top_k: int, - batch_size: int, - sequence_length: int, - ffn_hidden_size: int, - ffn_act_fn: dict - ): - super().__init__() - self.hidden_size = hidden_size - self.moe_num_experts = config.num_local_experts - self.moe_top_k = config.num_experts_per_tok - self.ffn_hidden_size = config.intermediate_size - self.ffn_act_fn = {"name", config.hidden_act} + assert not torch.isnan(hidden_states).any(), "Input hidden_states contains NaN values" + assert not torch.isinf(hidden_states).any(), "Input hidden_states contains Inf values" - self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.router.layer(hidden_states) - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32) - top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) + assert not torch.isnan(router_logits).any(), "router_logits contains NaN values" + assert not torch.isinf(router_logits).any(), "router_logits contains Inf values" + ort_inputs = { + "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), + "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), + } - top_weights_scale = ( - torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True) - if self.moe_normalize_expert_weights is not None - else 1.0 - ) - top_weights = top_weights / top_weights_scale + #ort_output = None + if self.experts.ort_sess is not None: + if not iobinding: + ort_output = self.experts.ort_sess.run(None, ort_inputs) + else: + ort_output = self.experts.ort_run_with_iobinding(ort_inputs) + #return ort_output - weights = weights.to(hidden_states.dtype) - top_weights = top_weights.to(hidden_states.dtype) - return weights, top_weights, top_experts + + + # print_tensor("input", ort_inputs["input"]) + # print_tensor("router_probs", ort_inputs["router_probs"]) + # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) + # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) + # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) + # print_tensor("output", ort_output[0]) + + print("ORT output:", ort_output) + assert not numpy.isnan(ort_output).any(), "ORT output contains NaN values" + assert not numpy.isinf(ort_output).any(), "ORT output contains Inf values" + return ort_output @@ -648,11 +673,14 @@ def test_dbrx_moe_parity(self): batch_size, sequence_length, config) - dbrx_moe.test_moe_with_tensor_parallelism(hidden_size, - ffn_hidden_size, - moe_num_experts, - num_rows=batch_size * sequence_length, - threshold=THRESHOLD) + dbrx_moe.parity_check() + #dbrx_moe.test_moe_with_tensor_parallelism(hidden_size, + #ffn_hidden_size, + #moe_num_experts, + #num_rows=batch_size * sequence_length, + #threshold=THRESHOLD) + #external_data_path = "dbrx_moe.onnx" + ".data" + #delete_model_data(external_data_path) if __name__ == "__main__": From 8d41aae19799c2d4a5b5cec855fecf2a091c4dae Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 2 Aug 2024 18:58:42 +0000 Subject: [PATCH 15/15] reshaped tesors for parity to match --- .../test/python/transformers/test_parity_dbrx_moe.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py index 9c73f1e7ccbd2..a49ad5f122d4c 100644 --- a/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_dbrx_moe.py @@ -104,10 +104,6 @@ def create_moe_onnx_graph( ), ] - fc1_experts_weights = fc1_experts_weights.view(16, 6144, 10752) - fc2_experts_weights = fc2_experts_weights.view(16, 6144, 10752).transpose(1, 2) - fc3_experts_weights = fc3_experts_weights.view(16, 6144, 10752) - fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] @@ -310,9 +306,9 @@ def __init__(self, hidden_size: int, self.moe_num_experts, self.hidden_size, self.ffn_hidden_size, - self.mlp.w1, - self.mlp.v1, - self.mlp.w2, + self.mlp.w1.view(moe_num_experts, hidden_size, ffn_hidden_size).reshape(moe_num_experts, hidden_size, ffn_hidden_size), + self.mlp.w2.view(moe_num_experts, ffn_hidden_size, hidden_size).transpose(1, 2).reshape(moe_num_experts, ffn_hidden_size, hidden_size), + self.mlp.v1.view(moe_num_experts, hidden_size, ffn_hidden_size).reshape(moe_num_experts, hidden_size, ffn_hidden_size), self.moe_top_k )