Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Change parameter names of builder methods for domain, version, outputs #1767

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorial/rewriter/examples/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def erf_gelu_pattern_2(op, x):


def gelu(op, x: ir.Value):
return op.Gelu(x, domain="com.microsoft")
return op.Gelu(x, _domain="com.microsoft")


####################################
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu
unsqueezed_inputs = []
for node_input in inputs:
unsqueezed_input = op.Unsqueeze(
node_input, axis_value, outputs=[f"{node_input.name}_unsqueeze"]
node_input, axis_value, _outputs=[f"{node_input.name}_unsqueeze"]
)
unsqueezed_inputs.append(unsqueezed_input)
# Send unsqueezed outputs to Concat
Expand Down Expand Up @@ -427,13 +427,13 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
num_outputs = math.ceil(split_dimension_size / split_value.item())
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
split_values = op.Split(
input, axis=axis, num_outputs=num_outputs, outputs=split_outputs
input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs
)
elif split_value.ndim == 1:
# split into 'size(split)' chunks
num_outputs = split_value.size
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
split_values = op.Split(input, split, axis=axis, outputs=split_outputs)
split_values = op.Split(input, split, axis=axis, _outputs=split_outputs)
else:
return None

Expand All @@ -442,11 +442,11 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None
if keepdims == 0:
# squeeze the split dimension if keepdims is 0
axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"])
axis_val = op.Constant(value_int=axis, _outputs=[f"{output.name}_axis"])
squeezed_values = []
for i in range(num_outputs):
squeezed = op.Squeeze(
split_values[i], axis_val, outputs=[f"{split_outputs[i]}_squeeze"]
split_values[i], axis_val, _outputs=[f"{split_outputs[i]}_squeeze"]
)
squeezed_values.append(squeezed)
split_values = squeezed_values
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Replacement
def gelu(op, x):
return op.Gelu(x, domain="com.microsoft")
return op.Gelu(x, _domain="com.microsoft")

Check warning on line 24 in onnxscript/rewriter/erfgelu.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/erfgelu.py#L24

Added line #L24 was not covered by tests


rule = pattern.RewriteRule(erf_gelu_pattern, gelu)
28 changes: 14 additions & 14 deletions onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def match_pattern(op, x, y, z):

def apply_pattern(op, x, y, z, **_):
"""Builds the replacement graph."""
return op.AddAdd(x, y, z, domain="ZZZ")
return op.AddAdd(x, y, z, _domain="ZZZ")

def validate_mapping(context, x, y, z, **_) -> bool:
"""Validates the mapping."""
Expand Down Expand Up @@ -127,7 +127,7 @@ def match_pattern(op, x, y, w, z):

def apply_pattern(op, x, y, w, z, **_):
"""Builds the pattern to match."""
return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", outputs=2)
return op.AddAddAddAdd(x, y, w, z, _domain="ZZZ", _outputs=2)

def validate_mapping(context, **_) -> bool:
return True
Expand Down Expand Up @@ -262,7 +262,7 @@ def match_pattern(op, x):
return t1, t2

def apply_pattern(op, x, **_):
return op.SinCos(x, domain="com.microsoft", outputs=2)
return op.SinCos(x, _domain="com.microsoft", _outputs=2)

rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo)
model_proto = onnx.parser.parse_model(
Expand Down Expand Up @@ -295,7 +295,7 @@ def match_pattern(op, x):
return t1, t2

def apply_pattern(op, x, **_):
return op.SinCos(x, domain="com.microsoft", outputs=2)
return op.SinCos(x, _domain="com.microsoft", _outputs=2)

rule = pattern.RewriteRule(
match_pattern,
Expand Down Expand Up @@ -338,8 +338,8 @@ def match_pattern(op, x, pos_ids, axis):
output, _length = op.ConcatTraining(
transpose,
transpose,
domain="com.microsoft",
outputs=2,
_domain="com.microsoft",
_outputs=2,
)

sin = op.Sin(output)
Expand All @@ -365,8 +365,8 @@ def apply_pattern(op, x, pos_ids, axis, **_):
pos_ids,
cos_cache,
sin_cache,
domain="com.microsoft",
outputs=2,
_domain="com.microsoft",
_outputs=2,
)

rule = pattern.RewriteRule(
Expand Down Expand Up @@ -409,7 +409,7 @@ def rotary_match_pattern(op, x, pos_ids, axis):
matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, _length = op.ConcatTraining(
transpose, transpose, domain="com.microsoft", outputs=2
transpose, transpose, _domain="com.microsoft", _outputs=2
)

sin = op.Sin(output)
Expand All @@ -431,7 +431,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16))
)
part1, part2 = op.RotaryEmbedding(
x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2
x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2
)
return part1, part2

