From 6ef325fd9f6aa3f3c0d0abeaf5d70577e6ba8b1f Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 26 Feb 2024 10:00:52 -0800 Subject: [PATCH] Run lintrunner --- .../test/python/quantization/test_fusions.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/onnxruntime/test/python/quantization/test_fusions.py b/onnxruntime/test/python/quantization/test_fusions.py index d265dbea3d87f..8a4e81180e40c 100644 --- a/onnxruntime/test/python/quantization/test_fusions.py +++ b/onnxruntime/test/python/quantization/test_fusions.py @@ -5,23 +5,24 @@ # license information. # -------------------------------------------------------------------------- +import math import unittest -import math import numpy as np import onnx -from onnxruntime.quantization.onnx_model import ONNXModel from onnxruntime.quantization.fusions import FusionGelu +from onnxruntime.quantization.onnx_model import ONNXModel + class TestFusions(unittest.TestCase): def build_erf_sequence_1_model(self): """ - +-------Mul(0.5)---------------------+ - | | - | v - [root] --> Div -----> Erf --> Add --> Mul --> - (B=1.4142...) (1) + +-------Mul(0.5)---------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul --> + (B=1.4142...) (1) """ shape = (1, 2, 3) @@ -49,11 +50,11 @@ def build_erf_sequence_1_model(self): def build_erf_sequence_2_model(self): """ - +------------------------------------+ - | | - | v - [root] --> Div -----> Erf --> Add --> Mul -->Mul --> - (B=1.4142...) (1) (0.5) + +------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul --> + (B=1.4142...) (1) (0.5) """ shape = (1, 2, 3) @@ -81,11 +82,11 @@ def build_erf_sequence_2_model(self): def build_erf_sequence_3_model(self): """ - +------------------------------------------+ - | | - | v - [root] --> Div -----> Erf --> Add --> Mul -->Mul - (B=1.4142...) (A=1) (A=0.5) + +------------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul + (B=1.4142...) (A=1) (A=0.5) """ shape = (1, 2, 3) @@ -113,11 +114,11 @@ def build_erf_sequence_3_model(self): def build_erf_sequence_4_model(self): """ - +----------------------------------------------+ - | | - | v - [root] --> Mul -----> Erf --> Add --> Mul -->Mul - (A=0.7071067690849304) (B=1) (B=0.5) + +----------------------------------------------+ + | | + | v + [root] --> Mul -----> Erf --> Add --> Mul -->Mul + (A=0.7071067690849304) (B=1) (B=0.5) """ shape = (1, 2, 3) @@ -187,5 +188,6 @@ def test_fuse_erf_to_gelu_4(self): self.assertEqual(gelu_node.op_type, "Gelu") self.assertTrue(gelu_node.name) + if __name__ == "__main__": unittest.main()