From 4838cb6b3e98273fcdd6a3e54da74cd584167780 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 27 Feb 2024 02:27:35 -0800 Subject: [PATCH] [QNN Quantization] Ensure fused nodes have names (#19650) ### Description - Updates the `qnn_preprocess_model()` method to set a name for any new nodes added to the graph (due to fusion). - Updates the `qnn_preprocess_model()` method to set a name for any unnamed nodes that previously existed in the original graph. - Adds unit tests for fusions (previously missing) - Checks that fused node names exist and are unique - Checks that fused graph is equivalent to original graph ### Motivation and Context Nodes are not strictly required to have names. However, a planned/upcoming feature to support mixed-precision (integer) quantized models needs nodes to have names. --- .../execution_providers/qnn/fusion_lpnorm.py | 7 +- .../execution_providers/qnn/preprocess.py | 11 + .../tools/quantization/fusions/fusion.py | 15 + .../tools/quantization/fusions/fusion_gelu.py | 25 +- .../quantization/fusions/fusion_layernorm.py | 1 + .../python/tools/quantization/onnx_model.py | 17 + .../test/python/quantization/test_fusions.py | 401 ++++++++++++++++++ 7 files changed, 465 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_fusions.py diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py index 9ebf400498e0e..fbf954febdda4 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py @@ -122,6 +122,11 @@ def fuse( self.nodes_to_remove.extend(subgraph_nodes) fused_node = onnx.helper.make_node( - self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1 + self.fused_op_type, + name=self.create_unique_node_name(), + inputs=[subgraph_input], + outputs=[subgraph_output], + p=2, + axis=-1, ) self.nodes_to_add.append(fused_node) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index becbaceab184e..b1c114fe1f9fd 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -44,6 +44,17 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: if fusion_layernorm.apply(): modified = True + # Make sure all nodes have a name. + unnamed_node_prefix = "qnn_preproc_node_" + available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1 + for node in onnx_model.model.graph.node: + if node.op_type != "Constant" and not node.name: + new_node_name = f"{unnamed_node_prefix}{available_suffix!s}" + available_suffix += 1 + node.name = new_node_name + modified = True + logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.") + if modified: onnx_model.topological_sort() onnx.save_model(model, model_output) diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py index b54b421226f1a..4bdc5c26cc946 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion.py @@ -24,6 +24,9 @@ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str): self.nodes_to_remove: list = [] self.nodes_to_add: list = [] + self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_" + self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops. + def fuse( self, node: onnx.NodeProto, @@ -57,6 +60,18 @@ def apply(self) -> bool: return graph_updated + def create_unique_node_name(self): + prefix = self._new_node_name_prefix + + if self._new_node_name_suffix is None: + largest_suffix: int = self.model.get_largest_node_name_suffix(prefix) + self._new_node_name_suffix = largest_suffix + 1 + + new_name = f"{prefix}{self._new_node_name_suffix!s}" + self._new_node_name_suffix += 1 + + return new_name + @staticmethod def is_safe_to_fuse_nodes( nodes_to_remove: list[onnx.NodeProto], diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py index a20d6dbffd7a7..42c4a11833641 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py @@ -112,7 +112,9 @@ def fuse_1( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True @@ -173,11 +175,9 @@ def fuse_2( if not self.has_constant_input(sqrt_node, 2.0): return False - root_node = self.model.get_parent(div, 0, output_name_to_node) - if root_node is None: - return False + subgraph_input = div.input[0] - if root_node.output[0] not in mul.input: + if subgraph_input not in mul.input: return False subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul] @@ -188,7 +188,9 @@ def fuse_2( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True @@ -239,9 +241,8 @@ def fuse_3( if i < 0: return False - root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node) - if root_node is None: - return False + root_input_index = 1 - i + subgraph_input = first_mul.input[root_input_index] if mul_half.output[0] not in input_name_to_nodes: return False @@ -250,7 +251,7 @@ def fuse_3( return False last_mul = children[0] - if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]): + if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input): return False subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul] @@ -263,7 +264,9 @@ def fuse_3( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py index d7fb89236d3d2..7d58c1c180822 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py @@ -127,6 +127,7 @@ def fuse( normalize_node = onnx.helper.make_node( "LayerNormalization", + name=self.create_unique_node_name(), inputs=[reduce_mean_node.input[0], weight_input, bias_input], outputs=[last_add_node.output[0]], ) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 4591c9c950e6e..46d245d353a07 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -283,6 +283,23 @@ def find_node_by_name(self, node_name, new_nodes_list, graph): node = find_by_name(node_name, graph_nodes_list) return node + def get_largest_node_name_suffix(self, node_name_prefix): + """ + Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`. + Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3. + """ + suffix = -1 + + for node in self.model.graph.node: + if node.name and node.name.startswith(node_name_prefix): + try: + index = int(node.name[len(node_name_prefix) :]) + suffix = max(index, suffix) + except ValueError: + continue + + return suffix + def find_nodes_by_initializer(self, graph, initializer): """ Find all nodes with given initializer as an input. diff --git a/onnxruntime/test/python/quantization/test_fusions.py b/onnxruntime/test/python/quantization/test_fusions.py new file mode 100644 index 0000000000000..bea110e566fb9 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_fusions.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import unittest + +import numpy as np +import onnx + +import onnxruntime +from onnxruntime.quantization.execution_providers.qnn.fusion_lpnorm import FusionLpNormalization +from onnxruntime.quantization.fusions import FusionGelu, FusionLayerNormalization +from onnxruntime.quantization.onnx_model import ONNXModel + + +class TestFusions(unittest.TestCase): + def check_fused_model_correctness(self, orig_model, fused_model, inputs, rtol=1e-7, atol=0): + """ + Checks that the output of the fused model matches the output of the original model. + """ + orig_session = onnxruntime.InferenceSession(orig_model.SerializeToString(), providers=["CPUExecutionProvider"]) + orig_results = orig_session.run(None, inputs) + + fused_session = onnxruntime.InferenceSession( + fused_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + fused_results = fused_session.run([], inputs) + + self.assertEqual(len(orig_results), len(fused_results), "Number of outputs for fused model differs") + for idx, expected_output in enumerate(orig_results): + actual_output = fused_results[idx] + np.testing.assert_allclose( + expected_output, + actual_output, + rtol=rtol, + atol=atol, + err_msg=f"Fused model output {idx} differs", + ) + + def build_erf_sequence_1_model(self, shape): + """ + Erf sequence that fuses into Gelu: + +-------Mul(0.5)---------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul --> + (B=1.4142...) (1) + + This method builds 2 of these Erf sequences: + + [root] -> ERF_SEQUENCE1 -> ERF_SEQUENCE2 -> output + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + # First Erf sequence + mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["mul0_out"]) + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul1_node = onnx.helper.make_node("Mul", ["add_out", "mul0_out"], ["seq1_output"]) + + # Second Erf sequence + mul0_node_dup = onnx.helper.make_node("Mul", ["seq1_output", "half_const"], ["mul0_out_dup"]) + div_node_dup = onnx.helper.make_node("Div", ["seq1_output", "root2_const"], ["div_out_dup"]) + erf_node_dup = onnx.helper.make_node("Erf", ["div_out_dup"], ["erf_out_dup"]) + add_node_dup = onnx.helper.make_node("Add", ["erf_out_dup", "one_const"], ["add_out_dup"]) + mul1_node_dup = onnx.helper.make_node("Mul", ["add_out_dup", "mul0_out_dup"], ["output"]) + + graph = onnx.helper.make_graph( + [ + mul0_node, + div_node, + erf_node, + add_node, + mul1_node, + mul0_node_dup, + div_node_dup, + erf_node_dup, + add_node_dup, + mul1_node_dup, + ], + "two_erf_sequences", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_2_model(self, shape): + """ + +------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul --> + (B=1.4142...) (1) (0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul0_node = onnx.helper.make_node("Mul", ["add_out", "root"], ["mul0_out"]) + mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "half_const"], ["output"]) + + graph = onnx.helper.make_graph( + [div_node, erf_node, add_node, mul0_node, mul1_node], + "erf_sequence_2", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_3_model(self, shape): + """ + +------------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul + (B=1.4142...) (A=1) (A=0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul0_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul0_out"]) + mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "root"], ["output"]) + + graph = onnx.helper.make_graph( + [div_node, erf_node, add_node, mul0_node, mul1_node], + "erf_sequence_3", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_4_model(self, shape): + """ + +----------------------------------------------+ + | | + | v + [root] --> Mul -----> Erf --> Add --> Mul -->Mul + (A=0.7071067690849304) (B=1) (B=0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + frac_const = onnx.numpy_helper.from_array(np.array(0.7071067690849304, dtype=np.float32), "frac_const") + + mul0_node = onnx.helper.make_node("Mul", ["root", "frac_const"], ["mul0_out"]) + erf_node = onnx.helper.make_node("Erf", ["mul0_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul1_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul1_out"]) + mul2_node = onnx.helper.make_node("Mul", ["mul1_out", "root"], ["output"]) + + graph = onnx.helper.make_graph( + [mul0_node, erf_node, add_node, mul1_node, mul2_node], + "erf_sequence_4", + [root_inp], + [output], + initializer=[one_const, half_const, frac_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_mean_sequence_model(self, shape, scale_val, bias_val, axis=-1): + """ + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ ^ ^ + | | | | + +-------------------------------------------------+ [Scale] [Bias] + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const") + bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") + eps_const = onnx.numpy_helper.from_array(np.array(1.0e-8, dtype=np.float32), "eps_const") + + rm0_node = onnx.helper.make_node("ReduceMean", ["root", "axes_const"], ["rm0_out"]) + sub_node = onnx.helper.make_node("Sub", ["root", "rm0_out"], ["sub_out"]) + pow_node = onnx.helper.make_node("Pow", ["sub_out", "two_const"], ["pow_out"]) + rm1_node = onnx.helper.make_node("ReduceMean", ["pow_out", "axes_const"], ["rm1_out"]) + add0_node = onnx.helper.make_node("Add", ["rm1_out", "eps_const"], ["add0_out"]) + sqrt_node = onnx.helper.make_node("Sqrt", ["add0_out"], ["sqrt_out"]) + div_node = onnx.helper.make_node("Div", ["sub_out", "sqrt_out"], ["div_out"]) + mul_node = onnx.helper.make_node("Mul", ["div_out", "scale_const"], ["mul_out"]) + add1_node = onnx.helper.make_node("Add", ["mul_out", "bias_const"], ["output"]) + + graph = onnx.helper.make_graph( + [rm0_node, sub_node, pow_node, rm1_node, add0_node, sqrt_node, div_node, mul_node, add1_node], + "reduce_mean_sequence", + [root_inp], + [output], + initializer=[scale_const, bias_const, axes_const, two_const, eps_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_l2_sequence_model(self, shape, epsilon_val, axis=-1): + """ + [root] --> ReduceL2 -----> Clip --> Expand ----> Div --> + | (axis=-1) (min=epsilon) (shape=root) ^ + | (keepdims=True) | + | | + +-----------------------------------------------+ + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + eps_const = onnx.numpy_helper.from_array(np.array(epsilon_val, dtype=np.float32), "eps_const") + shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const") + + rl2_node = onnx.helper.make_node("ReduceL2", ["root", "axes_const"], ["rl2_out"], keepdims=1) + clip_node = onnx.helper.make_node("Clip", ["rl2_out", "eps_const"], ["clip_out"]) + expand_node = onnx.helper.make_node("Expand", ["clip_out", "shape_const"], ["expand_out"]) + div_node = onnx.helper.make_node("Div", ["root", "expand_out"], ["output"]) + + graph = onnx.helper.make_graph( + [rl2_node, clip_node, expand_node, div_node], + "reducel2_sequence", + [root_inp], + [output], + initializer=[axes_const, eps_const, shape_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def test_fuse_erf_to_gelu_1(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_1_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 2 Gelu nodes. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 2) + + gelu_node_0 = model.model.graph.node[0] + gelu_node_1 = model.model.graph.node[1] + self.assertEqual(gelu_node_0.op_type, "Gelu") + self.assertEqual(gelu_node_1.op_type, "Gelu") + + self.assertTrue(gelu_node_0.name) + self.assertTrue(gelu_node_1.name) + self.assertNotEqual(gelu_node_0.name, gelu_node_1.name) # Generated names should not be equal + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_2(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_2_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_3(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_3_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_4(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_4_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_reduce_l2_to_lpnorm(self): + shape = (1, 2, 3) + model = self.build_reduce_l2_sequence_model(shape, 1e-12, axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LpNormalization node. + modified = FusionLpNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + lpnorm_node = model.model.graph.node[0] + self.assertEqual(lpnorm_node.op_type, "LpNormalization") + self.assertTrue(lpnorm_node.name) + + # LpNorm's p attribute should be set to 2 + p_attr = next(attr for attr in lpnorm_node.attribute if attr.name == "p") + self.assertEqual(p_attr.i, 2) + + def test_fuse_reduce_mean_to_layer_norm(self): + shape = (1, 2, 3) + model = self.build_reduce_mean_sequence_model(shape, [2.0, 2.0, 2.0], [1.0, 1.0, 1.0], axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LayerNormalization node. + modified = FusionLayerNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + layer_norm_node = model.model.graph.node[0] + self.assertEqual(layer_norm_node.op_type, "LayerNormalization") + self.assertTrue(layer_norm_node.name) + + # Check that fused model is equivalent to original model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + +if __name__ == "__main__": + unittest.main()