Skip to content

Commit

Permalink
Refactor pattern matcher (#1411)
Browse files Browse the repository at this point in the history
Part 1 of refactoring the pattern matcher (towards a unified API and
implementation of existing two pattern matchers).
* Eliminate separate "Var" class for patterns. (Every value-pattern will
have an optional name, used to bind to corresponding ir-value in a
match.)
* Eliminate the delay-run (vs. eager-run) distinction. Everything is
used in delayed-run mode now.
* Switch all replacement-builders to use context-parameter to directly
build IR nodes for replacement.
* Cleanup some rewrite rules
* Get rid of the to_ir methods

---------

Signed-off-by: Ganesan Ramalingam <[email protected]>
Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
gramalingam and justinchuby authored Apr 25, 2024
1 parent 8060e2d commit 0d98619
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 341 deletions.
15 changes: 3 additions & 12 deletions onnxscript/rewriter/broadcast_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@ def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shap
return op.Reshape(matmul, shape_c)


def matmul_with_two_shape_inputs(input_a, input_b, shape_a, shape_b, shape_c):
del shape_a # Unused
del shape_b # Unused
del shape_c # Unused
def matmul(op, input_a, input_b, **_):
return op.MatMul(input_a, input_b)


Expand All @@ -153,21 +150,15 @@ def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c):
return op.Reshape(matmul, shape_c)


def matmul_with_one_shape_input(input_a, input_b, shape_a, shape_c):
del shape_a # Unused
del shape_c # Unused
return op.MatMul(input_a, input_b)


# Register the rewrite rules
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern,
matmul_with_two_shape_inputs,
matmul,
check_if_need_reshape,
)
one_reshape_matmul_reshape_rule = pattern.RewriteRule(
one_reshape_matmul_reshape_pattern,
matmul_with_one_shape_input,
matmul,
# We can use the same check_if_need_reshape function for both the rules,
# as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern.
check_if_need_reshape,
Expand Down
57 changes: 16 additions & 41 deletions onnxscript/rewriter/cast_constant_of_shape.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import logging
from typing import Any, Sequence

import numpy as np
import onnx.helper

from onnxscript import ir
from onnxscript.rewriter import pattern
Expand All @@ -12,58 +11,34 @@
logger = logging.getLogger(__name__)


def cast_constant_of_shape(
shape: Sequence[int],
t: Any,
dtype: int,
match_bindings: dict[str, ir.Value | Any] | None = None,
) -> pattern.OpPattern:
constant = op.ConstantOfShape(shape, value=t)
def cast_constant_of_shape(shape, scalar, dtype):
constant = op.ConstantOfShape(shape, value=scalar)
return op.Cast(constant, to=dtype)


def fused_cast_constant_of_shape(
shape: Sequence[int], t: Any, dtype: int, match_bindings: dict[str, ir.Value | Any]
) -> pattern.OpPattern:
del dtype # unused
del t # unused
v_dtype = match_bindings["dtype"]
v_t = match_bindings["t"]
v_dtype = ir.DataType(v_dtype.value).numpy() # type: ignore[union-attr]
casted_val = ir.Tensor(v_t.value.numpy().astype(v_dtype)) # type: ignore[union-attr]
return op.ConstantOfShape(shape, value=casted_val)


def cast_constant_of_shape_without_value(
shape: Sequence[int],
dtype: int,
match_bindings: dict[str, ir.Value | Any] | None = None,
) -> pattern.OpPattern:
del match_bindings # Unused
def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir.Attr, **_):
# Cast scalar (a TensorProto attribute) to the specified dtype
scalar_value = scalar.value.numpy().item()
cast_value = onnx.helper.make_tensor("value", dtype.value, (), [scalar_value])
return op.ConstantOfShape(shape, value=cast_value)


def cast_constant_of_shape_without_value(shape, dtype):
constant = op.ConstantOfShape(shape)
return op.Cast(constant, to=dtype)


