Skip to content

Commit

Permalink
[Function-Rewriter] Support BiasSplitGelu operator in `stable_diffusi…
Browse files Browse the repository at this point in the history
…on_unet` (#1400)

1. Support BiasSplitGelu in diffuser models
2. Add `function_unittest_producer.py` from onnx-rewriter

After(left) vs Before(right)

![Screenshot 2024-04-18
102455](https://github.com/microsoft/onnxscript/assets/18010845/c483d575-ca06-4d35-aa43-af3d28d0ee0a)
  • Loading branch information
titaiwangms authored Apr 19, 2024
1 parent 9ea1d1c commit a6a7ffc
Show file tree
Hide file tree
Showing 27 changed files with 570 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ code = 'EDITORCONFIG-CHECKER'
include_patterns = ['**']
exclude_patterns = [
'**/*.ipynb',
'**/*.onnx',
'**/*.pb'
]
command = [
'python',
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/rewriter/onnxruntime/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from onnxscript.rewriter import function_rule
from onnxscript.rewriter.onnxruntime.transformers import (
biassplitgelu,
fastgelu,
layernorm,
multihead_attention,
Expand All @@ -13,4 +14,5 @@
multihead_attention.AttnPhi15RewriteRule,
layernorm.LNRewriteRule,
fastgelu.GeluRewriteRule,
biassplitgelu.GegluRewriteRule,
]
32 changes: 32 additions & 0 deletions onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

import logging

import onnx

import onnxscript
from onnxscript.rewriter import function_rule

logger = logging.getLogger(__name__)


class GegluRewriteRule(function_rule.FunctionRewriteRule):
FUNCTION_KEYWORD = "GEGLU"
PACKAGE_NAME = "diffusers"
_version_controller = function_rule.VersionController()

@_version_controller.register_version() # type: ignore[misc]
def _fusion(
self, function: onnx.FunctionProto
) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]:
del function # Unused
op = self.onnx_opset
msft_opset = onnxscript.values.Opset("com.microsoft", 1)

def ggelu(input, weight, bias):
weight_transpose = op.Transpose(weight, [1, 0])
matmul_input = op.MatMul(input, weight_transpose)
return msft_opset.BiasSplitGelu(matmul_input, bias)

function_proto = onnxscript.script(default_opset=op)(ggelu).to_function_proto() # type: ignore[arg-type]
return function_proto, (onnx.helper.make_operatorsetid("com.microsoft", 1),)
22 changes: 22 additions & 0 deletions onnxscript/rewriter/onnxruntime/transformers/biassplitgelu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import unittest

import numpy as np

from tests.common import testutils


class BiasSplitGeluParityTest(unittest.TestCase):
def setUp(self):
np.random.seed(0)

@testutils.skip_if_no_cuda("BiasSplitGelu Kernel unsupported on CPU.")
def test_geglu_stable_diffusion_unet(self):
testutils.test_onnxruntime_rewrite(
"geglu_stable_diffusion_unet", 4, {("com.microsoft", "BiasSplitGelu", "")}
)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
self._param_schemas = param_schemas_from_function_ir(self.function_ir)
return self._param_schemas

def to_function_proto(self):
def to_function_proto(self) -> onnx.FunctionProto:
"""Converts the function into :class:`onnx.FunctionProto`."""
return self.function_ir.to_function_proto()

Expand Down
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading

0 comments on commit a6a7ffc

Please sign in to comment.