Skip to content

Commit

Permalink
Test fusion correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Feb 26, 2024
1 parent 6ef325f commit 2470cf5
Showing 1 changed file with 85 additions and 20 deletions.
105 changes: 85 additions & 20 deletions onnxruntime/test/python/quantization/test_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,33 @@
import numpy as np
import onnx

import onnxruntime
from onnxruntime.quantization.fusions import FusionGelu
from onnxruntime.quantization.onnx_model import ONNXModel


class TestFusions(unittest.TestCase):
def build_erf_sequence_1_model(self):
def check_fused_model_correctness(self, orig_model, fused_model, inputs, rtol=1e-7, atol=0):
orig_session = onnxruntime.InferenceSession(orig_model.SerializeToString(), providers=["CPUExecutionProvider"])
orig_results = orig_session.run(None, inputs)

fused_session = onnxruntime.InferenceSession(
fused_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
fused_results = fused_session.run([], inputs)

self.assertEqual(len(orig_results), len(fused_results), "Number of outputs for fused model differs")
for idx, expected_output in enumerate(orig_results):
actual_output = fused_results[idx]
np.testing.assert_allclose(
expected_output,
actual_output,
rtol=rtol,
atol=atol,
err_msg=f"Fused model output {idx} differs",
)

def build_erf_sequence_1_model(self, shape):
"""
+-------Mul(0.5)---------------------+
| |
Expand All @@ -25,7 +46,6 @@ def build_erf_sequence_1_model(self):
(B=1.4142...) (1)
"""
shape = (1, 2, 3)
root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
Expand All @@ -45,10 +65,14 @@ def build_erf_sequence_1_model(self):
[output],
initializer=[one_const, half_const, root2_const],
)
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)])
opset_imports = [
onnx.helper.make_opsetid("", 18),
onnx.helper.make_opsetid("com.microsoft", 1),
]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return ONNXModel(model)

def build_erf_sequence_2_model(self):
def build_erf_sequence_2_model(self, shape):
"""
+------------------------------------+
| |
Expand All @@ -57,7 +81,6 @@ def build_erf_sequence_2_model(self):
(B=1.4142...) (1) (0.5)
"""
shape = (1, 2, 3)
root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
Expand All @@ -77,10 +100,14 @@ def build_erf_sequence_2_model(self):
[output],
initializer=[one_const, half_const, root2_const],
)
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)])
opset_imports = [
onnx.helper.make_opsetid("", 18),
onnx.helper.make_opsetid("com.microsoft", 1),
]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return ONNXModel(model)

def build_erf_sequence_3_model(self):
def build_erf_sequence_3_model(self, shape):
"""
+------------------------------------------+
| |
Expand All @@ -89,7 +116,6 @@ def build_erf_sequence_3_model(self):
(B=1.4142...) (A=1) (A=0.5)
"""
shape = (1, 2, 3)
root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
Expand All @@ -109,10 +135,14 @@ def build_erf_sequence_3_model(self):
[output],
initializer=[one_const, half_const, root2_const],
)
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)])
opset_imports = [
onnx.helper.make_opsetid("", 18),
onnx.helper.make_opsetid("com.microsoft", 1),
]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return ONNXModel(model)

def build_erf_sequence_4_model(self):
def build_erf_sequence_4_model(self, shape):
"""
+----------------------------------------------+
| |
Expand All @@ -121,7 +151,6 @@ def build_erf_sequence_4_model(self):
(A=0.7071067690849304) (B=1) (B=0.5)
"""
shape = (1, 2, 3)
root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
Expand All @@ -141,53 +170,89 @@ def build_erf_sequence_4_model(self):
[output],
initializer=[one_const, half_const, frac_const],
)
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)])
opset_imports = [
onnx.helper.make_opsetid("", 18),
onnx.helper.make_opsetid("com.microsoft", 1),
]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return ONNXModel(model)

def test_fuse_erf_to_gelu_1(self):
model = self.build_erf_sequence_1_model()
modified = FusionGelu(model).apply()
shape = (1, 2, 3)
model = self.build_erf_sequence_1_model(shape)
orig_model = onnx.ModelProto()
orig_model.CopyFrom(model.model)

# Check that fusion simplified model to 1 Gelu node.
modified = FusionGelu(model).apply()
self.assertTrue(modified)
self.assertEqual(len(model.model.graph.node), 1)

gelu_node = model.model.graph.node[0]
self.assertEqual(gelu_node.op_type, "Gelu")
self.assertTrue(gelu_node.name)

# Check that fusion is equivalent to original Erf model.
inputs = {"root": np.ones(shape, dtype=np.float32)}
self.check_fused_model_correctness(orig_model, model.model, inputs)

def test_fuse_erf_to_gelu_2(self):
model = self.build_erf_sequence_2_model()
modified = FusionGelu(model).apply()
shape = (1, 2, 3)
model = self.build_erf_sequence_2_model(shape)
orig_model = onnx.ModelProto()
orig_model.CopyFrom(model.model)

# Check that fusion simplified model to 1 Gelu node.
modified = FusionGelu(model).apply()
self.assertTrue(modified)
self.assertEqual(len(model.model.graph.node), 1)

gelu_node = model.model.graph.node[0]
self.assertEqual(gelu_node.op_type, "Gelu")
self.assertTrue(gelu_node.name)

# Check that fusion is equivalent to original Erf model.
inputs = {"root": np.ones(shape, dtype=np.float32)}
self.check_fused_model_correctness(orig_model, model.model, inputs)

def test_fuse_erf_to_gelu_3(self):
model = self.build_erf_sequence_3_model()
modified = FusionGelu(model).apply()
shape = (1, 2, 3)
model = self.build_erf_sequence_3_model(shape)
orig_model = onnx.ModelProto()
orig_model.CopyFrom(model.model)

# Check that fusion simplified model to 1 Gelu node.
modified = FusionGelu(model).apply()
self.assertTrue(modified)
self.assertEqual(len(model.model.graph.node), 1)

gelu_node = model.model.graph.node[0]
self.assertEqual(gelu_node.op_type, "Gelu")
self.assertTrue(gelu_node.name)

# Check that fusion is equivalent to original Erf model.
inputs = {"root": np.ones(shape, dtype=np.float32)}
self.check_fused_model_correctness(orig_model, model.model, inputs)

def test_fuse_erf_to_gelu_4(self):
model = self.build_erf_sequence_4_model()
modified = FusionGelu(model).apply()
shape = (1, 2, 3)
model = self.build_erf_sequence_4_model(shape)
orig_model = onnx.ModelProto()
orig_model.CopyFrom(model.model)

# Check that fusion simplified model to 1 Gelu node.
modified = FusionGelu(model).apply()
self.assertTrue(modified)
self.assertEqual(len(model.model.graph.node), 1)

gelu_node = model.model.graph.node[0]
self.assertEqual(gelu_node.op_type, "Gelu")
self.assertTrue(gelu_node.name)

# Check that fusion is equivalent to original Erf model.
inputs = {"root": np.ones(shape, dtype=np.float32)}
self.check_fused_model_correctness(orig_model, model.model, inputs)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2470cf5

Please sign in to comment.