Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed May 2, 2024
1 parent 8876bb9 commit 0dbb84d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
12 changes: 6 additions & 6 deletions docs/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def original_model(A: FLOAT[1, 4, 512, 512], B: FLOAT[1, 4, 512, 64]) -> FLOAT[1
return result


model = original_model.to_model_proto()
onnx.checker.check_model(model)
_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
Expand All @@ -55,7 +55,7 @@ def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shap
# =====================


def matmul(op, input_a: ir.Value, input_b: ir.Value, **_):
def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):

Check warning

Code scanning / lintrunner

PYLINT/W0621 Warning documentation

Redefining name 'op' from outer scope (line 43) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name
return op.MatMul(input_a, input_b)


Expand Down Expand Up @@ -181,7 +181,7 @@ def apply_rewrite(model):
# Create rewrite rules
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul, # replacement pattern
matmul_pattern, # replacement pattern
check_if_need_reshape, # match_condition function
)
# Create a Rewrite Rule Set
Expand All @@ -194,5 +194,5 @@ def apply_rewrite(model):
return model_with_rewrite


model_with_rewrite = apply_rewrite(model)
onnx.checker.check_model(model_with_rewrite)
_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)
6 changes: 3 additions & 3 deletions docs/rewriter/examples/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def original_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
return result


model = original_model.to_model_proto()
onnx.checker.check_model(model)
_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
Expand Down Expand Up @@ -144,7 +144,7 @@ def apply_rewrite_with_commute(model):


# Rewrite-Simple
model_with_rewrite = apply_rewrite(model)
model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(model_with_rewrite)

# Rewrite-Single-Patterns
Expand Down
2 changes: 1 addition & 1 deletion docs/rewriter/rewrite_patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Extending the previous [simple example](heading-target-simple), assumming a scen

In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched.

![gelu_pattern_1](examples/img/erfgelu_04_commute.png){width=330px align=left} ![gelu_pattern_2](examples/img/erfgelu_05_commute.png){width=330px align=center}
![gelu_pattern_1](examples/img/erfgelu_04_commute.png){width=330px align=left} ![gelu_pattern_2](examples/img/erfgelu_05_commute.png){width=330px align=center}


If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched.
Expand Down
5 changes: 1 addition & 4 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def identity(op, x, newshape):
del newshape # Unused
return op.Identity(x)

def _check_for_redundant_reshape(x, newshape):
def check_for_redundant_reshape(x, newshape):
oldshape = x.shape
newshape = _ir_utils.propagate_const_value(newshape)
newshape = _ir_utils.get_numpy_from_ir_value(newshape)
Expand All @@ -257,9 +257,6 @@ def _check_for_redundant_reshape(x, newshape):
return False
return all(not (d1 != d2 and d2 != -1) for d1, d2 in zip(oldshape, newshape)) # pylint: disable=consider-using-in

def check_for_redundant_reshape(bindings):
return _check_for_redundant_reshape(**bindings)

rule = pattern.RewriteRule(reshape, identity, check_for_redundant_reshape)

model_proto = onnx.parser.parse_model(
Expand Down

0 comments on commit 0dbb84d

Please sign in to comment.