diff --git a/onnxruntime/test/python/quantization/test_fusions.py b/onnxruntime/test/python/quantization/test_fusions.py index c6fbb967662b8..c244d74abb6ed 100644 --- a/onnxruntime/test/python/quantization/test_fusions.py +++ b/onnxruntime/test/python/quantization/test_fusions.py @@ -12,7 +12,8 @@ import onnx import onnxruntime -from onnxruntime.quantization.fusions import FusionGelu +from onnxruntime.quantization.execution_providers.qnn.fusion_lpnorm import FusionLpNormalization +from onnxruntime.quantization.fusions import FusionGelu, FusionLayerNormalization from onnxruntime.quantization.onnx_model import ONNXModel @@ -177,6 +178,79 @@ def build_erf_sequence_4_model(self, shape): model = onnx.helper.make_model(graph, opset_imports=opset_imports) return ONNXModel(model) + def build_reduce_mean_sequence_model(self, shape, scale_val, bias_val, axis=-1): + """ + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ ^ ^ + | | | | + +-------------------------------------------------+ [Scale] [Bias] + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const") + bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") + eps_const = onnx.numpy_helper.from_array(np.array(1.0e-8, dtype=np.float32), "eps_const") + + rm0_node = onnx.helper.make_node("ReduceMean", ["root", "axes_const"], ["rm0_out"]) + sub_node = onnx.helper.make_node("Sub", ["root", "rm0_out"], ["sub_out"]) + pow_node = onnx.helper.make_node("Pow", ["sub_out", "two_const"], ["pow_out"]) + rm1_node = onnx.helper.make_node("ReduceMean", ["pow_out", "axes_const"], ["rm1_out"]) + add0_node = onnx.helper.make_node("Add", ["rm1_out", "eps_const"], ["add0_out"]) + sqrt_node = onnx.helper.make_node("Sqrt", ["add0_out"], ["sqrt_out"]) + div_node = onnx.helper.make_node("Div", ["sub_out", "sqrt_out"], ["div_out"]) + mul_node = onnx.helper.make_node("Mul", ["div_out", "scale_const"], ["mul_out"]) + add1_node = onnx.helper.make_node("Add", ["mul_out", "bias_const"], ["output"]) + + graph = onnx.helper.make_graph( + [rm0_node, sub_node, pow_node, rm1_node, add0_node, sqrt_node, div_node, mul_node, add1_node], + "reduce_mean_sequence", + [root_inp], + [output], + initializer=[scale_const, bias_const, axes_const, two_const, eps_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_l2_sequence_model(self, shape, epsilon_val, axis=-1): + """ + [root] --> ReduceL2 -----> Clip --> Expand ----> Div --> + | (axis=-1) (min=epsilon) (shape=root) ^ + | (keepdims=True) | + | | + +-----------------------------------------------+ + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + eps_const = onnx.numpy_helper.from_array(np.array(epsilon_val, dtype=np.float32), "eps_const") + shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const") + + rl2_node = onnx.helper.make_node("ReduceL2", ["root", "axes_const"], ["rl2_out"], keepdims=1) + clip_node = onnx.helper.make_node("Clip", ["rl2_out", "eps_const"], ["clip_out"]) + expand_node = onnx.helper.make_node("Expand", ["clip_out", "shape_const"], ["expand_out"]) + div_node = onnx.helper.make_node("Div", ["root", "expand_out"], ["output"]) + + graph = onnx.helper.make_graph( + [rl2_node, clip_node, expand_node, div_node], + "reducel2_sequence", + [root_inp], + [output], + initializer=[axes_const, eps_const, shape_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + def test_fuse_erf_to_gelu_1(self): shape = (1, 2, 3) model = self.build_erf_sequence_1_model(shape) @@ -253,6 +327,44 @@ def test_fuse_erf_to_gelu_4(self): inputs = {"root": np.ones(shape, dtype=np.float32)} self.check_fused_model_correctness(orig_model, model.model, inputs) + def test_fuse_reduce_l2_to_lpnorm(self): + shape = (1, 2, 3) + model = self.build_reduce_l2_sequence_model(shape, 1e-12, axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LpNormalization node. + modified = FusionLpNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + lpnorm_node = model.model.graph.node[0] + self.assertEqual(lpnorm_node.op_type, "LpNormalization") + self.assertTrue(lpnorm_node.name) + + # LpNorm's p attribute should be set to 2 + p_attr = next(attr for attr in lpnorm_node.attribute if attr.name == "p") + self.assertEqual(p_attr.i, 2) + + def test_fuse_reduce_mean_to_layer_norm(self): + shape = (1, 2, 3) + model = self.build_reduce_mean_sequence_model(shape, [2.0, 2.0, 2.0], [1.0, 1.0, 1.0], axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LayerNormalization node. + modified = FusionLayerNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + layer_norm_node = model.model.graph.node[0] + self.assertEqual(layer_norm_node.op_type, "LayerNormalization") + self.assertTrue(layer_norm_node.name) + + # Check that fused model is equivalent to original model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + if __name__ == "__main__": unittest.main()