diff --git a/docs/index.md b/docs/index.md index e80d6b327..ed1bab77b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -100,6 +100,7 @@ tutorial/index api/index intermediate_representation/index auto_examples/index +rewriter/index articles/index ``` diff --git a/docs/rewriter/examples/broadcast_matmul.py b/docs/rewriter/examples/broadcast_matmul.py new file mode 100644 index 000000000..e64984092 --- /dev/null +++ b/docs/rewriter/examples/broadcast_matmul.py @@ -0,0 +1,198 @@ +"""Onnx Pattern Rewriting with match condition parameter. + +This script shows how to define a rewriting rule based on patterns while +utilizing the match condition parameter. + +First we write a dummy model with a several Reshape nodes and a Matmul node +=================== +""" + +import logging + +import numpy as np +import onnx + +import onnxscript +from onnxscript import FLOAT, ir, opset18, script +from onnxscript.rewriter import _ir_utils, pattern + +logger = logging.getLogger(__name__) + + +@script() +def original_model(A: FLOAT[1, 4, 512, 512], B: FLOAT[1, 4, 512, 64]) -> FLOAT[1, 4, 512, 64]: + # NOTE: Modified from `value_ints` to `value` + shape_a = opset18.Constant(value=[4, 512, 512]) + reshape_a = opset18.Reshape(A, shape_a) + shape_b = opset18.Constant(value=[4, 512, 64]) + reshape_b = opset18.Reshape(B, shape_b) + matmul = opset18.MatMul(reshape_a, reshape_b) + shape_c = opset18.Constant(value=[1, 4, 512, 64]) + result = opset18.Reshape(matmul, shape_c) + return result + + +_model = original_model.to_model_proto() +onnx.checker.check_model(_model) + + +#################################### +# The target pattern +# ===================== + +_op = pattern.onnxop + + +def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c): + reshape_a = _op.Reshape(input_a, shape_a) + reshape_b = _op.Reshape(input_b, shape_b) + matmul = _op.MatMul(reshape_a, reshape_b) + return _op.Reshape(matmul, shape_c) + + +#################################### +# The replacement pattern +# ===================== + + +def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_): + return op.MatMul(input_a, input_b) + + +#################################### +# Write condition to check if we need to replace the pattern +# ===================== + + +def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool: + """If matmul broadcasting is enough, then we don't need the reshapes. + + To validate this, we need to check the following: + 1. Input shapes check: input_a and input_b should be broadcastable + 2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b) + + If the above are true, then we don't need the reshapes. + """ + input_a_shape = input_a.shape + input_b_shape = input_b.shape + # TODO: Get a helper func to get const_value + shape_c_value = _ir_utils.propagate_const_value(shape_c) + shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr] + if shape_c is None: + return False + if not isinstance(shape_c, np.ndarray): + logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c)) + return False + if len(shape_c.shape) != 1: + logger.info( + "Unexpected final shape. The shape of 'shape' value is %s", + shape_c.shape, + ) + return False + shape_c = shape_c.tolist() + + # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape + # information. So, we need to check if the shape is None and return False. + if input_a_shape is None or input_b_shape is None or shape_c is None: + logger.info("Shape information is not available for the inputs and outputs.") + return False + input_a_shape = list(input_a_shape) + input_b_shape = list(input_b_shape) + + dim_a = len(input_a_shape) + dim_b = len(input_b_shape) + + # 1. Check if input shapes are broadcastable + # 1.a. If the first input is 1-D, check whether + # the dim matches the last second dim of the second input. + mimic_matmul_broadcast_behavior = False + if dim_a < 2: + if input_a_shape[-1] != input_b_shape[-2]: + logger.info("Original shape is not MatMul compatible.") + return False + else: + input_a_shape = [1, *input_a_shape] + dim_a = len(input_a_shape) + mimic_matmul_broadcast_behavior = True + # 1.b. If the second input is 1-D, check whether + # the dim matches the last dim of the first input. + if dim_b < 2: + if input_b_shape[-1] != input_a_shape[-1]: + logger.info("Original shape is not MatMul compatible.") + return False + else: + input_b_shape = [*input_b_shape, 1] + dim_b = len(input_b_shape) + mimic_matmul_broadcast_behavior = True + # 1.c. If both inputs are at least 2-D, check whether + # the last dimension of the first input matches the second + # last dimension of the second input, and shape[:-2] are + # broadcastable. + input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]] + input_b_shape_except_last_dim = input_b_shape[:-1] + broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]] + for idx, (dim_from_a, dim_from_b) in enumerate( + zip( + reversed(input_a_shape_except_second_last_dim), + reversed(input_b_shape_except_last_dim), + ) + ): + if dim_from_a not in {1, dim_from_b}: + logger.info("Original shape is not broadcastable.") + return False + elif idx > 0: + broadcast_matmul_output_shape = [ + max(dim_from_a, dim_from_b), + *broadcast_matmul_output_shape, + ] + + # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b) + # Prepend the broadcast_matmul_output_shape with the longer shape of input + if dim_a > dim_b: + longer_shape = input_a_shape + shorter_shape = input_b_shape + else: + longer_shape = input_b_shape + shorter_shape = input_a_shape + broadcast_matmul_output_shape = ( + longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape + ) + if mimic_matmul_broadcast_behavior and dim_b == 2: + broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] + if mimic_matmul_broadcast_behavior and dim_a == 2: + broadcast_matmul_output_shape.pop(-2) + if shape_c != broadcast_matmul_output_shape: + logger.info( + "Final output shape is not the same. Expected %s vs actual %s", + shape_c, + broadcast_matmul_output_shape, + ) + return False + + return True + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + # Create rewrite rules + two_reshapes_matmul_reshape_rule = pattern.RewriteRule( + two_reshapes_matmul_reshape_pattern, # target pattern + matmul_pattern, # replacement pattern + check_if_need_reshape, # match_condition function + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule]) + # Apply rewrite while passing match_condition + model_with_rewrite = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite + + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) diff --git a/docs/rewriter/examples/erfgelu.py b/docs/rewriter/examples/erfgelu.py new file mode 100644 index 000000000..f8723da59 --- /dev/null +++ b/docs/rewriter/examples/erfgelu.py @@ -0,0 +1,161 @@ +"""Onnx Pattern Rewriting. + +This script shows how to define a rewriting rule based on patterns. + +First a dummy model with a GELU activation +=================== +""" + +import math + +import onnx + +import onnxscript +from onnxscript import FLOAT, ir, opset18, script +from onnxscript.rewriter import pattern + + +@script() +def original_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]: + input_add = opset18.Add(X, Y) + sqrt2 = opset18.Constant(value_float=math.sqrt(2)) + erf = opset18.Erf(input_add / sqrt2) + add_const = opset18.Constant(value_float=1.0) + plus_one = erf + add_const + mul1 = input_add * plus_one + mul_const = opset18.Constant(value_float=0.5) + result = mul_const * mul1 + return result + + +_model = original_model.to_model_proto() +onnx.checker.check_model(_model) + + +#################################### +# Model demonstrating multiple patterns and variations of GELU activation +# ===================== + + +@script() +def commute_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]: + # Create first GELU variant + sqrt2_v1 = opset18.Constant(value_float=math.sqrt(2)) + erf_v1 = opset18.Erf(X / sqrt2_v1) + add_const_v1 = opset18.Constant(value_float=1.0) + plus_one_v1 = erf_v1 + add_const_v1 + mul1_v1 = X * plus_one_v1 + mul_const_v1 = opset18.Constant(value_float=0.5) + gelu1 = mul_const_v1 * mul1_v1 + + # Create second GELU variant + sqrt2_v2 = opset18.Constant(value_float=math.sqrt(2)) + erf_v2 = opset18.Erf(Y / sqrt2_v2) + add_const_v2 = opset18.Constant(value_float=1.0) + plus_one_v2 = erf_v2 + add_const_v2 + mul1_v2 = Y * plus_one_v2 + mul_const_v2 = opset18.Constant(value_float=0.5) + gelu2 = mul1_v2 * mul_const_v2 + + # Add both GELU functions + result = opset18.Add(gelu1, gelu2) + return result + + +commute_model = commute_model.to_model_proto() +onnx.checker.check_model(commute_model) + + +#################################### +# The target pattern +# ===================== + +_op = pattern.onnxop + + +def erf_gelu_pattern(x): + return 0.5 * (x * (_op.Erf(x / math.sqrt(2)) + 1.0)) + + +def erf_gelu_pattern_2(x): + return (x * (_op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5 + + +#################################### +# The replacement pattern +# ===================== + + +def gelu(op, x: ir.Value): + return op.Gelu(x, domain="com.microsoft") + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + rule = pattern.RewriteRule( + erf_gelu_pattern, # Target Pattern + gelu, # Replacement Pattern + ) + model_with_rewrite_applied = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=[rule], + ) + return model_with_rewrite_applied + + +def apply_rewrite_with_ruleset(model): + # Create multiple rules + rule1 = pattern.RewriteRule( + erf_gelu_pattern, # Target Pattern + gelu, # Replacement Pattern + ) + rule2 = pattern.RewriteRule( + erf_gelu_pattern_2, # Target Pattern + gelu, # Replacement Pattern + ) + # Create a Rewrite Rule Set with multiple rules. + rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2]) + # Apply rewrites + model_with_rewrite_applied = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + # pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing multiple rules + ) + return model_with_rewrite_applied + + +def apply_rewrite_with_commute(model): + rule = pattern.RewriteRule( + erf_gelu_pattern, # Target Pattern + gelu, # Replacement Pattern + ) + # Create a Rewrite Rule Set with commute=True + rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True) + # Apply rewrites + model_with_rewrite_applied = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite_applied + + +# Rewrite-Simple +model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(model_with_rewrite) + +# Rewrite-Single-Patterns +# Incorrect number of rewrites +model_with_single_rewrite_ruleset = apply_rewrite(commute_model) +onnx.checker.check_model(model_with_single_rewrite_ruleset) + +# Rewrite-Multiple-Patterns-RuleSet +model_with_rewrite_ruleset = apply_rewrite_with_ruleset(commute_model) +onnx.checker.check_model(model_with_rewrite_ruleset) + +# Rewrite-Multiple-Patterns-Commute +model_with_rewrite_commute = apply_rewrite_with_commute(commute_model) +onnx.checker.check_model(model_with_rewrite_commute) diff --git a/docs/rewriter/examples/img/broadcast_01.png b/docs/rewriter/examples/img/broadcast_01.png new file mode 100644 index 000000000..58df18ff7 Binary files /dev/null and b/docs/rewriter/examples/img/broadcast_01.png differ diff --git a/docs/rewriter/examples/img/broadcast_02.png b/docs/rewriter/examples/img/broadcast_02.png new file mode 100644 index 000000000..616013974 Binary files /dev/null and b/docs/rewriter/examples/img/broadcast_02.png differ diff --git a/docs/rewriter/examples/img/erfgelu_01.png b/docs/rewriter/examples/img/erfgelu_01.png new file mode 100644 index 000000000..53992ce3d Binary files /dev/null and b/docs/rewriter/examples/img/erfgelu_01.png differ diff --git a/docs/rewriter/examples/img/erfgelu_02.png b/docs/rewriter/examples/img/erfgelu_02.png new file mode 100644 index 000000000..ab000c95f Binary files /dev/null and b/docs/rewriter/examples/img/erfgelu_02.png differ diff --git a/docs/rewriter/examples/img/erfgelu_03_commute.png b/docs/rewriter/examples/img/erfgelu_03_commute.png new file mode 100644 index 000000000..cf51724e7 Binary files /dev/null and b/docs/rewriter/examples/img/erfgelu_03_commute.png differ diff --git a/docs/rewriter/examples/img/erfgelu_04_commute.png b/docs/rewriter/examples/img/erfgelu_04_commute.png new file mode 100644 index 000000000..4d38d3b4b Binary files /dev/null and b/docs/rewriter/examples/img/erfgelu_04_commute.png differ diff --git a/docs/rewriter/examples/img/erfgelu_05_commute.png b/docs/rewriter/examples/img/erfgelu_05_commute.png new file mode 100644 index 000000000..c31fb9b79 Binary files /dev/null and b/docs/rewriter/examples/img/erfgelu_05_commute.png differ diff --git a/docs/rewriter/examples/img/erfgelu_06_commute.png b/docs/rewriter/examples/img/erfgelu_06_commute.png new file mode 100644 index 000000000..e60849b10 Binary files /dev/null and b/docs/rewriter/examples/img/erfgelu_06_commute.png differ diff --git a/docs/rewriter/examples/img/erfgelu_07_commute.png b/docs/rewriter/examples/img/erfgelu_07_commute.png new file mode 100644 index 000000000..34176f695 Binary files /dev/null and b/docs/rewriter/examples/img/erfgelu_07_commute.png differ diff --git a/docs/rewriter/index.md b/docs/rewriter/index.md new file mode 100644 index 000000000..3b4e01e14 --- /dev/null +++ b/docs/rewriter/index.md @@ -0,0 +1,5 @@ +# Rewriter Tutorials + +```{toctree} +rewrite_patterns +``` diff --git a/docs/rewriter/rewrite_patterns.md b/docs/rewriter/rewrite_patterns.md new file mode 100644 index 000000000..87a5e31af --- /dev/null +++ b/docs/rewriter/rewrite_patterns.md @@ -0,0 +1,198 @@ +# Pattern-based Rewrite Using Rules + +## Introduction + +The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on rewrite rules provided by the user. + +## Usage + +There are three main components needed when rewriting patterns in the graph: + +1. `target_pattern` : Original pattern to match against. This pattern is written as a function using ONNXScript-like operators. +2. `replacement_pattern` : Pattern to replace the original pattern with. This pattern is also written as a function using ONNXScript-like operators. +3. `match_condition` (optional) : Pattern rewrite will occur only if the match condition is satisfied. + +(heading-target-simple)= +## A Simple Example + +An simple example demonstrating the usage of this functionality using the `GELU` activation function: + +`GELU` activation function can be computed using a Gauss Error Function using the given formula: + +
+ +
+ +We will show how we can find a subgraph matching this computation and replace it by a call to the function. + +Firstly, include all the rewriter relevant imports. + +```python + from onnxscript.rewriter import pattern + from onnxscript import ir + + _op = pattern.onnxop +``` + +Then create a target pattern that needs to be replaced using onnxscript operators. + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` + +After this, create a replacement pattern that consists of the GELU onnxscript operator. + +```{literalinclude} examples/erfgelu.py +:pyobject: gelu +``` +:::{note} +:name: type annotate ir.Value + +The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value ` class. +::: + + +For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function. + +```python + rule = pattern.RewriteRule( + erf_gelu_pattern, # Target Pattern + gelu, # Replacement Pattern + ) +``` + +Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components: + +1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`. +2. `function_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on function names. Steps on how to use this parameter will be covered in a different tutorial. This parameter is of type `Sequence[type[FunctionRewriteRule]]` +3. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. For the purpose of this tutorial, we will be using only this parameter in conjunction with `model`. This parameter is of either one of these types: + - `Sequence[PatternRewriteRule]` + - `RewriteRuleSet` + +:::{note} +:name: pattern_rewrite_rules input formatting + +`pattern_rewrite_rules` takes a sequence of `PatternRewriteRule` types or a RewriteRuleSet which is also essentially a rule set created using a sequence of `PatternRewriteRule` types, so if only a singular rewrite rule is to be passed, it needs to passed as part of a sequence. For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset). +::: + +The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above: + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite +``` + +The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended. + +![target_pattern](examples/img/erfgelu_01.png) ![replacement_pattern](examples/img/erfgelu_02.png) + + +(heading-target-commute)= +## Utilizing `commute` parameter for pattern-matching +Extending the previous [simple example](heading-target-simple), assumming a scenario where we have a graph with the following structure. + +![commute](examples/img/erfgelu_03_commute.png){align=center width=500px} + +In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched. + +![gelu_pattern_1](examples/img/erfgelu_04_commute.png){width=330px align=left} ![gelu_pattern_2](examples/img/erfgelu_05_commute.png){width=330px align=center} + + +If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched. + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` + +```{image} examples/img/erfgelu_06_commute.png +:alt: The resulting graph after matching. +:width: 400px +:align: center +``` + +Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods. + +(heading-target-commute-ruleset)= +### 1. Creating a rule-set with different patterns. + +This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example: + +```python + from onnxscript.rewriter import pattern + rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2]) +``` + +In order to apply this method to the example above, first create the two separate target patterns as follows: + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern_2 +``` + +Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite_with_ruleset +``` + + +### 2. Using the `commute` parameter while creating a rule. + +Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite_with_commute +``` + +For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows: + +![commute](examples/img/erfgelu_07_commute.png){align=center width=300px} + +## Using the `match_condition` parameter for pattern-matching + +This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration. + +Let us consider a model which consists of the following pattern. + +![target_pattern](examples/img/broadcast_01.png){align=center} + +Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: + +1. Input shapes check: `input_a` and `input_b` should be broadcastable +2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)` + +If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite. + +First, write a target pattern and replacement pattern in a similar way to the first example. + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: two_reshapes_matmul_reshape_pattern +``` + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: matmul +``` + +:::{note} +:name: omitting inputs in signature + +The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters. + +Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature. +::: + +In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows: + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: check_if_need_reshape +``` + +With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite. + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: apply_rewrite +``` + +The final graph with the applied rewrite looks as follows: +![broadcast_rewrite](examples/img/broadcast_02.png){align=center} diff --git a/docs/test/test_documentation_examples.py b/docs/test/test_documentation_examples.py index 6645882c8..368025ab6 100644 --- a/docs/test/test_documentation_examples.py +++ b/docs/test/test_documentation_examples.py @@ -50,6 +50,7 @@ def test(*relpath): test("..", "..", "docs", "examples") test("..", "..", "docs", "tutorial", "examples") + test("..", "..", "docs", "rewriter", "examples") if __name__ == "__main__": diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index 8e5ee638e..20071f3a8 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -1,11 +1,9 @@ from __future__ import annotations import logging -from typing import Any import numpy as np -from onnxscript import ir from onnxscript.rewriter import _ir_utils, pattern op = pattern.onnxop @@ -13,7 +11,7 @@ # condition to check if we need to replace the pattern -def check_if_need_reshape(match_bindings: dict[str, ir.Value | Any]) -> bool: +def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool: """If matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: @@ -22,17 +20,14 @@ def check_if_need_reshape(match_bindings: dict[str, ir.Value | Any]) -> bool: If the above are true, then we don't need the reshapes. - Args: - match_bindings: The match binding dictionary from a MatchResult. - Returns: bool: True if we need to replace the pattern, False otherwise. """ - input_a_shape = match_bindings["input_a"].shape - input_b_shape = match_bindings["input_b"].shape + input_a_shape = input_a.shape + input_b_shape = input_b.shape # TODO: Get a helper func to get const_value - shape_c_value = _ir_utils.propagate_const_value(match_bindings["shape_c"]) + shape_c_value = _ir_utils.propagate_const_value(shape_c) shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr] if shape_c is None: return False diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index e15954d24..1a53d59d3 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -1,12 +1,10 @@ from __future__ import annotations import logging -from typing import Any import numpy as np import onnx -from onnxscript import ir from onnxscript.rewriter import _ir_utils, pattern op = pattern.onnxop @@ -16,7 +14,7 @@ logger = logging.getLogger(__name__) -def _check_if_simulated_instance_norm_is_used_impl( +def check_if_simulated_instance_norm_is_used( input_x, adjusted_input_shape, original_input_shape, @@ -24,8 +22,25 @@ def _check_if_simulated_instance_norm_is_used_impl( bias_for_norm, weight_full, bias_full, - **kwargs, + **_, ) -> bool: + """Check if the simulated instance normalization is used. + + In torchlib with opset18, onnx.GroupNorm is using wrong definition, so + we use InstanceNormalization to simulate GroupNormalization. We need to check if there are arguments created to simulation. + If there are, then we need to replace the pattern. If they are not used, then we don't need to replace the pattern. + + To validate this, we need to check the following: + 1. weight_for_norm are all 1 and bias_for_norm are all 0, as they are created for the simulation. + 2. weight_full and bias_full are unsqueezed to be easily broadcastable. + 3. input rank should be 4 + 4. weight_full and bias_full should have ones except first dim. + 5. adjusted_input_shape is a constant tensor of form [0, g, -1] + 6. original_input_shape is the same as input_x shape. + + Returns: + bool: True if the simulated instance normalization is used, False otherwise. + """ weight_for_norm = _ir_utils.propagate_const_value(weight_for_norm) weight_for_norm = _ir_utils.get_numpy_from_ir_value(weight_for_norm) @@ -70,32 +85,6 @@ def _check_if_simulated_instance_norm_is_used_impl( return True -def check_if_simulated_instance_norm_is_used( - match_bindings: dict[str, ir.Value | Any], -) -> bool: - """Check if the simulated instance normalization is used. - - In torchlib with opset18, onnx.GroupNorm is using wrong definition, so - we use InstanceNormalization to simulate GroupNormalization. We need to check if there are arguments created to simulation. - If there are, then we need to replace the pattern. If they are not used, then we don't need to replace the pattern. - - To validate this, we need to check the following: - 1. weight_for_norm are all 1 and bias_for_norm are all 0, as they are created for the simulation. - 2. weight_full and bias_full are unsqueezed to be easily broadcastable. - 3. input rank should be 4 - 4. weight_full and bias_full should have ones except first dim. - 5. adjusted_input_shape is a constant tensor of form [0, g, -1] - 6. original_input_shape is the same as input_x shape. - - Args: - match_bindings: The match binding dictionary from a MatchResult. - - Returns: - bool: True if the simulated instance normalization is used, False otherwise. - """ - return _check_if_simulated_instance_norm_is_used_impl(**match_bindings) - - def instance_simulates_group_normalization_pattern( input_x, adjusted_input_shape, diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/onnxruntime/softmax.py index 682550e18..df868f134 100644 --- a/onnxscript/rewriter/onnxruntime/softmax.py +++ b/onnxscript/rewriter/onnxruntime/softmax.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -from typing import Any import onnx @@ -32,15 +31,14 @@ def softmax_without_axis(op, input): return op.Softmax(input) -def check_if_fp16_input(match_bindings: dict[str, ir.Value | Any]) -> bool: - input_val = match_bindings.get("input") - if input_val is None: +def check_if_fp16_input(input, **_) -> bool: + if input is None: logger.warning( "Cannot perform softmax upcast removal: " "cannot retrieve match_bindings for 'input' for dtype validation." ) return False - return input_val.dtype == ir.DataType.FLOAT16 + return input.dtype == ir.DataType.FLOAT16 # pylint: disable=pointless-string-statement diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d144502ed..7e38651db 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -807,10 +807,15 @@ def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: if len(node.outputs) != self._target_num_outputs: return MatchResult.FAIL() match = self._target_node_pattern.matches_node(node, model) + # NOTE: migrating to a simpler interface for match_condition signature. + # Ideally, the caller should pass in match_bindings as **match_bindings. + # This makes it easier to define this as a function with inputs like + # (input_a, input_b, **_) and omit all references to match_bindings. + # **_ refers to all the unused parameters in the match_condition function. if ( self._condition_function is not None and match - and not self._condition_function(match.bindings) + and not self._condition_function(**match.bindings) ): return MatchResult.FAIL() return match diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index a30a2341c..45bdcd6ad 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -245,7 +245,7 @@ def identity(op, x, newshape): del newshape # Unused return op.Identity(x) - def _check_for_redundant_reshape(x, newshape): + def check_for_redundant_reshape(x, newshape): oldshape = x.shape newshape = _ir_utils.propagate_const_value(newshape) newshape = _ir_utils.get_numpy_from_ir_value(newshape) @@ -257,9 +257,6 @@ def _check_for_redundant_reshape(x, newshape): return False return all(not (d1 != d2 and d2 != -1) for d1, d2 in zip(oldshape, newshape)) # pylint: disable=consider-using-in - def check_for_redundant_reshape(bindings): - return _check_for_redundant_reshape(**bindings) - rule = pattern.RewriteRule(reshape, identity, check_for_redundant_reshape) model_proto = onnx.parser.parse_model(