def fused_cast_constant_of_shape_without_value(
shape: Sequence[int], dtype: int, match_bindings: dict[str, ir.Value | Any]
) -> pattern.OpPattern:
del dtype # Unused
v_dtype = match_bindings["dtype"]
v_dtype = ir.DataType(v_dtype.value).numpy() # type: ignore[union-attr]
val = ir.Tensor(np.zeros(1, dtype=v_dtype))
return op.ConstantOfShape(shape, value=val)
def fused_cast_constant_of_shape_without_value(op, shape, dtype, **_):
zero = onnx.helper.make_tensor("value", dtype.value, (), [0])
return op.ConstantOfShape(shape, value=zero)


cast_constant_of_shape_rule = pattern.RewriteRule(
cast_constant_of_shape,
pattern.ReplacementPatternFunction(fused_cast_constant_of_shape, delay_run=True),
cast_constant_of_shape, fused_cast_constant_of_shape
)

cast_constant_of_shape_without_value_rule = pattern.RewriteRule(
cast_constant_of_shape_without_value,
pattern.ReplacementPatternFunction(
fused_cast_constant_of_shape_without_value, delay_run=True
),
cast_constant_of_shape_without_value, fused_cast_constant_of_shape_without_value
)

rules = pattern.RewriteRuleSet(
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def erf_gelu_pattern(x):


# Replacement
def gelu(x):
return msft_op.Gelu(x)
def gelu(op, x):
return op.Gelu(x, domain="com.microsoft")


rule = pattern.RewriteRule(erf_gelu_pattern, gelu)
2 changes: 1 addition & 1 deletion onnxscript/rewriter/gemm_to_matmul_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def reshape_gemm_reshape_pattern(input_a, input_b, input_c, shape_a, shape_c):
return op.Reshape(gemm, shape_c)


def matmul_add(input_a, input_b, input_c, shape_a, shape_d):
def matmul_add(op, input_a, input_b, input_c, **_):
matmul = op.MatMul(input_a, input_b)
return op.Add(matmul, input_c)

Expand Down
33 changes: 2 additions & 31 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,7 @@
import onnxscript
import onnxscript.rewriter.pattern as orp
from onnxscript import ir
from onnxscript.rewriter import _ir_utils, _tape


class _SimpleBuilder:
"""temporary adaptor for building 'generic patterns'."""

# TODO(justinchuby): Merge with the rest of pattern building methods
def __init__(self):
self.tape = _tape.Tape()

def __getattr__(self, op_type: str) -> Any:
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)

def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
domain = kwargs.pop("domain", "")
output_names = kwargs.pop("output_names", 1)
if isinstance(output_names, Sequence):
num_outputs = len(output_names)
else:
assert isinstance(output_names, int)
num_outputs = output_names
if num_outputs == 1:
return self.tape.op(op_type, inputs=inputs, attributes=kwargs, domain=domain)
return self.tape.op_multi_output(
op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs
)

@property
def nodes(self) -> Sequence[ir.Node]:
return self.tape.nodes
from onnxscript.rewriter import _ir_utils


class PatternMatchResult:
Expand Down Expand Up @@ -339,7 +310,7 @@ def _build_pattern(
assert len(kwargs) == 0, f"Attributes are not supported yet but kwargs={kwargs}"

inputs = [ir.Input(name=name) for name in args]
builder = _SimpleBuilder()
builder = orp.RewriterContext()
outputs = func(builder, *inputs, **kwargs)
if isinstance(outputs, ir.Value):
outputs = [outputs]
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def match_pattern(cls, op, x, y, w, z):
@classmethod
def apply_pattern(cls, op, x, y, w, z):
"""Builds the pattern to match."""
return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", output_names=2)
return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", outputs=2)

def validate_mapping(
self,
Expand Down Expand Up @@ -273,7 +273,7 @@ def match_pattern(cls, op, x, pos_ids, axis):
transpose,
transpose,
domain="com.microsoft",
output_names=2,
outputs=2,
)

sin = op.Sin(output)
Expand Down Expand Up @@ -307,7 +307,7 @@ def apply_pattern(cls, op, x, pos_ids, axis):
cos_cache,
sin_cache,
domain="com.microsoft",
output_names=2,
outputs=2,
)

