Skip to content

Commit

Permalink
[rewriter] Create the Dropout->Identity rules (#1813)
Browse files Browse the repository at this point in the history
Fix #1776
  • Loading branch information
justinchuby authored Sep 3, 2024
1 parent 74ae4cc commit 2627ab0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
12 changes: 12 additions & 0 deletions onnxscript/rewriter/no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def div_by_1(op, x):
return x / 1


def dropout_zero(op, x):
return op.Dropout(x, ratio=0.0)


def dropout_inference(op, x):
return op.Dropout(x, training_mode=False)


# Replacement
def identity(op, x):
return op.Identity(x)
Expand All @@ -32,6 +40,8 @@ def identity(op, x):
add_0_rule = pattern.RewriteRule(add_0, identity)
sub_0_rule = pattern.RewriteRule(sub_0, identity)
div_by_1_rule = pattern.RewriteRule(div_by_1, identity)
dropout_zero_rule = pattern.RewriteRule(dropout_zero, identity)
dropout_inference_rule = pattern.RewriteRule(dropout_inference, identity)
# TODO: Include Mul by 0, 0 by Mul, 0 by Div? Those would be 0s, but not no-ops

rules = pattern.RewriteRuleSet(
Expand All @@ -40,5 +50,7 @@ def identity(op, x):
*add_0_rule.commute(),
sub_0_rule,
div_by_1_rule,
dropout_zero_rule,
dropout_inference_rule,
]
)
20 changes: 20 additions & 0 deletions onnxscript/rewriter/no_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,26 @@ def test_div_one_should_become_no_op_with_initializer(
"""
)

@parameterized.parameterized.expand(
[
("dropout zero ratio", "ratio=0.0"),
("dropout inference", "training_mode=0"),
("dropout inference with positive ratio", "ratio=0.42, training_mode=0"),
("dropout training with zero ratio", "ratio=0.0, training_mode=1"),
]
)
def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: str):
self._check(
f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float16[M] input) => (float16[M] output)
{{
output = Dropout<{attribute}>(input)
}}
"""
)
# TODO: Test the negative cases


if __name__ == "__main__":
unittest.main()

0 comments on commit 2627ab0

Please sign in to comment.