diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index db0e2a638..dadaf5e8b 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -45,7 +45,7 @@ def match_pattern(op, x, y, z): def apply_pattern(op, x, y, z, **_): """Builds the replacement graph.""" - return op.AddAdd(x, y, z, domain="ZZZ") + return op.AddAdd(x, y, z, _domain="ZZZ") def validate_mapping(context, x, y, z, **_) -> bool: """Validates the mapping.""" @@ -127,7 +127,7 @@ def match_pattern(op, x, y, w, z): def apply_pattern(op, x, y, w, z, **_): """Builds the pattern to match.""" - return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", outputs=2) + return op.AddAddAddAdd(x, y, w, z, _domain="ZZZ", _outputs=2) def validate_mapping(context, **_) -> bool: return True @@ -262,7 +262,7 @@ def match_pattern(op, x): return t1, t2 def apply_pattern(op, x, **_): - return op.SinCos(x, domain="com.microsoft", outputs=2) + return op.SinCos(x, _domain="com.microsoft", _outputs=2) rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo) model_proto = onnx.parser.parse_model( @@ -295,7 +295,7 @@ def match_pattern(op, x): return t1, t2 def apply_pattern(op, x, **_): - return op.SinCos(x, domain="com.microsoft", outputs=2) + return op.SinCos(x, _domain="com.microsoft", _outputs=2) rule = pattern.RewriteRule( match_pattern, @@ -338,8 +338,8 @@ def match_pattern(op, x, pos_ids, axis): output, _length = op.ConcatTraining( transpose, transpose, - domain="com.microsoft", - outputs=2, + _domain="com.microsoft", + _outputs=2, ) sin = op.Sin(output) @@ -365,8 +365,8 @@ def apply_pattern(op, x, pos_ids, axis, **_): pos_ids, cos_cache, sin_cache, - domain="com.microsoft", - outputs=2, + _domain="com.microsoft", + _outputs=2, ) rule = pattern.RewriteRule( @@ -409,7 +409,7 @@ def rotary_match_pattern(op, x, pos_ids, axis): matmul = op.MatMul(pos_ids, cast) transpose = op.Transpose(matmul) output, _length = op.ConcatTraining( - transpose, transpose, domain="com.microsoft", outputs=2 + transpose, transpose, _domain="com.microsoft", _outputs=2 ) sin = op.Sin(output) @@ -431,7 +431,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) ) part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 + x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 ) return part1, part2 @@ -475,7 +475,7 @@ def rotary_match_pattern(op, x, pos_ids, axis): matmul = op.MatMul(pos_ids, cast) transpose = op.Transpose(matmul) output, _length = op.ConcatTraining( - transpose, transpose, domain="com.microsoft", outputs=2 + transpose, transpose, _domain="com.microsoft", _outputs=2 ) sin = op.Sin(output) @@ -497,7 +497,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis): value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) ) part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 + x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 ) return part1, part2 @@ -535,8 +535,8 @@ def test_transpose_transpose_onnxscript(self): # return Y def transpose_transpose_pattern(op, X): - XT = op.Transpose(X, outputs=["XT"]) - Y = op.Transpose(XT, outputs=["Y"]) + XT = op.Transpose(X, _outputs=["XT"]) + Y = op.Transpose(XT, _outputs=["Y"]) return Y def transpose_transpose_mapping(perm0, perm1): diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py index 83f263304..3a4444dbb 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py @@ -29,7 +29,7 @@ def check(cls, context, x, y, cst) -> bool: def rewrite(cls, op, x, y, cst): value = cst.const_value.numpy() c = float(value[0] if value.shape == (1,) else value) - return op.FusedMatMul(x, y, alpha=1 / c, domain="com.microsoft") + return op.FusedMatMul(x, y, alpha=1 / c, _domain="com.microsoft") class FusedMatMulDiv2(orp.RewriteRuleAsClass): @@ -37,7 +37,7 @@ class FusedMatMulDiv2(orp.RewriteRuleAsClass): @classmethod def pattern(cls, op, x, y, cst): - return op.Div(op.FusedMatMul(x, y, domain="com.microsoft"), cst) + return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst) @classmethod def check(cls, context, x, y, cst) -> bool: @@ -60,7 +60,7 @@ def rewrite(cls, op, x, y, cst): att = node.attributes.get(name) if att: kwargs[name] = att.value - return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft") + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") class _TransposeMatMulBase(orp.RewriteRuleAsClass): @@ -83,7 +83,7 @@ def rewrite(cls, op, x, y): kwargs[name] = att.value name = "transA" if cls._pos == 1 else "transB" kwargs[name] = 1 - kwargs.get(name, 0) - return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft") + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") class TransposeMatMul1(_TransposeMatMulBase): @@ -99,7 +99,7 @@ class TransposeFusedMatMul1(TransposeMatMul1): @classmethod def pattern(cls, op, x, y): - return op.FusedMatMul(op.Transpose(x), y, domain="com.microsoft") + return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft") class TransposeMatMul2(_TransposeMatMulBase): @@ -117,7 +117,7 @@ class TransposeFusedMatMul2(TransposeMatMul2): @classmethod def pattern(cls, op, x, y): - return op.FusedMatMul(x, op.Transpose(y), domain="com.microsoft") + return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft") class MatMulTranspose(orp.RewriteRuleAsClass): @@ -146,7 +146,7 @@ def rewrite(cls, op, x, y): kwargs[name] = att.value for name in ["transA", "transB"]: kwargs[name] = 1 - kwargs.get(name, 0) - return op.FusedMatMul(y, x, **kwargs, domain="com.microsoft") + return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft") class FusedMatMulTranspose(MatMulTranspose): @@ -154,7 +154,7 @@ class FusedMatMulTranspose(MatMulTranspose): @classmethod def pattern(cls, op, x, y): - return op.Transpose(op.FusedMatMul(x, y, domain="com.microsoft")) + return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft")) def fused_matmul_rule_sets() -> orp.RewriteRuleSet: diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py index 843ad920b..7372ef6cf 100644 --- a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py @@ -27,7 +27,7 @@ def group_normalization_and_silu_submodule( channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2]) return torch_module_op.submodule("torch_nn_modules_activation_SiLU")( @@ -51,7 +51,7 @@ def group_normalization_with_silu( channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) return op.Transpose(group_norm, perm=[0, 3, 1, 2]) diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index bcd7c2d38..85b412b24 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -142,7 +142,7 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) return op.Transpose(output, perm=[0, 3, 1, 2]) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 4c388c6ae..6f3613e5f 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -203,35 +203,35 @@ def __init__( def __call__( self, *args, - domain: str | None = None, - version: int | None = None, - outputs: int | list[str | None] = 1, + _domain: str | None = None, + _version: int | None = None, + _outputs: int | list[str | None] = 1, _allow_other_attributes: bool | None = None, **kwargs, ): - if version is not None: + if _version is not None: raise ValueError( - "The pattern builder does not support 'version' keyword argument. " + "The pattern builder does not support '_version' keyword argument. " "Version restrictions should be handled by rewrite rules." ) - if domain is None: + if _domain is None: opset_pattern = self.opset_pattern - elif isinstance(domain, str): - opset_pattern = OpsetPatternBuilder(domain) + elif isinstance(_domain, str): + opset_pattern = OpsetPatternBuilder(_domain) else: - # TODO(rama): allow OpsetPatternBuilder as domain. - raise TypeError("domain must be a string.") + # TODO(rama): allow OpsetPatternBuilder as _domain. + raise TypeError("_domain must be a string.") - if isinstance(outputs, int): - outputs = [None for _ in range(outputs)] - elif not isinstance(outputs, Sequence) or not all( - isinstance(x, (str, type(None))) for x in outputs + if isinstance(_outputs, int): + _outputs = [None for _ in range(_outputs)] + elif not isinstance(_outputs, Sequence) or not all( + isinstance(x, (str, type(None))) for x in _outputs ): - raise ValueError("outputs must be an int or a list[str|None].") + raise ValueError("_outputs must be an int or a list[str|None].") inputs = [_to_value_pattern(x) for x in args] attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} node_pattern = NodePattern( - opset_pattern, self.op_name, inputs, attributes, outputs, _allow_other_attributes + opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes ) output_values = node_pattern.outputs # Unpack outputs if there is only one output, the common case. @@ -805,9 +805,9 @@ def __getattr__(self, op_type: str) -> Any: def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): # TODO(rama): some of the following logic should move into the tape. - domain = kwargs.pop("domain", "") - version = kwargs.pop("version", None) - outputs = kwargs.pop("outputs", 1) + domain = kwargs.pop("_domain", "") + version = kwargs.pop("_version", None) + outputs = kwargs.pop("_outputs", 1) if isinstance(outputs, Sequence): num_outputs = len(outputs) else: diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0b2748b1d..31985db5a 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -109,7 +109,7 @@ def fast_gelu_pattern1(op, x): return (1.0 + tanh) * (0.5 * x) def fast_gelu(op, x): - return op.FastGelu(x, domain="com.microsoft") + return op.FastGelu(x, _domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu) @@ -130,7 +130,7 @@ def fast_gelu_pattern1_long(op, x): return op.Mul(one_plus_tanh, half_x) def fast_gelu(op, x): - return op.FastGelu(x, domain="com.microsoft") + return op.FastGelu(x, _domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu) @@ -315,7 +315,7 @@ def add_same(op, x): return x + x def double(op, x): - return op.Double(x, domain="custom.domain", version=10) + return op.Double(x, _domain="custom.domain", _version=10) rule = pattern.RewriteRule(add_same, double) @@ -339,7 +339,7 @@ def add_same(op, x): return x + x def double(op, x): - return op.Double(x, domain="custom.domain", version=10) + return op.Double(x, _domain="custom.domain", _version=10) rule = pattern.RewriteRule(add_same, double) @@ -373,7 +373,7 @@ def test_optional_attribute(self): def concat_pattern(op, x, y): seq = op.SequenceConstruct(x, y) - result = op.ConcatFromSequence(seq, outputs=["result"]) + result = op.ConcatFromSequence(seq, _outputs=["result"]) return result def concat(op, x, y, result: ir.Value):