model = self.get_rotary_model()
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def div_by_1(x):


# Replacement
def identity(x):
def identity(op, x):
return op.Identity(x)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,22 @@ def group_normalization_and_silu_submodule(


def group_normalization_with_silu(
op,
input,
weight,
bias,
epsilon,
groups,
):
group_norm = msft_op.GroupNorm(
group_norm = op.GroupNorm(
input,
weight,
bias,
activation=1,
channels_last=1,
epsilon=epsilon,
groups=groups,
domain="com.microsoft",
)
return op.Transpose(group_norm, perm=[0, 3, 1, 2])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,7 @@ def test_simulated_instance_norm_is_replaced_by_group_norm_silu(self):
self.assertEqual(count, 2)
# plus 2 in model constants
self.assertEqual(len(model.graph), 10)


if __name__ == "__main__":
unittest.main()
27 changes: 9 additions & 18 deletions onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def instance_simulates_group_normalization_pattern(
weight_full,
bias_full,
epsilon,
match_bindings: dict[str, ir.Value | Any] | None = None,
):
adjusted_input = op.Reshape(input_x, adjusted_input_shape)
inst_norm = op.InstanceNormalization(
Expand All @@ -116,17 +115,7 @@ def instance_simulates_group_normalization_pattern(
return op.Add(mul, bias_full)


def group_normalization(
input_x,
adjusted_input_shape,
original_input_shape,
weight_for_norm,
bias_for_norm,
weight_full,
bias_full,
epsilon,
match_bindings: dict[str, ir.Value],
):
def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, epsilon, **_):
# com.microsoft.GroupNorm only supports NHWC for now
nhwc_input = op.Transpose(input_x, perm=[0, 2, 3, 1])
# com.microsoft.GroupNorm only supports gamma and beta as float type
Expand All @@ -136,27 +125,29 @@ def group_normalization(
bias_full = op.Cast(bias_full, to=onnx.TensorProto.FLOAT)
bias_full = op.Reshape(bias_full, reshape_to_1d)
# re-obtain attribute groups
if "weight_for_norm" not in match_bindings:
raise ValueError("weight_for_norm is not found in match_bindings")
if match_bindings["weight_for_norm"].shape is None:
# TODO(rama): Earlier check implies weight_for_norm is a constant tensor?
# If not, we should add a check that shape[0] is not symbolic.
shape = weight_for_norm.shape
if shape is None:
raise ValueError("weight_for_norm shape not known")
groups = match_bindings["weight_for_norm"].shape[0]
output = msft_op.GroupNorm(
groups = shape[0]
output = op.GroupNorm(
nhwc_input,
weight_full,
bias_full,
activation=0,
channels_last=1,
epsilon=epsilon,
groups=groups,
domain="com.microsoft",
)
return op.Transpose(output, perm=[0, 3, 1, 2])


# Register the rewrite rules
instance_norm_to_group_norm_rule = pattern.RewriteRule(
instance_simulates_group_normalization_pattern,
pattern.ReplacementPatternFunction(group_normalization, delay_run=True),
group_normalization,
check_if_simulated_instance_norm_is_used,
)

Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/onnxruntime/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def softmax_with_fp32_upcast(input, axis):
return op.Cast(softmax, to=onnx.TensorProto.FLOAT16)


def softmax(input, axis):
def softmax(op, input, axis):
return op.Softmax(input, axis=axis)


Expand All @@ -28,7 +28,7 @@ def softmax_with_fp32_upcast_without_axis(input):
return op.Cast(softmax, to=onnx.TensorProto.FLOAT16)


def softmax_without_axis(input):
def softmax_without_axis(op, input):
return op.Softmax(input)


Expand Down
Loading

0 comments on commit 0d98619

Please sign in to comment.