Expand Down Expand Up @@ -475,7 +475,7 @@ def rotary_match_pattern(op, x, pos_ids, axis):
matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, _length = op.ConcatTraining(
transpose, transpose, domain="com.microsoft", outputs=2
transpose, transpose, _domain="com.microsoft", _outputs=2
)

sin = op.Sin(output)
Expand All @@ -497,7 +497,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16))
)
part1, part2 = op.RotaryEmbedding(
x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2
x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2
)
return part1, part2

Expand Down Expand Up @@ -535,8 +535,8 @@ def test_transpose_transpose_onnxscript(self):
# return Y

def transpose_transpose_pattern(op, X):
XT = op.Transpose(X, outputs=["XT"])
Y = op.Transpose(XT, outputs=["Y"])
XT = op.Transpose(X, _outputs=["XT"])
Y = op.Transpose(XT, _outputs=["Y"])
return Y

def transpose_transpose_mapping(perm0, perm1):
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@

@classmethod
def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
return op.Split(x, num_outputs=2, axis=-1, outputs=2)
return op.Split(x, num_outputs=2, axis=-1, _outputs=2)

Check warning on line 158 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L158

Added line #L158 was not covered by tests


class TransposeIdentity(orp.RewriteRuleAsClass):
Expand Down
16 changes: 8 additions & 8 deletions onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def check(cls, context, x, y, cst) -> bool:
def rewrite(cls, op, x, y, cst):
value = cst.const_value.numpy()
c = float(value[0] if value.shape == (1,) else value)
return op.FusedMatMul(x, y, alpha=1 / c, domain="com.microsoft")
return op.FusedMatMul(x, y, alpha=1 / c, _domain="com.microsoft")


