diff --git a/onnxscript/optimizer/remove_unused.py b/onnxscript/optimizer/remove_unused.py index 2b8cd6789..06d1e0717 100644 --- a/onnxscript/optimizer/remove_unused.py +++ b/onnxscript/optimizer/remove_unused.py @@ -26,11 +26,23 @@ def remove_unused_optional_outputs( op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain) except Exception: return - # TODO: If current node is a BatchNormalization node, - # based on training_mode atrribute, number of optional outputs and - # how they are handled varies, handle both training_modes + if n.op_type == "BatchNormalization": - return + # BatchNormalization op has 3 outputs: Y, running_mean, running_var + # If running_mean and running_var are not used, remove them, and the training_mode attribute + def is_used_output(i: int) -> bool: + if i < len(n.output): + return n.output[i] in used + return False + + if is_used_output(1) or is_used_output(2): + return + del n.output[1:] + for j, attr in enumerate(n.attribute): + if attr.name == "training_mode": + del n.attribute[j] + break + optional_info = [] for o in op_schema.outputs: # Current ops do not have optional outputs if they have variable number of outputs diff --git a/onnxscript/optimizer/remove_unused_test.py b/onnxscript/optimizer/remove_unused_test.py index 656d808a9..8d6aa2525 100644 --- a/onnxscript/optimizer/remove_unused_test.py +++ b/onnxscript/optimizer/remove_unused_test.py @@ -170,6 +170,40 @@ def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self): self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") self.assertEqual(len(model.graph.node[2].output), 3) + def test_remove_trailing_unused_optional_outputs_batchnorm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z) { + z, mean_out, var_out = BatchNormalization (x, scale, B, mean, var) + } + """ + ) + self.assertEqual(len(model.graph.node[0].attribute), 1) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "BatchNormalization") + # Check that both the mean/var outputs are removed, and training_mode attribute is removed. + self.assertEqual(len(model.graph.node[0].output), 1) + self.assertEqual(len(model.graph.node[0].attribute), 0) + + def test_avoid_remove_used_optional_outputs_batchnorm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z, float[3] mean_out) { + z, mean_out, var_out = BatchNormalization (x, scale, B, mean, var) + } + """ + ) + self.assertEqual(len(model.graph.node[0].attribute), 1) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "BatchNormalization") + # Check that the mean/var outputs are NOT removed, and training_mode attribute is NOT removed. + self.assertEqual(len(model.graph.node[0].output), 3) + self.assertEqual(len(model.graph.node[0].attribute), 1) + if __name__ == "__main__": unittest.main()