Skip to content

Commit

Permalink
Change names of builder methods for domain, version, outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Jul 30, 2024
1 parent a72f048 commit 18f45d8
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 49 deletions.
28 changes: 14 additions & 14 deletions onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def match_pattern(op, x, y, z):

def apply_pattern(op, x, y, z, **_):
"""Builds the replacement graph."""
return op.AddAdd(x, y, z, domain="ZZZ")
return op.AddAdd(x, y, z, _domain="ZZZ")

def validate_mapping(context, x, y, z, **_) -> bool:
"""Validates the mapping."""
Expand Down Expand Up @@ -127,7 +127,7 @@ def match_pattern(op, x, y, w, z):

def apply_pattern(op, x, y, w, z, **_):
"""Builds the pattern to match."""
return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", outputs=2)
return op.AddAddAddAdd(x, y, w, z, _domain="ZZZ", _outputs=2)

def validate_mapping(context, **_) -> bool:
return True
Expand Down Expand Up @@ -262,7 +262,7 @@ def match_pattern(op, x):
return t1, t2

def apply_pattern(op, x, **_):
return op.SinCos(x, domain="com.microsoft", outputs=2)
return op.SinCos(x, _domain="com.microsoft", _outputs=2)

rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo)
model_proto = onnx.parser.parse_model(
Expand Down Expand Up @@ -295,7 +295,7 @@ def match_pattern(op, x):
return t1, t2

def apply_pattern(op, x, **_):
return op.SinCos(x, domain="com.microsoft", outputs=2)
return op.SinCos(x, _domain="com.microsoft", _outputs=2)

rule = pattern.RewriteRule(
match_pattern,
Expand Down Expand Up @@ -338,8 +338,8 @@ def match_pattern(op, x, pos_ids, axis):
output, _length = op.ConcatTraining(
transpose,
transpose,
domain="com.microsoft",
outputs=2,
_domain="com.microsoft",
_outputs=2,
)

sin = op.Sin(output)
Expand All @@ -365,8 +365,8 @@ def apply_pattern(op, x, pos_ids, axis, **_):
pos_ids,
cos_cache,
sin_cache,
domain="com.microsoft",
outputs=2,
_domain="com.microsoft",
_outputs=2,
)

rule = pattern.RewriteRule(
Expand Down Expand Up @@ -409,7 +409,7 @@ def rotary_match_pattern(op, x, pos_ids, axis):
matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, _length = op.ConcatTraining(
transpose, transpose, domain="com.microsoft", outputs=2
transpose, transpose, _domain="com.microsoft", _outputs=2
)

sin = op.Sin(output)
Expand All @@ -431,7 +431,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16))
)
part1, part2 = op.RotaryEmbedding(
x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2
x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2
)
return part1, part2

Expand Down Expand Up @@ -475,7 +475,7 @@ def rotary_match_pattern(op, x, pos_ids, axis):
matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, _length = op.ConcatTraining(
transpose, transpose, domain="com.microsoft", outputs=2
transpose, transpose, _domain="com.microsoft", _outputs=2
)

sin = op.Sin(output)
Expand All @@ -497,7 +497,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16))
)
part1, part2 = op.RotaryEmbedding(
x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2
x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2
)
return part1, part2

Expand Down Expand Up @@ -535,8 +535,8 @@ def test_transpose_transpose_onnxscript(self):
# return Y

def transpose_transpose_pattern(op, X):
XT = op.Transpose(X, outputs=["XT"])
Y = op.Transpose(XT, outputs=["Y"])
XT = op.Transpose(X, _outputs=["XT"])
Y = op.Transpose(XT, _outputs=["Y"])
return Y

def transpose_transpose_mapping(perm0, perm1):
Expand Down
16 changes: 8 additions & 8 deletions onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def check(cls, context, x, y, cst) -> bool:
def rewrite(cls, op, x, y, cst):
value = cst.const_value.numpy()
c = float(value[0] if value.shape == (1,) else value)
return op.FusedMatMul(x, y, alpha=1 / c, domain="com.microsoft")
return op.FusedMatMul(x, y, alpha=1 / c, _domain="com.microsoft")


