From 2627ab06990bfe9522fd997d8905dc5542135fc3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 13:31:51 -0700 Subject: [PATCH] [rewriter] Create the Dropout->Identity rules (#1813) Fix #1776 --- onnxscript/rewriter/no_op.py | 12 ++++++++++++ onnxscript/rewriter/no_op_test.py | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 7a4b00798..21cee515d 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -23,6 +23,14 @@ def div_by_1(op, x): return x / 1 +def dropout_zero(op, x): + return op.Dropout(x, ratio=0.0) + + +def dropout_inference(op, x): + return op.Dropout(x, training_mode=False) + + # Replacement def identity(op, x): return op.Identity(x) @@ -32,6 +40,8 @@ def identity(op, x): add_0_rule = pattern.RewriteRule(add_0, identity) sub_0_rule = pattern.RewriteRule(sub_0, identity) div_by_1_rule = pattern.RewriteRule(div_by_1, identity) +dropout_zero_rule = pattern.RewriteRule(dropout_zero, identity) +dropout_inference_rule = pattern.RewriteRule(dropout_inference, identity) # TODO: Include Mul by 0, 0 by Mul, 0 by Div? Those would be 0s, but not no-ops rules = pattern.RewriteRuleSet( @@ -40,5 +50,7 @@ def identity(op, x): *add_0_rule.commute(), sub_0_rule, div_by_1_rule, + dropout_zero_rule, + dropout_inference_rule, ] ) diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/no_op_test.py index 92172ec1f..4e509e7f3 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/no_op_test.py @@ -177,6 +177,26 @@ def test_div_one_should_become_no_op_with_initializer( """ ) + @parameterized.parameterized.expand( + [ + ("dropout zero ratio", "ratio=0.0"), + ("dropout inference", "training_mode=0"), + ("dropout inference with positive ratio", "ratio=0.42, training_mode=0"), + ("dropout training with zero ratio", "ratio=0.0, training_mode=1"), + ] + ) + def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: str): + self._check( + f""" + + agraph (float16[M] input) => (float16[M] output) + {{ + output = Dropout<{attribute}>(input) + }} + """ + ) + # TODO: Test the negative cases + if __name__ == "__main__": unittest.main()