From 5a3595882cbbc95f5cd23f7a024bd5096ced63dc Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 14 Nov 2024 13:33:40 -0800 Subject: [PATCH] Handle input initializers correctly in constant folding (#1944) Values that are both inputs and initializers of a model/graph should not be treated as constants (and cannot be used for constant-folding). Unfortunately, the single `const_value` field is class Value is used both to indicate constant-values of proper constants as well as initializer values of initializers. Ideally, the IR should provide an easy way to distinguish this at the value level (with either an extra boolean flag to indicate the value is an input-value or by using distinct fields for "initializer_value" and "const_value". Meanwhile, this PR introduces a workaround to handle the main issue. --- onnxscript/optimizer/_constant_folding.py | 26 +++++++++++++++++++ .../optimizer/_constant_folding_test.py | 24 +++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 418593ff4..a5141c6bc 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -137,6 +137,7 @@ class Replacement: class OptimizerState: def __init__(self): self._sym_value_map: dict[ir.Value, Any] = {} + self._initializer_inputs: list[set[ir.Value]] = [] def get_sym_value(self, value: ir.Value | None) -> Any: if value is None: @@ -146,6 +147,19 @@ def get_sym_value(self, value: ir.Value | None) -> Any: def set_sym_value(self, value: ir.Value, sym_value: Any) -> None: self._sym_value_map[value] = sym_value + def push_initializer_inputs(self) -> None: + self._initializer_inputs.append(set()) + + def pop_initializer_inputs(self) -> None: + self._initializer_inputs.pop() + + def add_initializer_input(self, value: ir.Value) -> None: + assert self._initializer_inputs + self._initializer_inputs[-1].add(value) + + def is_initializer_input(self, value: ir.Value) -> bool: + return any(value in inputs for inputs in self._initializer_inputs) + # The "partial evaluators" below are non-standard evaluators. They are used to perform # partial evaluation and/or static program analysis (abstract interpretation). @@ -754,6 +768,9 @@ def process_node(self, node: ir.Node): if any(x is None for x in input_values): return None + if any(self._state.is_initializer_input(x) for x in node.inputs): # type: ignore[arg-type] + return None + if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr] if logger.isEnabledFor(logging.DEBUG): input_sizes = [input.size for input in input_values] # type: ignore[union-attr] @@ -817,9 +834,18 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): self.replace_node(node, replacement, root) def visit_graph(self, graph: ir.Graph) -> None: + # Track inputs that have a const_value (which is really a default-value, and should not + # be used for constant-folding). + self._state.push_initializer_inputs() + for input in graph.inputs: + if input.const_value is not None: + self._state.add_initializer_input(input) + for node in graph: self.visit_node(node, graph) + self._state.pop_initializer_inputs() + def visit_function(self, function: ir.Function) -> None: for node in function: self.visit_node(node, function) diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 52e06bd56..d6a799116 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -6,6 +6,7 @@ import parameterized import pytest +import onnxscript.ir as ir import onnxscript.optimizer as optimizer from onnxscript.ir import serde from onnxscript.optimizer import _constant_folding @@ -434,5 +435,28 @@ def test_concat_identity(self): self.assertEqual(optimized.graph.node[0].op_type, "Identity") +class FoldConstantsIrTest(unittest.TestCase): + def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model: + model_proto = onnx.parser.parse_model(model_text) + model = serde.deserialize_model(model_proto) + _constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) + optimizer.remove_unused_nodes(model) + return model + + def test_initializer_input_not_folded(self): + model_text = """ + + agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z) + { + # c is not a constant, and following should not be folded. + two_c = Add (c, c) + z = Mul (x, two_c) + } + """ + optimized = self._fold(model_text) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph.node(0).op_type, "Add") + + if __name__ == "__main__": unittest.main()