From a7835f2baa6884112e67ef8c31ee5aa345c74392 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 31 Jul 2024 11:12:11 -0700 Subject: [PATCH] [DRAFT] Change parameter names of builder methods for domain, version, outputs (#1767) As discussed previously. Use parameter names _domain, _version, and _outputs for special kwargs in onnx op builder method. --- docs/tutorial/rewriter/examples/erfgelu.py | 2 +- onnxscript/optimizer/_constant_folding.py | 10 ++--- onnxscript/rewriter/erfgelu.py | 2 +- onnxscript/rewriter/generic_pattern_test.py | 28 +++++++------- onnxscript/rewriter/llama_rule_sets.py | 2 +- .../onnxruntime/fused_matmul_rule_sets.py | 16 ++++---- .../group_normalization_merge_silu.py | 4 +- .../instance_to_group_normalization.py | 2 +- onnxscript/rewriter/pattern.py | 38 +++++++++---------- onnxscript/rewriter/pattern_test.py | 10 ++--- 10 files changed, 57 insertions(+), 57 deletions(-) diff --git a/docs/tutorial/rewriter/examples/erfgelu.py b/docs/tutorial/rewriter/examples/erfgelu.py index a7f16cea0..f32ade37c 100644 --- a/docs/tutorial/rewriter/examples/erfgelu.py +++ b/docs/tutorial/rewriter/examples/erfgelu.py @@ -87,7 +87,7 @@ def erf_gelu_pattern_2(op, x): def gelu(op, x: ir.Value): - return op.Gelu(x, domain="com.microsoft") + return op.Gelu(x, _domain="com.microsoft") #################################### diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 9f4899e0e..a34b9810b 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -362,7 +362,7 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu unsqueezed_inputs = [] for node_input in inputs: unsqueezed_input = op.Unsqueeze( - node_input, axis_value, outputs=[f"{node_input.name}_unsqueeze"] + node_input, axis_value, _outputs=[f"{node_input.name}_unsqueeze"] ) unsqueezed_inputs.append(unsqueezed_input) # Send unsqueezed outputs to Concat @@ -427,13 +427,13 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: num_outputs = math.ceil(split_dimension_size / split_value.item()) split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] split_values = op.Split( - input, axis=axis, num_outputs=num_outputs, outputs=split_outputs + input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs ) elif split_value.ndim == 1: # split into 'size(split)' chunks num_outputs = split_value.size split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split(input, split, axis=axis, outputs=split_outputs) + split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) else: return None @@ -442,11 +442,11 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None if keepdims == 0: # squeeze the split dimension if keepdims is 0 - axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) + axis_val = op.Constant(value_int=axis, _outputs=[f"{output.name}_axis"]) squeezed_values = [] for i in range(num_outputs): squeezed = op.Squeeze( - split_values[i], axis_val, outputs=[f"{split_outputs[i]}_squeeze"] + split_values[i], axis_val, _outputs=[f"{split_outputs[i]}_squeeze"] ) squeezed_values.append(squeezed) split_values = squeezed_values diff --git a/onnxscript/rewriter/erfgelu.py b/onnxscript/rewriter/erfgelu.py index ea8d27a4e..c821a79b3 100644 --- a/onnxscript/rewriter/erfgelu.py +++ b/onnxscript/rewriter/erfgelu.py @@ -21,7 +21,7 @@ def erf_gelu_pattern(op, x): # Replacement def gelu(op, x): - return op.Gelu(x, domain="com.microsoft") + return op.Gelu(x, _domain="com.microsoft") rule = pattern.RewriteRule(erf_gelu_pattern, gelu) diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index db0e2a638..dadaf5e8b 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -45,7 +45,7 @@ def match_pattern(op, x, y, z): def apply_pattern(op, x, y, z, **_): """Builds the replacement graph.""" - return op.AddAdd(x, y, z, domain="ZZZ") + return op.AddAdd(x, y, z, _domain="ZZZ") def validate_mapping(context, x, y, z, **_) -> bool: """Validates the mapping.""" @@ -127,7 +127,7 @@ def match_pattern(op, x, y, w, z): def apply_pattern(op, x, y, w, z, **_): """Builds the pattern to match.""" - return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", outputs=2) + return op.AddAddAddAdd(x, y, w, z, _domain="ZZZ", _outputs=2) def validate_mapping(context, **_) -> bool: return True @@ -262,7 +262,7 @@ def match_pattern(op, x): return t1, t2 def apply_pattern(op, x, **_): - return op.SinCos(x, domain="com.microsoft", outputs=2) + return op.SinCos(x, _domain="com.microsoft", _outputs=2) rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo) model_proto = onnx.parser.parse_model( @@ -295,7 +295,7 @@ def match_pattern(op, x): return t1, t2 def apply_pattern(op, x, **_): - return op.SinCos(x, domain="com.microsoft", outputs=2) + return op.SinCos(x, _domain="com.microsoft", _outputs=2) rule = pattern.RewriteRule( match_pattern, @@ -338,8 +338,8 @@ def match_pattern(op, x, pos_ids, axis): output, _length = op.ConcatTraining( transpose, transpose, - domain="com.microsoft", - outputs=2, + _domain="com.microsoft", + _outputs=2, ) sin = op.Sin(output) @@ -365,8 +365,8 @@ def apply_pattern(op, x, pos_ids, axis, **_): pos_ids, cos_cache, sin_cache, - domain="com.microsoft", - outputs=2, + _domain="com.microsoft", + _outputs=2, ) rule = pattern.RewriteRule( @@ -409,7 +409,7 @@ def rotary_match_pattern(op, x, pos_ids, axis): matmul = op.MatMul(pos_ids, cast) transpose = op.Transpose(matmul) output, _length = op.ConcatTraining( - transpose, transpose, domain="com.microsoft", outputs=2 + transpose, transpose, _domain="com.microsoft", _outputs=2 ) sin = op.Sin(output) @@ -431,7 +431,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) ) part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 + x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 ) return part1, part2 @@ -475,7 +475,7 @@ def rotary_match_pattern(op, x, pos_ids, axis): matmul = op.MatMul(pos_ids, cast) transpose = op.Transpose(matmul) output, _length = op.ConcatTraining( - transpose, transpose, domain="com.microsoft", outputs=2 + transpose, transpose, _domain="com.microsoft", _outputs=2 ) sin = op.Sin(output) @@ -497,7 +497,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis): value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) ) part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 + x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 ) return part1, part2 @@ -535,8 +535,8 @@ def test_transpose_transpose_onnxscript(self): # return Y def transpose_transpose_pattern(op, X): - XT = op.Transpose(X, outputs=["XT"]) - Y = op.Transpose(XT, outputs=["Y"]) + XT = op.Transpose(X, _outputs=["XT"]) + Y = op.Transpose(XT, _outputs=["Y"]) return Y def transpose_transpose_mapping(perm0, perm1): diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 6be58dd65..1adb03e16 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -155,7 +155,7 @@ def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> bool: @classmethod def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1): - return op.Split(x, num_outputs=2, axis=-1, outputs=2) + return op.Split(x, num_outputs=2, axis=-1, _outputs=2) class TransposeIdentity(orp.RewriteRuleAsClass): diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py index 83f263304..3a4444dbb 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py @@ -29,7 +29,7 @@ def check(cls, context, x, y, cst) -> bool: def rewrite(cls, op, x, y, cst): value = cst.const_value.numpy() c = float(value[0] if value.shape == (1,) else value) - return op.FusedMatMul(x, y, alpha=1 / c, domain="com.microsoft") + return op.FusedMatMul(x, y, alpha=1 / c, _domain="com.microsoft") class FusedMatMulDiv2(orp.RewriteRuleAsClass): @@ -37,7 +37,7 @@ class FusedMatMulDiv2(orp.RewriteRuleAsClass): @classmethod def pattern(cls, op, x, y, cst): - return op.Div(op.FusedMatMul(x, y, domain="com.microsoft"), cst) + return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst) @classmethod def check(cls, context, x, y, cst) -> bool: @@ -60,7 +60,7 @@ def rewrite(cls, op, x, y, cst): att = node.attributes.get(name) if att: kwargs[name] = att.value - return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft") + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") class _TransposeMatMulBase(orp.RewriteRuleAsClass): @@ -83,7 +83,7 @@ def rewrite(cls, op, x, y): kwargs[name] = att.value name = "transA" if cls._pos == 1 else "transB" kwargs[name] = 1 - kwargs.get(name, 0) - return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft") + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") class TransposeMatMul1(_TransposeMatMulBase): @@ -99,7 +99,7 @@ class TransposeFusedMatMul1(TransposeMatMul1): @classmethod def pattern(cls, op, x, y): - return op.FusedMatMul(op.Transpose(x), y, domain="com.microsoft") + return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft") class TransposeMatMul2(_TransposeMatMulBase): @@ -117,7 +117,7 @@ class TransposeFusedMatMul2(TransposeMatMul2): @classmethod def pattern(cls, op, x, y): - return op.FusedMatMul(x, op.Transpose(y), domain="com.microsoft") + return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft") class MatMulTranspose(orp.RewriteRuleAsClass): @@ -146,7 +146,7 @@ def rewrite(cls, op, x, y): kwargs[name] = att.value for name in ["transA", "transB"]: kwargs[name] = 1 - kwargs.get(name, 0) - return op.FusedMatMul(y, x, **kwargs, domain="com.microsoft") + return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft") class FusedMatMulTranspose(MatMulTranspose): @@ -154,7 +154,7 @@ class FusedMatMulTranspose(MatMulTranspose): @classmethod def pattern(cls, op, x, y): - return op.Transpose(op.FusedMatMul(x, y, domain="com.microsoft")) + return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft")) def fused_matmul_rule_sets() -> orp.RewriteRuleSet: diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py index 843ad920b..7372ef6cf 100644 --- a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py @@ -27,7 +27,7 @@ def group_normalization_and_silu_submodule( channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2]) return torch_module_op.submodule("torch_nn_modules_activation_SiLU")( @@ -51,7 +51,7 @@ def group_normalization_with_silu( channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) return op.Transpose(group_norm, perm=[0, 3, 1, 2]) diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index bcd7c2d38..85b412b24 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -142,7 +142,7 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) return op.Transpose(output, perm=[0, 3, 1, 2]) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 4c388c6ae..6f3613e5f 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -203,35 +203,35 @@ def __init__( def __call__( self, *args, - domain: str | None = None, - version: int | None = None, - outputs: int | list[str | None] = 1, + _domain: str | None = None, + _version: int | None = None, + _outputs: int | list[str | None] = 1, _allow_other_attributes: bool | None = None, **kwargs, ): - if version is not None: + if _version is not None: raise ValueError( - "The pattern builder does not support 'version' keyword argument. " + "The pattern builder does not support '_version' keyword argument. " "Version restrictions should be handled by rewrite rules." ) - if domain is None: + if _domain is None: opset_pattern = self.opset_pattern - elif isinstance(domain, str): - opset_pattern = OpsetPatternBuilder(domain) + elif isinstance(_domain, str): + opset_pattern = OpsetPatternBuilder(_domain) else: - # TODO(rama): allow OpsetPatternBuilder as domain. - raise TypeError("domain must be a string.") + # TODO(rama): allow OpsetPatternBuilder as _domain. + raise TypeError("_domain must be a string.") - if isinstance(outputs, int): - outputs = [None for _ in range(outputs)] - elif not isinstance(outputs, Sequence) or not all( - isinstance(x, (str, type(None))) for x in outputs + if isinstance(_outputs, int): + _outputs = [None for _ in range(_outputs)] + elif not isinstance(_outputs, Sequence) or not all( + isinstance(x, (str, type(None))) for x in _outputs ): - raise ValueError("outputs must be an int or a list[str|None].") + raise ValueError("_outputs must be an int or a list[str|None].") inputs = [_to_value_pattern(x) for x in args] attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} node_pattern = NodePattern( - opset_pattern, self.op_name, inputs, attributes, outputs, _allow_other_attributes + opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes ) output_values = node_pattern.outputs # Unpack outputs if there is only one output, the common case. @@ -805,9 +805,9 @@ def __getattr__(self, op_type: str) -> Any: def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): # TODO(rama): some of the following logic should move into the tape. - domain = kwargs.pop("domain", "") - version = kwargs.pop("version", None) - outputs = kwargs.pop("outputs", 1) + domain = kwargs.pop("_domain", "") + version = kwargs.pop("_version", None) + outputs = kwargs.pop("_outputs", 1) if isinstance(outputs, Sequence): num_outputs = len(outputs) else: diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0b2748b1d..31985db5a 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -109,7 +109,7 @@ def fast_gelu_pattern1(op, x): return (1.0 + tanh) * (0.5 * x) def fast_gelu(op, x): - return op.FastGelu(x, domain="com.microsoft") + return op.FastGelu(x, _domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu) @@ -130,7 +130,7 @@ def fast_gelu_pattern1_long(op, x): return op.Mul(one_plus_tanh, half_x) def fast_gelu(op, x): - return op.FastGelu(x, domain="com.microsoft") + return op.FastGelu(x, _domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu) @@ -315,7 +315,7 @@ def add_same(op, x): return x + x def double(op, x): - return op.Double(x, domain="custom.domain", version=10) + return op.Double(x, _domain="custom.domain", _version=10) rule = pattern.RewriteRule(add_same, double) @@ -339,7 +339,7 @@ def add_same(op, x): return x + x def double(op, x): - return op.Double(x, domain="custom.domain", version=10) + return op.Double(x, _domain="custom.domain", _version=10) rule = pattern.RewriteRule(add_same, double) @@ -373,7 +373,7 @@ def test_optional_attribute(self): def concat_pattern(op, x, y): seq = op.SequenceConstruct(x, y) - result = op.ConcatFromSequence(seq, outputs=["result"]) + result = op.ConcatFromSequence(seq, _outputs=["result"]) return result def concat(op, x, y, result: ir.Value):