Skip to content

Commit

Permalink
Test other fusions
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Feb 26, 2024
1 parent 2470cf5 commit a3d0c8c
Showing 1 changed file with 113 additions and 1 deletion.
114 changes: 113 additions & 1 deletion onnxruntime/test/python/quantization/test_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import onnx

import onnxruntime
from onnxruntime.quantization.fusions import FusionGelu
from onnxruntime.quantization.execution_providers.qnn.fusion_lpnorm import FusionLpNormalization
from onnxruntime.quantization.fusions import FusionGelu, FusionLayerNormalization
from onnxruntime.quantization.onnx_model import ONNXModel


Expand Down Expand Up @@ -177,6 +178,79 @@ def build_erf_sequence_4_model(self, shape):
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return ONNXModel(model)

def build_reduce_mean_sequence_model(self, shape, scale_val, bias_val, axis=-1):
"""
+----------------------+
| |
| v
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ ^ ^
| | | |
+-------------------------------------------------+ [Scale] [Bias]
"""
root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const")
bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const")
axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const")
two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const")
eps_const = onnx.numpy_helper.from_array(np.array(1.0e-8, dtype=np.float32), "eps_const")

rm0_node = onnx.helper.make_node("ReduceMean", ["root", "axes_const"], ["rm0_out"])
sub_node = onnx.helper.make_node("Sub", ["root", "rm0_out"], ["sub_out"])
pow_node = onnx.helper.make_node("Pow", ["sub_out", "two_const"], ["pow_out"])
rm1_node = onnx.helper.make_node("ReduceMean", ["pow_out", "axes_const"], ["rm1_out"])
add0_node = onnx.helper.make_node("Add", ["rm1_out", "eps_const"], ["add0_out"])
sqrt_node = onnx.helper.make_node("Sqrt", ["add0_out"], ["sqrt_out"])
div_node = onnx.helper.make_node("Div", ["sub_out", "sqrt_out"], ["div_out"])
mul_node = onnx.helper.make_node("Mul", ["div_out", "scale_const"], ["mul_out"])
add1_node = onnx.helper.make_node("Add", ["mul_out", "bias_const"], ["output"])

graph = onnx.helper.make_graph(
[rm0_node, sub_node, pow_node, rm1_node, add0_node, sqrt_node, div_node, mul_node, add1_node],
"reduce_mean_sequence",
[root_inp],
[output],
initializer=[scale_const, bias_const, axes_const, two_const, eps_const],
)
opset_imports = [
onnx.helper.make_opsetid("", 18),
]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return ONNXModel(model)

def build_reduce_l2_sequence_model(self, shape, epsilon_val, axis=-1):
"""
[root] --> ReduceL2 -----> Clip --> Expand ----> Div -->
| (axis=-1) (min=epsilon) (shape=root) ^
| (keepdims=True) |
| |
+-----------------------------------------------+
"""
root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const")
eps_const = onnx.numpy_helper.from_array(np.array(epsilon_val, dtype=np.float32), "eps_const")
shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const")

rl2_node = onnx.helper.make_node("ReduceL2", ["root", "axes_const"], ["rl2_out"], keepdims=1)
clip_node = onnx.helper.make_node("Clip", ["rl2_out", "eps_const"], ["clip_out"])
expand_node = onnx.helper.make_node("Expand", ["clip_out", "shape_const"], ["expand_out"])
div_node = onnx.helper.make_node("Div", ["root", "expand_out"], ["output"])

graph = onnx.helper.make_graph(
[rl2_node, clip_node, expand_node, div_node],
"reducel2_sequence",
[root_inp],
[output],
initializer=[axes_const, eps_const, shape_const],
)
opset_imports = [
onnx.helper.make_opsetid("", 18),
]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return ONNXModel(model)

def test_fuse_erf_to_gelu_1(self):
shape = (1, 2, 3)
model = self.build_erf_sequence_1_model(shape)
Expand Down Expand Up @@ -253,6 +327,44 @@ def test_fuse_erf_to_gelu_4(self):
inputs = {"root": np.ones(shape, dtype=np.float32)}
self.check_fused_model_correctness(orig_model, model.model, inputs)

def test_fuse_reduce_l2_to_lpnorm(self):
shape = (1, 2, 3)
model = self.build_reduce_l2_sequence_model(shape, 1e-12, axis=-1)
orig_model = onnx.ModelProto()
orig_model.CopyFrom(model.model)

# Check that fusion simplified model to 1 LpNormalization node.
modified = FusionLpNormalization(model).apply()
self.assertTrue(modified)
self.assertEqual(len(model.model.graph.node), 1)

lpnorm_node = model.model.graph.node[0]
self.assertEqual(lpnorm_node.op_type, "LpNormalization")
self.assertTrue(lpnorm_node.name)

# LpNorm's p attribute should be set to 2
p_attr = next(attr for attr in lpnorm_node.attribute if attr.name == "p")
self.assertEqual(p_attr.i, 2)

def test_fuse_reduce_mean_to_layer_norm(self):
shape = (1, 2, 3)
model = self.build_reduce_mean_sequence_model(shape, [2.0, 2.0, 2.0], [1.0, 1.0, 1.0], axis=-1)
orig_model = onnx.ModelProto()
orig_model.CopyFrom(model.model)

# Check that fusion simplified model to 1 LayerNormalization node.
modified = FusionLayerNormalization(model).apply()
self.assertTrue(modified)
self.assertEqual(len(model.model.graph.node), 1)

layer_norm_node = model.model.graph.node[0]
self.assertEqual(layer_norm_node.op_type, "LayerNormalization")
self.assertTrue(layer_norm_node.name)

# Check that fused model is equivalent to original model.
inputs = {"root": np.ones(shape, dtype=np.float32)}
self.check_fused_model_correctness(orig_model, model.model, inputs)


if __name__ == "__main__":
unittest.main()

0 comments on commit a3d0c8c

Please sign in to comment.