Skip to content

Commit

Permalink
Run lintrunner
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Feb 26, 2024
1 parent 8f7c88f commit 6ef325f
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions onnxruntime/test/python/quantization/test_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@
# license information.
# --------------------------------------------------------------------------

import math
import unittest

import math
import numpy as np
import onnx

from onnxruntime.quantization.onnx_model import ONNXModel
from onnxruntime.quantization.fusions import FusionGelu
from onnxruntime.quantization.onnx_model import ONNXModel


class TestFusions(unittest.TestCase):
def build_erf_sequence_1_model(self):
"""
+-------Mul(0.5)---------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->
(B=1.4142...) (1)
+-------Mul(0.5)---------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->
(B=1.4142...) (1)
"""
shape = (1, 2, 3)
Expand Down Expand Up @@ -49,11 +50,11 @@ def build_erf_sequence_1_model(self):

def build_erf_sequence_2_model(self):
"""
+------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
(B=1.4142...) (1) (0.5)
+------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
(B=1.4142...) (1) (0.5)
"""
shape = (1, 2, 3)
Expand Down Expand Up @@ -81,11 +82,11 @@ def build_erf_sequence_2_model(self):

def build_erf_sequence_3_model(self):
"""
+------------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul
(B=1.4142...) (A=1) (A=0.5)
+------------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul
(B=1.4142...) (A=1) (A=0.5)
"""
shape = (1, 2, 3)
Expand Down Expand Up @@ -113,11 +114,11 @@ def build_erf_sequence_3_model(self):

def build_erf_sequence_4_model(self):
"""
+----------------------------------------------+
| |
| v
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
(A=0.7071067690849304) (B=1) (B=0.5)
+----------------------------------------------+
| |
| v
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
(A=0.7071067690849304) (B=1) (B=0.5)
"""
shape = (1, 2, 3)
Expand Down Expand Up @@ -187,5 +188,6 @@ def test_fuse_erf_to_gelu_4(self):
self.assertEqual(gelu_node.op_type, "Gelu")
self.assertTrue(gelu_node.name)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6ef325f

Please sign in to comment.