From ed28222099649380f1c2e9e981c0f06c074e13ab Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 11 Oct 2024 08:25:43 -0700 Subject: [PATCH] A few fixes relating to constant propagation (#1892) Fixes a few different issues. Helps resolve an issue relating to ir-based optimization for the Blender model in the benchmark. * Move the utility for evaluating `Constant` op into the IR, and make `const_value` automatically perform the related computation. * Eliminate the dependence on the reference-implementation for evaluation of Constant op. * There are still a couple of issues relating to the use of reference-implementation (eg., when we have tensor-valued attributes in external-data format, and the use of float16) which will need to be addressed separately, but the above bypasses this issue for Constant op (and the Blender model). * Make the optimizer robust to external-data-tensors whose files are not available. --- .../rewriter/examples/broadcast_matmul.py | 4 +- onnxscript/optimizer/__init__.py | 3 + onnxscript/optimizer/_constant_folding.py | 69 ++++++++++++++++++- onnxscript/rewriter/_ir_utils.py | 47 ++----------- onnxscript/rewriter/broadcast_to_matmul.py | 4 +- .../instance_to_group_normalization.py | 10 +-- .../onnxruntime/transformers/layernorm.py | 8 ++- .../transformers/multihead_attention.py | 8 ++- onnxscript/rewriter/pattern.py | 12 ++-- onnxscript/rewriter/pattern_test.py | 3 +- 10 files changed, 100 insertions(+), 68 deletions(-) diff --git a/docs/tutorial/rewriter/examples/broadcast_matmul.py b/docs/tutorial/rewriter/examples/broadcast_matmul.py index e529f39d0..de919cf9c 100644 --- a/docs/tutorial/rewriter/examples/broadcast_matmul.py +++ b/docs/tutorial/rewriter/examples/broadcast_matmul.py @@ -15,7 +15,7 @@ import onnxscript from onnxscript import FLOAT, ir, opset18, script -from onnxscript.rewriter import _ir_utils, pattern +from onnxscript.rewriter import pattern logger = logging.getLogger(__name__) @@ -83,8 +83,6 @@ def check_if_not_need_reshape( input_a_shape = input_a.shape input_b_shape = input_b.shape - # TODO: Get a helper func to get const_value - _ir_utils.propagate_const_value(shape_c) shape_c_tensor = shape_c.const_value if shape_c_tensor is None: logger.info("The value 'shape_c' is not statically known.") diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index f6e2715ab..b35f70a52 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -126,9 +126,12 @@ def optimize_ir( remove_unused_nodes(model) +basic_constant_propagation = _constant_folding.basic_constant_propagation + __all__ = [ "fold_constants", "remove_unused_nodes", "optimize", "optimize_ir", + "basic_constant_propagation", ] diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index a93bc3927..818fd95e1 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -8,7 +8,8 @@ import dataclasses import logging import math -from typing import Any, Callable, Sequence, Union +import typing +from typing import Any, Callable, Iterable, Sequence, Union import numpy as np import onnx @@ -32,6 +33,10 @@ def is_non_deterministic_op(node: ir.Node) -> bool: ) +def is_onnx_op(node: ir.Node, op_type: str) -> bool: + return node.op_type == op_type and utils.is_onnx_domain(node.domain) + + def is_constant_op(node: ir.Node) -> bool: return node.op_type in {"Constant", "ConstantOfShape"} and utils.is_onnx_domain( node.domain @@ -48,6 +53,50 @@ def is_constant_op(node: ir.Node) -> bool: # use ORT's implementation if we want to. +def _process_constant_node(node: ir.Node) -> None: + """Sets const_value of output value of a Constant op node.""" + if node.op_type != "Constant" or node.domain not in {"", "ai.onnx"}: + return + if len(node.attributes) != 1: + return + attr_name, attr_value = next(iter(node.attributes.items())) + if len(node.outputs) != 1: + return + ir_value = node.outputs[0] + + if attr_value is None or not isinstance(attr_value, ir.Attr): + return + + const_value: ir.TensorProtocol + if attr_name in {"value_float", "value_floats"}: + const_value = ir.Tensor( + np.array(attr_value.value, dtype=np.float32), name=ir_value.name + ) + elif attr_name in {"value_int", "value_ints"}: + const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name) + elif attr_name in {"value_string", "value_strings"}: + const_value = ir.StringTensor( + np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name + ) + elif attr_name == "value": + const_value = typing.cast(ir.TensorProtocol, attr_value.value) + else: + return + + ir_value.const_value = const_value + ir_value.shape = const_value.shape # type: ignore + ir_value.dtype = const_value.dtype + + +def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None: + """Performs basic constant propagation for a sequence of nodes. + + Just marks the output values of Constant op nodes with their const_value. + """ + for node in nodes: + _process_constant_node(node) + + class ReferenceEvaluator: def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None: try: @@ -168,7 +217,11 @@ def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None: return None const_value = val.const_value if const_value is not None: - return const_value.numpy() + try: + return const_value.numpy() + except FileNotFoundError: + # External data is not available. + return None return None @@ -604,6 +657,12 @@ def process_node(self, node: ir.Node): for i, value in enumerate(node.inputs): sym_value = self._state.get_sym_value(value) if isinstance(sym_value, ir.Value): + logger.debug( + "Node [%s]: Replacing input %s with %s", + node.name, + value.name, # type: ignore[union-attr] + sym_value.name, + ) node.replace_input_with(i, sym_value) # TODO(rama): consider merging type/other info from both values @@ -629,6 +688,10 @@ def process_node(self, node: ir.Node): if is_control_flow_op(node) or is_non_deterministic_op(node): return None + if is_onnx_op(node, "Constant"): + _process_constant_node(node) + return None + input_values = [_get_numpy_value(x) for x in node.inputs] if any(x is None for x in input_values): return None @@ -648,7 +711,7 @@ def convert(av): return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): replacement = self.new_constant(node.outputs[0], outputs) - if is_constant_op(node) or replacement is None: + if is_onnx_op(node, "ConstantOfShape") or replacement is None: return None return Replacement(replacement.outputs, [replacement]) else: diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index c7a7b7ad0..bd353f388 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -1,46 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""This is a temporary utility to assist new IR while it's still under development.""" - from __future__ import annotations -import typing - -import numpy as np - -from onnxscript import ir - -GRAPH_OUTPUT_META_KEY = "pkg.onnxscript.rewriter.generic_pattern.graph_output" - - -def propagate_const_value(ir_value: ir.Value) -> ir.Value: - """Temporary method to propagate a constant value to the IR value.""" - node = ir_value.producer() - if node is None: - return ir_value - if node.op_type != "Constant": - return ir_value - attr_name, attr_value = next(iter(node.attributes.items())) - if attr_value is None or not isinstance(attr_value, ir.Attr): - return ir_value +import onnxscript.ir as ir +from onnxscript.optimizer import basic_constant_propagation - const_value: ir.TensorProtocol - if attr_name in {"value_float", "value_floats"}: - const_value = ir.Tensor( - np.array(attr_value.value, dtype=np.float32), name=ir_value.name - ) - elif attr_name in {"value_int", "value_ints"}: - const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name) - elif attr_name in {"value_string", "value_strings"}: - const_value = ir.StringTensor( - np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name - ) - elif attr_name == "value": - const_value = typing.cast(ir.TensorProtocol, attr_value.value) - else: - return ir_value - ir_value.const_value = const_value - ir_value.shape = const_value.shape # type: ignore - ir_value.dtype = const_value.dtype - return ir_value +def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: + node = value.producer() + if node is not None: + basic_constant_propagation([node]) + return value.const_value diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index 3ae5562cd..df216d977 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -5,7 +5,7 @@ import logging from onnxscript import ir -from onnxscript.rewriter import _ir_utils, pattern +from onnxscript.rewriter import pattern logger = logging.getLogger(__name__) @@ -30,8 +30,6 @@ def check_if_not_need_reshape( input_a_shape = input_a.shape input_b_shape = input_b.shape - # TODO: Get a helper func to get const_value - _ir_utils.propagate_const_value(shape_c) shape_c_tensor = shape_c.const_value if shape_c_tensor is None: logger.info("The value 'shape_c' is not statically known.") diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index 85b412b24..fa0f67c5e 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -7,7 +7,7 @@ import numpy as np import onnx -from onnxscript.rewriter import _ir_utils, pattern +from onnxscript.rewriter import pattern torch_module_op = pattern.torch_module_op @@ -42,14 +42,12 @@ def check_if_simulated_instance_norm_is_used( Returns: bool: True if the simulated instance normalization is used, False otherwise. """ - weight_for_norm_prop = _ir_utils.propagate_const_value(weight_for_norm) - weight_for_norm_const_value = weight_for_norm_prop.const_value + weight_for_norm_const_value = weight_for_norm.const_value if weight_for_norm_const_value is None: return False weight_for_norm = weight_for_norm_const_value.numpy() - bias_for_norm_prop = _ir_utils.propagate_const_value(bias_for_norm) - bias_for_norm_const_value = bias_for_norm_prop.const_value + bias_for_norm_const_value = bias_for_norm.const_value if bias_for_norm_const_value is None: return False bias_for_norm = bias_for_norm_const_value.numpy() @@ -76,7 +74,6 @@ def check_if_simulated_instance_norm_is_used( if not all(dim == 1 for dim in bias_full_shape[1:]): return False - adjusted_input_shape = _ir_utils.propagate_const_value(adjusted_input_shape) adjusted_input_shape_const_value = adjusted_input_shape.const_value g = weight_for_norm.shape[0] @@ -87,7 +84,6 @@ def check_if_simulated_instance_norm_is_used( return False # NOTE: Restrict the rule to only support constant shape - original_input_shape = _ir_utils.propagate_const_value(original_input_shape) original_input_shape_const_value = original_input_shape.const_value if ( original_input_shape_const_value is None diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py index edbfa4e02..fb56c9f6c 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py +++ b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py @@ -5,8 +5,10 @@ import logging import onnxscript +import onnxscript.ir.convenience +import onnxscript.rewriter._ir_utils as _ir_utils from onnxscript import ir -from onnxscript.rewriter import _ir_utils, function_rule +from onnxscript.rewriter import function_rule logger = logging.getLogger(__name__) @@ -23,8 +25,8 @@ def _fusion(self, function: ir.Function) -> ir.Function: if aten_add_node is None: raise function_rule.FunctionRewriteError("Could not find Add node") - eps_ir_value = _ir_utils.propagate_const_value(aten_add_node.inputs[1]) - eps_const_value = eps_ir_value.const_value + eps_ir_value = aten_add_node.inputs[1] + eps_const_value = _ir_utils.get_const_value(eps_ir_value) if eps_const_value is None: raise function_rule.FunctionRewriteError("Could not find eps") eps_numpy_value = eps_const_value.numpy() diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py index 85053479f..7fff108f6 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py +++ b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py @@ -56,8 +56,10 @@ from onnx import helper as onnx_helper import onnxscript +import onnxscript.ir.convenience +import onnxscript.rewriter._ir_utils as _ir_utils from onnxscript import ir -from onnxscript.rewriter import _ir_utils, function_rule +from onnxscript.rewriter import function_rule logger = logging.getLogger(__name__) @@ -110,8 +112,8 @@ def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig: assert ( constant_node.op_type == "Constant" ), "Expected the second input to Reshape to be a Constant node." - value = _ir_utils.propagate_const_value(reshape_node.inputs[1]) - constant_value = value.const_value + value = reshape_node.inputs[1] + constant_value = _ir_utils.get_const_value(value) if constant_value is None: raise function_rule.FunctionRewriteError( "Failed to propagate constant value for Reshape node." diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 1f00840d4..d49e503f1 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -21,9 +21,9 @@ Union, ) +import onnxscript.optimizer from onnxscript import ir from onnxscript.ir import _convenience, _tape -from onnxscript.rewriter import _ir_utils T = TypeVar("T") @@ -618,7 +618,6 @@ def value(self) -> int | float: return self._value def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: - value = _ir_utils.propagate_const_value(value) constant_value = value.const_value if constant_value is None: return match.fail(f"Value is not a constant, expecting {self.value}.") @@ -915,14 +914,16 @@ def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool: if subgraph replacement happens. But subsequent DCE will remove the constant node if it is not used elsewhere. """ - value = _ir_utils.propagate_const_value(value) constant_value = value.const_value if constant_value is None: return self.fail( f"Value {value.name} is not a constant, expecting {pattern_constant.value}.", ) - constant_value_numpy = constant_value.numpy() + try: + constant_value_numpy = constant_value.numpy() + except FileNotFoundError: + return self.fail(f"Constant value of {value.name} not available.") # TODO (rama): allow users to specify shape requirement, if desired. if constant_value_numpy.size != 1: return self.fail( @@ -1372,6 +1373,7 @@ def _apply_to_graph_or_function( # for inserted nodes in the case of patterns with multiple output-nodes. The following # is sufficient for patterns with a single output-node "node", which can serve as the # insertion-point. + onnxscript.optimizer.basic_constant_propagation(delta.new_nodes) _convenience.replace_nodes_and_values( graph_or_function, node, @@ -1386,8 +1388,10 @@ def _apply_to_graph_or_function( def apply_to_model(self, model: ir.Model, verbose: int | None = None) -> int: assert isinstance(model, ir.Model) + onnxscript.optimizer.basic_constant_propagation(model.graph) count = self._apply_to_graph_or_function(model, model.graph, verbose=verbose) for function in model.functions.values(): + onnxscript.optimizer.basic_constant_propagation(function) count += self._apply_to_graph_or_function(model, function, verbose=verbose) return count diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 6c9497d7a..0247949f5 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -10,7 +10,7 @@ from onnxscript import FLOAT, ir, script from onnxscript import opset17 as op -from onnxscript.rewriter import _ir_utils, cast_constant_of_shape, pattern +from onnxscript.rewriter import cast_constant_of_shape, pattern logger = logging.getLogger(__name__) @@ -259,7 +259,6 @@ def identity(op, x, newshape): def check_for_redundant_reshape(context, x, newshape): oldshape = x.shape - newshape = _ir_utils.propagate_const_value(newshape) newshape_const_value = newshape.const_value if newshape_const_value is None: return False