class FusedMatMulDiv2(orp.RewriteRuleAsClass):
"""Replaces ``FusedMatMul + Div`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y, cst):
return op.Div(op.FusedMatMul(x, y, domain="com.microsoft"), cst)
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)

@classmethod
def check(cls, context, x, y, cst) -> bool:
Expand All @@ -60,7 +60,7 @@ def rewrite(cls, op, x, y, cst):
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft")
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")


class _TransposeMatMulBase(orp.RewriteRuleAsClass):
Expand All @@ -83,7 +83,7 @@ def rewrite(cls, op, x, y):
kwargs[name] = att.value
name = "transA" if cls._pos == 1 else "transB"
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft")
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")


class TransposeMatMul1(_TransposeMatMulBase):
Expand All @@ -99,7 +99,7 @@ class TransposeFusedMatMul1(TransposeMatMul1):

@classmethod
def pattern(cls, op, x, y):
return op.FusedMatMul(op.Transpose(x), y, domain="com.microsoft")
return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft")


class TransposeMatMul2(_TransposeMatMulBase):
Expand All @@ -117,7 +117,7 @@ class TransposeFusedMatMul2(TransposeMatMul2):

@classmethod
def pattern(cls, op, x, y):
return op.FusedMatMul(x, op.Transpose(y), domain="com.microsoft")
return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft")


class MatMulTranspose(orp.RewriteRuleAsClass):
Expand Down Expand Up @@ -146,15 +146,15 @@ def rewrite(cls, op, x, y):
kwargs[name] = att.value
for name in ["transA", "transB"]:
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(y, x, **kwargs, domain="com.microsoft")
return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft")


class FusedMatMulTranspose(MatMulTranspose):
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y):
return op.Transpose(op.FusedMatMul(x, y, domain="com.microsoft"))
return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft"))


def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def group_normalization_and_silu_submodule(
channels_last=1,
epsilon=epsilon,
groups=groups,
domain="com.microsoft",
_domain="com.microsoft",
)
transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2])
return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(
Expand All @@ -51,7 +51,7 @@ def group_normalization_with_silu(
channels_last=1,
epsilon=epsilon,
groups=groups,
domain="com.microsoft",
_domain="com.microsoft",
)
return op.Transpose(group_norm, perm=[0, 3, 1, 2])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep
channels_last=1,
epsilon=epsilon,
groups=groups,
domain="com.microsoft",
_domain="com.microsoft",
)
return op.Transpose(output, perm=[0, 3, 1, 2])

Expand Down
38 changes: 19 additions & 19 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,35 +203,35 @@
def __call__(
self,
*args,
domain: str | None = None,
version: int | None = None,
outputs: int | list[str | None] = 1,
_domain: str | None = None,
_version: int | None = None,
_outputs: int | list[str | None] = 1,
_allow_other_attributes: bool | None = None,
**kwargs,
):
if version is not None:
if _version is not None:
raise ValueError(
"The pattern builder does not support 'version' keyword argument. "
"The pattern builder does not support '_version' keyword argument. "
"Version restrictions should be handled by rewrite rules."
)
if domain is None:
if _domain is None:
opset_pattern = self.opset_pattern
elif isinstance(domain, str):
opset_pattern = OpsetPatternBuilder(domain)
elif isinstance(_domain, str):
opset_pattern = OpsetPatternBuilder(_domain)
else:
# TODO(rama): allow OpsetPatternBuilder as domain.
raise TypeError("domain must be a string.")
# TODO(rama): allow OpsetPatternBuilder as _domain.
raise TypeError("_domain must be a string.")

Check warning on line 223 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L223

Added line #L223 was not covered by tests

if isinstance(outputs, int):
outputs = [None for _ in range(outputs)]
elif not isinstance(outputs, Sequence) or not all(
isinstance(x, (str, type(None))) for x in outputs
if isinstance(_outputs, int):
_outputs = [None for _ in range(_outputs)]
elif not isinstance(_outputs, Sequence) or not all(
isinstance(x, (str, type(None))) for x in _outputs
):
raise ValueError("outputs must be an int or a list[str|None].")
raise ValueError("_outputs must be an int or a list[str|None].")

Check warning on line 230 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L230

Added line #L230 was not covered by tests
inputs = [_to_value_pattern(x) for x in args]
attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()}
node_pattern = NodePattern(
opset_pattern, self.op_name, inputs, attributes, outputs, _allow_other_attributes
opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes
)
output_values = node_pattern.outputs
# Unpack outputs if there is only one output, the common case.
Expand Down Expand Up @@ -805,9 +805,9 @@

def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
# TODO(rama): some of the following logic should move into the tape.
domain = kwargs.pop("domain", "")
version = kwargs.pop("version", None)
outputs = kwargs.pop("outputs", 1)
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)
if isinstance(outputs, Sequence):
num_outputs = len(outputs)
else:
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def fast_gelu_pattern1(op, x):
return (1.0 + tanh) * (0.5 * x)

def fast_gelu(op, x):
return op.FastGelu(x, domain="com.microsoft")
return op.FastGelu(x, _domain="com.microsoft")

return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu)

Expand All @@ -130,7 +130,7 @@ def fast_gelu_pattern1_long(op, x):
return op.Mul(one_plus_tanh, half_x)

def fast_gelu(op, x):
return op.FastGelu(x, domain="com.microsoft")
return op.FastGelu(x, _domain="com.microsoft")

return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu)

Expand Down Expand Up @@ -315,7 +315,7 @@ def add_same(op, x):
return x + x

def double(op, x):
return op.Double(x, domain="custom.domain", version=10)
return op.Double(x, _domain="custom.domain", _version=10)

rule = pattern.RewriteRule(add_same, double)

Expand All @@ -339,7 +339,7 @@ def add_same(op, x):
return x + x

def double(op, x):
return op.Double(x, domain="custom.domain", version=10)
return op.Double(x, _domain="custom.domain", _version=10)

rule = pattern.RewriteRule(add_same, double)

Expand Down Expand Up @@ -373,7 +373,7 @@ def test_optional_attribute(self):

def concat_pattern(op, x, y):
seq = op.SequenceConstruct(x, y)
result = op.ConcatFromSequence(seq, outputs=["result"])
result = op.ConcatFromSequence(seq, _outputs=["result"])
return result

def concat(op, x, y, result: ir.Value):
Expand Down
Loading