From 0d98619dee85025f8fb110864607f6f477c3d8ae Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 24 Apr 2024 20:18:26 -0700 Subject: [PATCH] Refactor pattern matcher (#1411) Part 1 of refactoring the pattern matcher (towards a unified API and implementation of existing two pattern matchers). * Eliminate separate "Var" class for patterns. (Every value-pattern will have an optional name, used to bind to corresponding ir-value in a match.) * Eliminate the delay-run (vs. eager-run) distinction. Everything is used in delayed-run mode now. * Switch all replacement-builders to use context-parameter to directly build IR nodes for replacement. * Cleanup some rewrite rules * Get rid of the to_ir methods --------- Signed-off-by: Ganesan Ramalingam Co-authored-by: Justin Chu --- onnxscript/rewriter/broadcast_to_matmul.py | 15 +- onnxscript/rewriter/cast_constant_of_shape.py | 57 +-- onnxscript/rewriter/erfgelu.py | 4 +- onnxscript/rewriter/gemm_to_matmul_add.py | 2 +- onnxscript/rewriter/generic_pattern.py | 33 +- onnxscript/rewriter/generic_pattern_test.py | 6 +- onnxscript/rewriter/no_op.py | 2 +- .../group_normalization_merge_silu.py | 4 +- .../group_normalization_merge_silu_test.py | 4 + .../instance_to_group_normalization.py | 27 +- onnxscript/rewriter/onnxruntime/softmax.py | 4 +- onnxscript/rewriter/pattern.py | 347 +++++++----------- onnxscript/rewriter/pattern_test.py | 42 ++- 13 files changed, 206 insertions(+), 341 deletions(-) diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index 92757c19e..8e5ee638e 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -140,10 +140,7 @@ def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shap return op.Reshape(matmul, shape_c) -def matmul_with_two_shape_inputs(input_a, input_b, shape_a, shape_b, shape_c): - del shape_a # Unused - del shape_b # Unused - del shape_c # Unused +def matmul(op, input_a, input_b, **_): return op.MatMul(input_a, input_b) @@ -153,21 +150,15 @@ def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c): return op.Reshape(matmul, shape_c) -def matmul_with_one_shape_input(input_a, input_b, shape_a, shape_c): - del shape_a # Unused - del shape_c # Unused - return op.MatMul(input_a, input_b) - - # Register the rewrite rules two_reshapes_matmul_reshape_rule = pattern.RewriteRule( two_reshapes_matmul_reshape_pattern, - matmul_with_two_shape_inputs, + matmul, check_if_need_reshape, ) one_reshape_matmul_reshape_rule = pattern.RewriteRule( one_reshape_matmul_reshape_pattern, - matmul_with_one_shape_input, + matmul, # We can use the same check_if_need_reshape function for both the rules, # as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern. check_if_need_reshape, diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/cast_constant_of_shape.py index 09b5db893..6a9cf855f 100644 --- a/onnxscript/rewriter/cast_constant_of_shape.py +++ b/onnxscript/rewriter/cast_constant_of_shape.py @@ -1,9 +1,8 @@ from __future__ import annotations import logging -from typing import Any, Sequence -import numpy as np +import onnx.helper from onnxscript import ir from onnxscript.rewriter import pattern @@ -12,58 +11,34 @@ logger = logging.getLogger(__name__) -def cast_constant_of_shape( - shape: Sequence[int], - t: Any, - dtype: int, - match_bindings: dict[str, ir.Value | Any] | None = None, -) -> pattern.OpPattern: - constant = op.ConstantOfShape(shape, value=t) +def cast_constant_of_shape(shape, scalar, dtype): + constant = op.ConstantOfShape(shape, value=scalar) return op.Cast(constant, to=dtype) -def fused_cast_constant_of_shape( - shape: Sequence[int], t: Any, dtype: int, match_bindings: dict[str, ir.Value | Any] -) -> pattern.OpPattern: - del dtype # unused - del t # unused - v_dtype = match_bindings["dtype"] - v_t = match_bindings["t"] - v_dtype = ir.DataType(v_dtype.value).numpy() # type: ignore[union-attr] - casted_val = ir.Tensor(v_t.value.numpy().astype(v_dtype)) # type: ignore[union-attr] - return op.ConstantOfShape(shape, value=casted_val) - - -def cast_constant_of_shape_without_value( - shape: Sequence[int], - dtype: int, - match_bindings: dict[str, ir.Value | Any] | None = None, -) -> pattern.OpPattern: - del match_bindings # Unused +def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir.Attr, **_): + # Cast scalar (a TensorProto attribute) to the specified dtype + scalar_value = scalar.value.numpy().item() + cast_value = onnx.helper.make_tensor("value", dtype.value, (), [scalar_value]) + return op.ConstantOfShape(shape, value=cast_value) + + +def cast_constant_of_shape_without_value(shape, dtype): constant = op.ConstantOfShape(shape) return op.Cast(constant, to=dtype) -def fused_cast_constant_of_shape_without_value( - shape: Sequence[int], dtype: int, match_bindings: dict[str, ir.Value | Any] -) -> pattern.OpPattern: - del dtype # Unused - v_dtype = match_bindings["dtype"] - v_dtype = ir.DataType(v_dtype.value).numpy() # type: ignore[union-attr] - val = ir.Tensor(np.zeros(1, dtype=v_dtype)) - return op.ConstantOfShape(shape, value=val) +def fused_cast_constant_of_shape_without_value(op, shape, dtype, **_): + zero = onnx.helper.make_tensor("value", dtype.value, (), [0]) + return op.ConstantOfShape(shape, value=zero) cast_constant_of_shape_rule = pattern.RewriteRule( - cast_constant_of_shape, - pattern.ReplacementPatternFunction(fused_cast_constant_of_shape, delay_run=True), + cast_constant_of_shape, fused_cast_constant_of_shape ) cast_constant_of_shape_without_value_rule = pattern.RewriteRule( - cast_constant_of_shape_without_value, - pattern.ReplacementPatternFunction( - fused_cast_constant_of_shape_without_value, delay_run=True - ), + cast_constant_of_shape_without_value, fused_cast_constant_of_shape_without_value ) rules = pattern.RewriteRuleSet( diff --git a/onnxscript/rewriter/erfgelu.py b/onnxscript/rewriter/erfgelu.py index 67f0d47e1..59d689cee 100644 --- a/onnxscript/rewriter/erfgelu.py +++ b/onnxscript/rewriter/erfgelu.py @@ -23,8 +23,8 @@ def erf_gelu_pattern(x): # Replacement -def gelu(x): - return msft_op.Gelu(x) +def gelu(op, x): + return op.Gelu(x, domain="com.microsoft") rule = pattern.RewriteRule(erf_gelu_pattern, gelu) diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/gemm_to_matmul_add.py index ae44ffe27..cce9865c9 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/gemm_to_matmul_add.py @@ -13,7 +13,7 @@ def reshape_gemm_reshape_pattern(input_a, input_b, input_c, shape_a, shape_c): return op.Reshape(gemm, shape_c) -def matmul_add(input_a, input_b, input_c, shape_a, shape_d): +def matmul_add(op, input_a, input_b, input_c, **_): matmul = op.MatMul(input_a, input_b) return op.Add(matmul, input_c) diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 6cc70da64..f53067820 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -11,36 +11,7 @@ import onnxscript import onnxscript.rewriter.pattern as orp from onnxscript import ir -from onnxscript.rewriter import _ir_utils, _tape - - -class _SimpleBuilder: - """temporary adaptor for building 'generic patterns'.""" - - # TODO(justinchuby): Merge with the rest of pattern building methods - def __init__(self): - self.tape = _tape.Tape() - - def __getattr__(self, op_type: str) -> Any: - return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) - - def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): - domain = kwargs.pop("domain", "") - output_names = kwargs.pop("output_names", 1) - if isinstance(output_names, Sequence): - num_outputs = len(output_names) - else: - assert isinstance(output_names, int) - num_outputs = output_names - if num_outputs == 1: - return self.tape.op(op_type, inputs=inputs, attributes=kwargs, domain=domain) - return self.tape.op_multi_output( - op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs - ) - - @property - def nodes(self) -> Sequence[ir.Node]: - return self.tape.nodes +from onnxscript.rewriter import _ir_utils class PatternMatchResult: @@ -339,7 +310,7 @@ def _build_pattern( assert len(kwargs) == 0, f"Attributes are not supported yet but kwargs={kwargs}" inputs = [ir.Input(name=name) for name in args] - builder = _SimpleBuilder() + builder = orp.RewriterContext() outputs = func(builder, *inputs, **kwargs) if isinstance(outputs, ir.Value): outputs = [outputs] diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index 2e09e7891..bd120e780 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -124,7 +124,7 @@ def match_pattern(cls, op, x, y, w, z): @classmethod def apply_pattern(cls, op, x, y, w, z): """Builds the pattern to match.""" - return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", output_names=2) + return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", outputs=2) def validate_mapping( self, @@ -273,7 +273,7 @@ def match_pattern(cls, op, x, pos_ids, axis): transpose, transpose, domain="com.microsoft", - output_names=2, + outputs=2, ) sin = op.Sin(output) @@ -307,7 +307,7 @@ def apply_pattern(cls, op, x, pos_ids, axis): cos_cache, sin_cache, domain="com.microsoft", - output_names=2, + outputs=2, ) model = self.get_rotary_model() diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 0a149ad96..bd9b1c370 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -24,7 +24,7 @@ def div_by_1(x): # Replacement -def identity(x): +def identity(op, x): return op.Identity(x) diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py index a6dfb54eb..d4c60e59e 100644 --- a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py @@ -32,13 +32,14 @@ def group_normalization_and_silu_submodule( def group_normalization_with_silu( + op, input, weight, bias, epsilon, groups, ): - group_norm = msft_op.GroupNorm( + group_norm = op.GroupNorm( input, weight, bias, @@ -46,6 +47,7 @@ def group_normalization_with_silu( channels_last=1, epsilon=epsilon, groups=groups, + domain="com.microsoft", ) return op.Transpose(group_norm, perm=[0, 3, 1, 2]) diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py index b2f824e3f..ced611685 100644 --- a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py @@ -123,3 +123,7 @@ def test_simulated_instance_norm_is_replaced_by_group_norm_silu(self): self.assertEqual(count, 2) # plus 2 in model constants self.assertEqual(len(model.graph), 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index 4036f0078..e15954d24 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -105,7 +105,6 @@ def instance_simulates_group_normalization_pattern( weight_full, bias_full, epsilon, - match_bindings: dict[str, ir.Value | Any] | None = None, ): adjusted_input = op.Reshape(input_x, adjusted_input_shape) inst_norm = op.InstanceNormalization( @@ -116,17 +115,7 @@ def instance_simulates_group_normalization_pattern( return op.Add(mul, bias_full) -def group_normalization( - input_x, - adjusted_input_shape, - original_input_shape, - weight_for_norm, - bias_for_norm, - weight_full, - bias_full, - epsilon, - match_bindings: dict[str, ir.Value], -): +def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, epsilon, **_): # com.microsoft.GroupNorm only supports NHWC for now nhwc_input = op.Transpose(input_x, perm=[0, 2, 3, 1]) # com.microsoft.GroupNorm only supports gamma and beta as float type @@ -136,12 +125,13 @@ def group_normalization( bias_full = op.Cast(bias_full, to=onnx.TensorProto.FLOAT) bias_full = op.Reshape(bias_full, reshape_to_1d) # re-obtain attribute groups - if "weight_for_norm" not in match_bindings: - raise ValueError("weight_for_norm is not found in match_bindings") - if match_bindings["weight_for_norm"].shape is None: + # TODO(rama): Earlier check implies weight_for_norm is a constant tensor? + # If not, we should add a check that shape[0] is not symbolic. + shape = weight_for_norm.shape + if shape is None: raise ValueError("weight_for_norm shape not known") - groups = match_bindings["weight_for_norm"].shape[0] - output = msft_op.GroupNorm( + groups = shape[0] + output = op.GroupNorm( nhwc_input, weight_full, bias_full, @@ -149,6 +139,7 @@ def group_normalization( channels_last=1, epsilon=epsilon, groups=groups, + domain="com.microsoft", ) return op.Transpose(output, perm=[0, 3, 1, 2]) @@ -156,7 +147,7 @@ def group_normalization( # Register the rewrite rules instance_norm_to_group_norm_rule = pattern.RewriteRule( instance_simulates_group_normalization_pattern, - pattern.ReplacementPatternFunction(group_normalization, delay_run=True), + group_normalization, check_if_simulated_instance_norm_is_used, ) diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/onnxruntime/softmax.py index 225f27cfb..682550e18 100644 --- a/onnxscript/rewriter/onnxruntime/softmax.py +++ b/onnxscript/rewriter/onnxruntime/softmax.py @@ -18,7 +18,7 @@ def softmax_with_fp32_upcast(input, axis): return op.Cast(softmax, to=onnx.TensorProto.FLOAT16) -def softmax(input, axis): +def softmax(op, input, axis): return op.Softmax(input, axis=axis) @@ -28,7 +28,7 @@ def softmax_with_fp32_upcast_without_axis(input): return op.Cast(softmax, to=onnx.TensorProto.FLOAT16) -def softmax_without_axis(input): +def softmax_without_axis(op, input): return op.Softmax(input) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 050851932..ec2b41d13 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1,9 +1,10 @@ from __future__ import annotations +import dataclasses import inspect import itertools import math -from typing import Any, Callable, Sequence +from typing import Any, Callable, List, Optional, Sequence, Tuple import numpy as np import onnx @@ -12,15 +13,13 @@ from onnxscript import ir from onnxscript.ir import _convenience -from onnxscript.rewriter import _ir_utils +from onnxscript.rewriter import _ir_utils, _tape # Overview of the pattern module: The classes below are used to define both # patterns (that we search for) and replacements for rewrite rules. # The matches() method of a pattern is used to check if an IR component # matches the pattern. -# The to_ir() method of a pattern is used to create a new IR component # TODO: Ensure that all matches() methods have same type signature (where -# appropriate) and that all to_ir() methods have same type signature (where # appropriate). @@ -40,9 +39,6 @@ def name(self) -> str | None: def matches(self, value: int | str | Sequence) -> bool: return value == self.value - def to_ir(self, model, bindings=None) -> int | str | Sequence: - return self.value - class StringConstantPattern: def __init__(self, value: str, name: str) -> None: @@ -60,9 +56,6 @@ def name(self) -> str: def matches(self, attr: ir.AttrString) -> bool: return attr.value == self.value - def to_ir(self, model, bindings=None) -> ir.AttrString: - return ir.AttrString(value=self.value, name=self.name) - class IntConstantPattern: def __init__(self, value: int, name: str) -> None: @@ -80,9 +73,6 @@ def name(self) -> str: def matches(self, attr: ir.AttrInt64) -> bool: return attr.value == self.value - def to_ir(self, model, bindings=None) -> ir.AttrInt64: - return ir.AttrInt64(value=self.value, name=self.name) - class ListConstantPattern: def __init__(self, value: Sequence[int | str | float], name: str) -> None: @@ -101,18 +91,6 @@ def matches(self, attr: ir.AttrFloat32s | ir.AttrInt64s | ir.AttrStrings) -> boo # TODO: Need more data points to determine if this is the right way to compare lists. return attr.value == self.value - def to_ir(self, model, bindings=None) -> ir.AttrFloat32s | ir.AttrInt64s | ir.AttrStrings: - the_first_non_none_item = next((item for item in self.value if item is not None), None) - if isinstance(the_first_non_none_item, int): - return ir.AttrInt64s(value=self.value, name=self.name) # type: ignore - if isinstance(the_first_non_none_item, str): - return ir.AttrStrings(value=self.value, name=self.name) # type: ignore - if isinstance(the_first_non_none_item, float): - return ir.AttrFloat32s(value=self.value, name=self.name) # type: ignore - raise TypeError( - f"Cannot convert list of {type(the_first_non_none_item)} to ConstantPattern" - ) - class PrefixPattern: """This pattern is used to simplify submodule opset pattern matching.""" @@ -127,9 +105,6 @@ def value(self) -> str: def matches(self, value: str) -> bool: return value.startswith(self.value) - def to_ir(self, model, bindings=None) -> str: - raise NotImplementedError("PrefixPattern should not be converted to IR") - class FloatConstantPattern: def __init__( @@ -153,9 +128,6 @@ def matches(self, attr: ir.AttrFloat32): attr.value, self.value, rel_tol=self._rel_tol, abs_tol=self._abs_tol ) - def to_ir(self, model, bindings=None) -> ir.AttrFloat32: - return ir.AttrFloat32(self.name, self.value) - class TensorConstantPattern: def __init__( @@ -186,9 +158,6 @@ def matches(self, attr: ir.AttrTensor): ) ) - def to_ir(self, model, bindings=None) -> ir.AttrTensor: - return ir.AttrTensor(self.name, self.value) - def _make_constant_pattern( value: float | int | Sequence | ir.TensorProtocol, name: str @@ -238,15 +207,6 @@ def matches( return self.value_pattern.matches(attr_val, model) # type: ignore[arg-type] return self.value_pattern.matches(attr_val) - def to_ir(self, model: ir.Model, rewrite_cache: RewriteCache, bindings=None) -> ir.Value: - if isinstance(self.value_pattern, Var): - val, nodes = self.value_pattern.to_ir( - model, bindings, 1, rewrite_cache - ) # TODO: handle multiple outputs - return val - # constant pattern - return self.value_pattern.to_ir(model, bindings) - class OpsetPattern: """Represents an opset pattern. @@ -292,17 +252,6 @@ def matches(self, opset): domain, version = opset return self.domain_pattern.matches(domain) and self.version_pattern.matches(version) - def to_ir(self, model, bindings=None) -> str: - domain = self.domain_pattern.to_ir(model, bindings) - assert isinstance(domain, str), f"Expected str, got {type(domain)}" - # TODO: Should we ban other custom domains? - if domain not in model.opset_imports: - assert isinstance( - self.version_pattern, PythonPattern - ), f"custom domain {domain} needs to have a specific version." - model.opset_imports[self.domain_pattern.value] = self.version_pattern.value - return domain - def __getattr__(self, name: str) -> Any: return OpPattern(self, PythonPattern(name)) @@ -343,6 +292,7 @@ def __init__( self.op_name_pattern = op_name_pattern def __call__(self, *args, **kwargs): + # TODO(rama): Unify with convention used elsewhere. if "_num_outputs" in kwargs: num_outputs = kwargs["_num_outputs"] del kwargs["_num_outputs"] @@ -359,8 +309,8 @@ def __call__(self, *args, **kwargs): def _to_value_pattern( - x: ValuePattern | int | float, -) -> NodeOutputPattern | Constant | Var | ValuePattern: + x: ValuePattern | int | float | None, +) -> NodeOutputPattern | Constant | ValuePattern | None: """Promotes an input-value used to construct a NodePattern to a ValuePattern. Example usage: @@ -376,7 +326,7 @@ def _to_value_pattern( :: z = op.Add(x, op.Constant(0)) """ - if isinstance(x, ValuePattern): + if x is None or isinstance(x, ValuePattern): return x if isinstance(x, (int, float, Sequence)): return Constant(x) @@ -461,8 +411,19 @@ class ValuePattern: operations, so that we can write patterns like `x + 1` and `1 + x`. """ - def __init__(self) -> None: - pass + def __init__(self, name: str | None) -> None: + self.name = name + + def __repr__(self) -> str: + return f"ValuePattern({self.name!r})" + + def matches(self, value: ir.Value, model: ir.Model): + if self.name is None: + return MatchResult([], {}) + return MatchResult([], {self.name: value}) + + def commute(self) -> Sequence[ValuePattern]: + return [self] def __add__(self, other): return onnxop.Add(self, other) @@ -504,7 +465,7 @@ def __init__( self, domain: OpsetPattern, op: PythonPattern | PrefixPattern, - inputs: Sequence[int | float | ValuePattern], + inputs: Sequence[int | float | ValuePattern | None], attributes: dict[str, AttrPattern], ): self.domain = domain @@ -529,7 +490,7 @@ def matches(self, value: ir.Value, model: ir.Model): def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: """Examine if the IR node matches the self pattern.""" - node_version = model.graph.opset_imports.get(node.domain, 0) + node_version = model.graph.opset_imports.get(node.domain, 1) if not self.domain.matches((node.domain, node_version)): return MatchResult.FAIL() if not self.op.matches(node.op_type): @@ -540,6 +501,10 @@ def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: # because at least the starting node op_type is already matched. for arg_value, previous_node_output_pattern in zip(node.inputs, self.inputs): # previous_node_output_pattern could be a Var, if it's the original arg. + if arg_value is None and previous_node_output_pattern is None: + continue + if arg_value is None or previous_node_output_pattern is None: + return MatchResult.FAIL() sub_match = previous_node_output_pattern.matches(arg_value, model) # type: ignore[attr-defined] match.extend(sub_match, model) if not match: # If sub-match failed, @@ -561,50 +526,10 @@ def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: match.values.append(node) # type: ignore[attr-defined] return match - def to_ir( - self, - model: ir.Model, - bindings: dict[str, ir.Value | Any], - num_outputs: int, - rewrite_cache: RewriteCache, - ) -> tuple[Sequence[ir.Value], Sequence[ir.Node]]: - domain = self.domain.to_ir(model) - op = self.op.to_ir(model) - assert isinstance(op, str), f"Expected str, got {type(op)}" - inputs = [] - nodes: list[ir.Node] = [] - for val_pattern in self.inputs: - if ( - value_and_node := rewrite_cache.get_node_output_pattern(val_pattern) # type: ignore[arg-type] - ) is not None: - val, n = value_and_node - else: - val, n = val_pattern.to_ir(model, bindings, 1, rewrite_cache) # type: ignore[attr-defined] - rewrite_cache.set_node_output_pattern_with_ir(val_pattern, val, n) # type: ignore[arg-type] - nodes.extend(n) # type: ignore[arg-type] - # If one of the inputs was a the output of a previous node, - # unpack the new output ir value that is created for that node - if isinstance(val, tuple): - # TODO: Move implementation of output_index to NodeOutputPatter.to_ir - inputs.append(val[val_pattern.output_index]) - else: - inputs.append(val) - attributes = ( - attr_pattern.to_ir(model, rewrite_cache, bindings) - for attr_pattern in self.attributes.values() - ) - new_node = ir.Node( - domain=domain, - op_type=op, - inputs=inputs, - attributes=attributes, # type: ignore[arg-type] - num_outputs=num_outputs, - ) - nodes.append(new_node) - return new_node.outputs, nodes - def commute(self) -> Sequence[NodePattern]: - list_of_lists = [pattern.commute() for pattern in self.inputs] # type: ignore[attr-defined] + list_of_lists = [ + [None] if pattern is None else pattern.commute() for pattern in self.inputs + ] # type: ignore[attr-defined] def enumerate_inputs(inputs, index): if index >= len(inputs): @@ -631,7 +556,10 @@ class NodeOutputPattern(ValuePattern): is values computed using a specific op. """ - def __init__(self, node_pattern: NodePattern, output_index: int) -> None: + def __init__( + self, node_pattern: NodePattern, output_index: int, name: str | None = None + ) -> None: + super().__init__(name) self.node_pattern = node_pattern self.output_index = output_index @@ -644,44 +572,8 @@ def matches(self, value: ir.Value, model: ir.Model): return MatchResult.FAIL() return self.node_pattern.matches_node(node, model) - def to_ir( - self, - model: ir.Model, - bindings: dict[str, ir.Value | Any], - num_outputs: int, - rewrite_cache: RewriteCache, - ) -> tuple[Sequence[ir.Value], Sequence[ir.Node]]: - assert self.output_index == 0, "TODO: handle multiple outputs" - return self.node_pattern.to_ir(model, bindings, num_outputs, rewrite_cache) - - -class Var(ValuePattern): - """Represents a pattern variable.""" - - def __init__(self, name: str) -> None: - self.pattern_var_name = name - self.bound_value = None - - def __repr__(self) -> str: - return f"Var({self.pattern_var_name!r})" - - def matches(self, value: ir.Value, model: ir.Model): - return MatchResult([], {self.pattern_var_name: value}) - - def to_ir( - self, - model: ir.Model, - bindings: dict[str, ir.Value | Any], - num_outputs: int, - rewrite_cache: RewriteCache, - ) -> tuple[ir.Value, Sequence]: - del model # Unused - del num_outputs # Unused - del rewrite_cache # Unused - return bindings[self.pattern_var_name], [] - def commute(self) -> Sequence[ValuePattern]: - return [self] +Var = ValuePattern class Constant(ValuePattern): @@ -690,6 +582,7 @@ class Constant(ValuePattern): def __init__( self, value: int | float, rel_tol: float = 1e-5, abs_tol: float = 1e-8 ) -> None: + super().__init__(None) self.value = value self.rel_tol = rel_tol self.abs_tol = abs_tol @@ -763,11 +656,6 @@ def pattern(x, shape1, shape2): return node_pattern, num_outputs -# Currently, the replacement graph function is the same as the pattern function. -# This may change in the future. -_handle_replacement_return_value = _handle_pattern_return_value - - def _valid_to_replace(matched_nodes: Sequence[Any]) -> bool: """Check that values computed by the matched_nodes, except for the last one, are used only by the matched_nodes.""" # * Must check that all values matched by pattern are used only by pattern, @@ -800,63 +688,84 @@ def __init__(self, function: Callable) -> None: def function(self) -> Callable: return self._function - def get_pattern(self, *variables: Sequence[Var]) -> tuple[NodePattern, int]: + def get_pattern(self, variables: Sequence[Var]) -> tuple[NodePattern, int]: node_output_pattern = self._function(*variables) return _handle_pattern_return_value(node_output_pattern) -class ReplacementPatternFunction: - """The replacement pattern that will replace the targeted pattern. +# A type representing the domains/versions used in creating a replacement subgraph +UsedOpsets = List[Tuple[str, Optional[int]]] - Attributes: - function (Callable): The replacement function that will be used to replace the matched pattern. - delay_run (bool): If True, the replacement function will not be run until the matched pattern is found. - This is useful when we want to extract certain metavalue from the matched pattern and use it in the - replacement pattern. - """ - def __init__(self, function, *, delay_run: bool = False): - self._function = function - self._delay_run = delay_run +class RewriterContext: + """Context parameter used to build the replacement pattern.""" + + # TODO(justinchuby): Merge with the rest of pattern building methods + def __init__(self): + self._tape = _tape.Tape() + self._used_opsets: UsedOpsets = [] + + def __getattr__(self, op_type: str) -> Any: + return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) + + def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): + domain = kwargs.pop("domain", "") + version = kwargs.pop("version", None) + self._used_opsets.append((domain, version)) + outputs = kwargs.pop("outputs", 1) + if isinstance(outputs, Sequence): + num_outputs = len(outputs) + else: + assert isinstance(outputs, int) + num_outputs = outputs + if num_outputs == 1: + return self._tape.op(op_type, inputs=inputs, attributes=kwargs, domain=domain) + return self._tape.op_multi_output( + op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs + ) @property - def function(self) -> Callable: - return self._function + def nodes(self) -> Sequence[ir.Node]: + # TODO(rama): The current tape-based implementation will not track nodes added + # via overloaded operators, eg., `x + y`. One possible way to fix this is to + # have values/nodes know which tape they belong to (instead of a graph/function). + # However, it is unclear we need this feature for rewriting: we could also + # identify the nodes to be inserted from the replacement values (by tracing back). + return self._tape.nodes @property - def delay_run(self) -> bool: - return self._delay_run + def used_opsets(self) -> UsedOpsets: + return self._used_opsets - # TODO: How do we merge it with to_ir function? - def get_pattern( - self, - *vars: Sequence[Var], - match_bindings: dict[str, ir.Value | Any] | None = None, - ) -> tuple[NodePattern | None, int | None]: - if self._delay_run: - if match_bindings is None: - return None, None - node_output_pattern = self._function(*vars, match_bindings) - else: - node_output_pattern = self._function(*vars) - return _handle_pattern_return_value(node_output_pattern) +@dataclasses.dataclass +class ReplacementSubgraph: + """A subgraph that will replace the matched pattern.""" -class RewriteCache: - def __init__(self): - self._node_output_pattern_to_ir: dict[NodeOutputPattern, tuple[ir.Value, ir.Node]] = ( - dict() - ) + new_values: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] + used_opsets: UsedOpsets - def get_node_output_pattern( - self, node_output_pattern: NodeOutputPattern - ) -> tuple[ir.Value, ir.Node] | None: - return self._node_output_pattern_to_ir.get(node_output_pattern, None) - def set_node_output_pattern_with_ir( - self, node_output_pattern: NodeOutputPattern, value: ir.Value, node: ir.Node - ) -> None: - self._node_output_pattern_to_ir[node_output_pattern] = (value, node) +class ReplacementPatternFunction: + """The replacement pattern that will replace the targeted pattern. + + Attributes: + function (Callable): The replacement function that will be used to replace the matched pattern. + """ + + def __init__(self, function) -> None: + self._function = function + + def get_replacement( + self, + match_bindings: dict[str, ir.Value | Any] | None = None, + ) -> ReplacementSubgraph: + context = RewriterContext() + new_values = self._function(context, **match_bindings) + if not isinstance(new_values, Sequence): + new_values = [new_values] + return ReplacementSubgraph(new_values, context.nodes, context.used_opsets) class RewriteRule: @@ -879,11 +788,9 @@ def __init__( """ if target_pattern is None: - # NOTE: commute() generated rules will have target_pattern as None - # ReplacementPatternFunction is still needed in try_rewrite + # NOTE: this is a default-constructor. Caller responsible for filling in the fields. assert replacement_pattern is None assert condition_function is None - self._replacement_pattern = ReplacementPatternFunction(replacement_pattern) return elif replacement_pattern is None: raise ValueError( @@ -900,21 +807,12 @@ def __init__( self._condition_function = condition_function _pattern_vars = inspect.signature(self._target_pattern.function).parameters - _replacement_vars = inspect.signature(self._replacement_pattern.function).parameters - # TODO: accept _replacement_vars being subset of _pattern_vars? - assert len(_pattern_vars) == len(_replacement_vars) self._vars = [Var(v) for v in _pattern_vars] # Get the last node pattern and number of outputs from the pattern function self._target_node_pattern, self._target_num_outputs = self._target_pattern.get_pattern( - *self._vars # type: ignore[arg-type] - ) - # NOTE: Return Nones if the replacement pattern is delayed running - self._replace_node_pattern, _replacement_num_outputs = replacement_pattern.get_pattern( - *self._vars # type: ignore[arg-type] + self._vars # type: ignore[arg-type] ) - if _replacement_num_outputs is not None: - assert self._target_num_outputs == _replacement_num_outputs def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: """Check if the node from IR matches the pattern.""" @@ -931,29 +829,37 @@ def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: def try_rewrite( self, model: ir.Model, node: ir.Node - ) -> tuple[Sequence[Any], Sequence[ir.Node]] | None: + ): # TODO(rama) -> ReplacementSubgraph | None: """If the node matches the pattern, then replace the node with the replacement pattern.""" match = self.matches(node, model) if match: assert match.values is not None, "Matched values should not be None." if _valid_to_replace(match.values): - # NOTE: delayed running as the replacement pattern needs bindings - if self._replacement_pattern.delay_run: - # bindings will be consumed by the replacement function - self._replace_node_pattern, _replacement_num_outputs = ( - self._replacement_pattern.get_pattern( - *self._vars[:-1], # type: ignore[arg-type] - match_bindings=match.bindings, - ) + # bindings will be consumed by the replacement function + delta = self._replacement_pattern.get_replacement(match.bindings) + if len(delta.new_values) != self._target_num_outputs: + raise ValueError( + f"Number of outputs from replacement function does not match the number of outputs from the target pattern. " + f"Expected {self._target_num_outputs}, but got {len(delta.new_values)}." ) - assert self._target_num_outputs == _replacement_num_outputs - rewrite_cache = RewriteCache() - assert self._replace_node_pattern is not None, "Replacement pattern is None." - _, _to_insert = self._replace_node_pattern.to_ir( - model, match.bindings, self._target_num_outputs, rewrite_cache - ) - - return (match.values, _to_insert) # type: ignore[return-value] + # TODO(rama): Check/update opset-imports + # (i) Integrate following with the multi-output matcher and code elsewhere: + # (ii) For functions, we need to do this with function, not model's main graph. + # (iii) Code in the caller (below) checks if match overlaps previous match, which + # appears incorrect for single-pattern matcher. Best to alter iteration to apply + # each rewrite immediately, instead of accumulating them. + # (iv) return delta here + imports = model.graph.opset_imports + for domain, version in delta.used_opsets: + if domain not in imports: + # use 1 as default version if not explicitly specified + imports[domain] = version if version is not None else 1 + elif version is not None and version != imports[domain]: + raise ValueError( + f"Multiple versions of opset {domain} used. " + f"Expected version {imports[domain]}, but got {version}." + ) + return match.values, delta.new_nodes return None def apply_to_model(self, model: ir.Model, *, commute: bool = False): @@ -970,7 +876,8 @@ def replace_pattern(new_pattern): rule._condition_function = self._condition_function rule._target_node_pattern = new_pattern rule._target_num_outputs = self._target_num_outputs - rule._replace_node_pattern = self._replace_node_pattern + rule._replacement_pattern = self._replacement_pattern + rule._vars = self._vars return rule return [replace_pattern(p) for p in self._target_node_pattern.commute()] diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 9b50da2ef..14fcf4be8 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -17,8 +17,8 @@ def rule(self) -> pattern.RewriteRule: def reciprocal_mul_pattern(x, y): return (1 / x) * y - def div(x, y): - return y / x + def div(op, x, y): + return op.Div(y, x) return pattern.RewriteRule(reciprocal_mul_pattern, div) @@ -96,8 +96,8 @@ def fast_gelu_pattern1(x): tanh = op.Tanh(c * (x + (x**3) * b)) return (1.0 + tanh) * (0.5 * x) - def fast_gelu(x): - return msft_op.FastGelu(x) + def fast_gelu(op, x): + return op.FastGelu(x, domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu) @@ -117,8 +117,8 @@ def fast_gelu_pattern1_long(x): half_x = op.Mul(half, x) return op.Mul(one_plus_tanh, half_x) - def fast_gelu(x): - return msft_op.FastGelu(x) + def fast_gelu(op, x): + return op.FastGelu(x, domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu) @@ -163,7 +163,7 @@ def concat_pattern(x, y, axis): seq = op.SequenceConstruct(x, y) return op.ConcatFromSequence(seq, axis=axis) - def concat(x, y, axis): + def concat(op, x, y, axis): return op.Concat(x, y, axis=axis) return pattern.RewriteRule(concat_pattern, concat) @@ -213,7 +213,7 @@ def test_commute(self): def add_0(x): return x + 0 - def identity(x): + def identity(op, x): return op.Identity(x) add_0_rule = pattern.RewriteRule(add_0, identity) @@ -240,7 +240,7 @@ def test_const_value(self): def reshape(x, newshape): return op.Reshape(x, newshape) - def identity(x, newshape): + def identity(op, x, newshape): del newshape # Unused return op.Identity(x) @@ -299,6 +299,30 @@ def test_delayed_run_provides_correct_bindings_for_multiple_matches(self): self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) self.assertEqual(model.graph[1].attributes["value"].value.dtype, 1) + def test_opset_import(self): + def add_same(x): + return x + x + + def double(op, x): + return op.Double(x, domain="custom.domain", version=10) + + rule = pattern.RewriteRule(add_same, double) + + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[M] z) + { + y = Add (x, x) + z = Relu (y) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = pattern.RewriteRuleSet([rule], commute=True).apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(model.graph.opset_imports["custom.domain"], 10) + if __name__ == "__main__": unittest.main()