class FusedMatMulDiv2(orp.RewriteRuleAsClass):
"""Replaces ``FusedMatMul + Div`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y, cst):
return op.Div(op.FusedMatMul(x, y, domain="com.microsoft"), cst)
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)

@classmethod
def check(cls, context, x, y, cst) -> bool:
Expand All @@ -60,7 +60,7 @@ def rewrite(cls, op, x, y, cst):
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft")
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")


class _TransposeMatMulBase(orp.RewriteRuleAsClass):
Expand All @@ -83,7 +83,7 @@ def rewrite(cls, op, x, y):
kwargs[name] = att.value
name = "transA" if cls._pos == 1 else "transB"
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft")
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")


class TransposeMatMul1(_TransposeMatMulBase):
Expand All @@ -99,7 +99,7 @@ class TransposeFusedMatMul1(TransposeMatMul1):

@classmethod
def pattern(cls, op, x, y):
return op.FusedMatMul(op.Transpose(x), y, domain="com.microsoft")
return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft")


class TransposeMatMul2(_TransposeMatMulBase):
Expand All @@ -117,7 +117,7 @@ class TransposeFusedMatMul2(TransposeMatMul2):

@classmethod
def pattern(cls, op, x, y):
return op.FusedMatMul(x, op.Transpose(y), domain="com.microsoft")
return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft")


class MatMulTranspose(orp.RewriteRuleAsClass):
Expand Down Expand Up @@ -146,15 +146,15 @@ def rewrite(cls, op, x, y):
kwargs[name] = att.value
for name in ["transA", "transB"]:
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(y, x, **kwargs, domain="com.microsoft")
return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft")


class FusedMatMulTranspose(MatMulTranspose):
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y):
return op.Transpose(op.FusedMatMul(x, y, domain="com.microsoft"))
return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft"))


def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def group_normalization_and_silu_submodule(
channels_last=1,
epsilon=epsilon,
groups=groups,
domain="com.microsoft",
_domain="com.microsoft",
)
transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2])
return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(
Expand All @@ -51,7 +51,7 @@ def group_normalization_with_silu(
channels_last=1,
epsilon=epsilon,
groups=groups,
domain="com.microsoft",
_domain="com.microsoft",
)
return op.Transpose(group_norm, perm=[0, 3, 1, 2])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep
channels_last=1,
epsilon=epsilon,
groups=groups,
domain="com.microsoft",
_domain="com.microsoft",
)
return op.Transpose(output, perm=[0, 3, 1, 2])

Expand Down
38 changes: 19 additions & 19 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,35 +203,35 @@ def __init__(
def __call__(
self,
*args,
domain: str | None = None,
version: int | None = None,
outputs: int | list[str | None] = 1,
_domain: str | None = None,
_version: int | None = None,
_outputs: int | list[str | None] = 1,
_allow_other_attributes: bool | None = None,
**kwargs,
):
if version is not None:
if _version is not None:
raise ValueError(
"The pattern builder does not support 'version' keyword argument. "
"The pattern builder does not support '_version' keyword argument. "
"Version restrictions should be handled by rewrite rules."
)
if domain is None:
if _domain is None:
opset_pattern = self.opset_pattern
elif isinstance(domain, str):
opset_pattern = OpsetPatternBuilder(domain)
elif isinstance(_domain, str):
opset_pattern = OpsetPatternBuilder(_domain)
else:
# TODO(rama): allow OpsetPatternBuilder as domain.
raise TypeError("domain must be a string.")
# TODO(rama): allow OpsetPatternBuilder as _domain.
raise TypeError("_domain must be a string.")

Check warning on line 223 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L223

Added line #L223 was not covered by tests

if isinstance(outputs, int):
outputs = [None for _ in range(outputs)]
elif not isinstance(outputs, Sequence) or not all(
isinstance(x, (str, type(None))) for x in outputs
if isinstance(_outputs, int):
_outputs = [None for _ in range(_outputs)]
elif not isinstance(_outputs, Sequence) or not all(
isinstance(x, (str, type(None))) for x in _outputs
):
raise ValueError("outputs must be an int or a list[str|None].")
raise ValueError("_outputs must be an int or a list[str|None].")

Check warning on line 230 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L230

Added line #L230 was not covered by tests
inputs = [_to_value_pattern(x) for x in args]
attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()}
node_pattern = NodePattern(
opset_pattern, self.op_name, inputs, attributes, outputs, _allow_other_attributes
opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes
)
output_values = node_pattern.outputs
# Unpack outputs if there is only one output, the common case.
Expand Down Expand Up @@ -805,9 +805,9 @@ def __getattr__(self, op_type: str) -> Any:

def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
# TODO(rama): some of the following logic should move into the tape.
domain = kwargs.pop("domain", "")
version = kwargs.pop("version", None)
outputs = kwargs.pop("outputs", 1)
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)
if isinstance(outputs, Sequence):
num_outputs = len(outputs)
else:
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def fast_gelu_pattern1(op, x):
return (1.0 + tanh) * (0.5 * x)

def fast_gelu(op, x):
return op.FastGelu(x, domain="com.microsoft")
return op.FastGelu(x, _domain="com.microsoft")

return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu)

Expand All @@ -130,7 +130,7 @@ def fast_gelu_pattern1_long(op, x):
return op.Mul(one_plus_tanh, half_x)

def fast_gelu(op, x):
return op.FastGelu(x, domain="com.microsoft")
return op.FastGelu(x, _domain="com.microsoft")

return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu)

Expand Down Expand Up @@ -315,7 +315,7 @@ def add_same(op, x):
return x + x

def double(op, x):
return op.Double(x, domain="custom.domain", version=10)
return op.Double(x, _domain="custom.domain", _version=10)

rule = pattern.RewriteRule(add_same, double)

Expand All @@ -339,7 +339,7 @@ def add_same(op, x):
return x + x

def double(op, x):
return op.Double(x, domain="custom.domain", version=10)
return op.Double(x, _domain="custom.domain", _version=10)

rule = pattern.RewriteRule(add_same, double)

Expand Down Expand Up @@ -373,7 +373,7 @@ def test_optional_attribute(self):

def concat_pattern(op, x, y):
seq = op.SequenceConstruct(x, y)
result = op.ConcatFromSequence(seq, outputs=["result"])
result = op.ConcatFromSequence(seq, _outputs=["result"])
return result

def concat(op, x, y, result: ir.Value):
Expand Down

0 comments on commit 18f45d8

Please sign in to comment.