From fca7401966f5a6964531436209406ef312a0d041 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 21 May 2024 16:27:21 -0700 Subject: [PATCH] Simplify get_numpy_from_ir_value (#1561) Fix #1550 NOTE: from https://github.com/microsoft/onnxscript/pull/1553#discussion_r1604205079, I think we can hide the None check in this function. --- onnxscript/rewriter/_ir_utils.py | 9 ------ .../instance_to_group_normalization.py | 28 +++++++++++++------ .../onnxruntime/transformers/layernorm.py | 5 ++-- .../transformers/multihead_attention.py | 7 ++++- onnxscript/rewriter/pattern.py | 23 +++++++++------ onnxscript/rewriter/pattern_test.py | 7 +++-- 6 files changed, 48 insertions(+), 31 deletions(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 9bfc4ac5a..702e5a3f9 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -42,12 +42,3 @@ def propagate_const_value(ir_value: ir.Value) -> ir.Value: ir_value.shape = const_value.shape # type: ignore ir_value.dtype = const_value.dtype return ir_value - - -def get_numpy_from_ir_value(value: ir.Value) -> np.ndarray | None: - constant_value = value.const_value - if constant_value is not None: - if isinstance(constant_value, ir.serde.TensorProtoTensor): - return constant_value.numpy() - return np.array(constant_value) - return constant_value diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index ca06917b5..559033a7c 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -40,11 +40,17 @@ def check_if_simulated_instance_norm_is_used( Returns: bool: True if the simulated instance normalization is used, False otherwise. """ - weight_for_norm = _ir_utils.propagate_const_value(weight_for_norm) - weight_for_norm = _ir_utils.get_numpy_from_ir_value(weight_for_norm) + weight_for_norm_prop = _ir_utils.propagate_const_value(weight_for_norm) + weight_for_norm_const_value = weight_for_norm_prop.const_value + if weight_for_norm_const_value is None: + return False + weight_for_norm = weight_for_norm_const_value.numpy() - bias_for_norm = _ir_utils.propagate_const_value(bias_for_norm) - bias_for_norm = _ir_utils.get_numpy_from_ir_value(bias_for_norm) + bias_for_norm_prop = _ir_utils.propagate_const_value(bias_for_norm) + bias_for_norm_const_value = bias_for_norm_prop.const_value + if bias_for_norm_const_value is None: + return False + bias_for_norm = bias_for_norm_const_value.numpy() if not np.all(weight_for_norm == 1): return False @@ -69,16 +75,22 @@ def check_if_simulated_instance_norm_is_used( return False adjusted_input_shape = _ir_utils.propagate_const_value(adjusted_input_shape) - adjusted_input_shape = _ir_utils.get_numpy_from_ir_value(adjusted_input_shape) + adjusted_input_shape_const_value = adjusted_input_shape.const_value g = weight_for_norm.shape[0] - if adjusted_input_shape is None or adjusted_input_shape.tolist() != [0, g, -1]: + if ( + adjusted_input_shape_const_value is None + or adjusted_input_shape_const_value.numpy().tolist() != [0, g, -1] + ): return False # NOTE: Restrict the rule to only support constant shape original_input_shape = _ir_utils.propagate_const_value(original_input_shape) - original_input_shape = _ir_utils.get_numpy_from_ir_value(original_input_shape) - if original_input_shape is None or original_input_shape.tolist() != input_x.shape: + original_input_shape_const_value = original_input_shape.const_value + if ( + original_input_shape_const_value is None + or original_input_shape_const_value.numpy().tolist() != input_x.shape + ): return False return True diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py index 54ccfa86b..d6e5fe1d5 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py +++ b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py @@ -22,9 +22,10 @@ def _fusion(self, function: ir.Function) -> ir.Function: raise function_rule.FunctionRewriteError("Could not find Add node") eps_ir_value = _ir_utils.propagate_const_value(aten_add_node.inputs[1]) - eps_numpy_value = _ir_utils.get_numpy_from_ir_value(eps_ir_value) - if eps_numpy_value is None: + eps_const_value = eps_ir_value.const_value + if eps_const_value is None: raise function_rule.FunctionRewriteError("Could not find eps") + eps_numpy_value = eps_const_value.numpy() eps = eps_numpy_value.item() logger.info("eps: %s", eps) diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py index 9c16ef975..1ed949d4b 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py +++ b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py @@ -109,7 +109,12 @@ def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig: constant_node.op_type == "Constant" ), "Expected the second input to Reshape to be a Constant node." value = _ir_utils.propagate_const_value(reshape_node.inputs[1]) - constant_numpy_value = _ir_utils.get_numpy_from_ir_value(value) + constant_value = value.const_value + if constant_value is None: + raise function_rule.FunctionRewriteError( + "Failed to propagate constant value for Reshape node." + ) + constant_numpy_value = constant_value.numpy() if constant_numpy_value.shape[0] == 4: num_attention_heads = constant_numpy_value[2] head_size = constant_numpy_value[3] diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 504cfdeea..337e9cd43 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -576,18 +576,24 @@ def value(self) -> int | float: def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: value = _ir_utils.propagate_const_value(value) - constant_value = _ir_utils.get_numpy_from_ir_value(value) + constant_value = value.const_value if constant_value is None: return match.fail(f"Value is not a constant, expecting {self.value}.") + constant_value_numpy = constant_value.numpy() # TODO (rama): allow users to specify shape requirement, if desired. - if constant_value.size != 1: + if constant_value_numpy.size != 1: return match.fail(f"Value is not a scalar, expecting {self.value}.") if not math.isclose( - constant_value.item(), self._value, rel_tol=self._rel_tol, abs_tol=self._abs_tol + constant_value_numpy.item(), + self._value, + rel_tol=self._rel_tol, + abs_tol=self._abs_tol, ): - match.fail(f"Value mismatch: expected {self._value}, got {constant_value.item()}.") + match.fail( + f"Value mismatch: expected {self._value}, got {constant_value_numpy.item()}." + ) # Note: If the value is produced by a Constant node, we could include # the Constant node in the return_value list. However, we don't do that. @@ -893,26 +899,27 @@ def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool: node if it is not used elsewhere. """ value = _ir_utils.propagate_const_value(value) - constant_value = _ir_utils.get_numpy_from_ir_value(value) + constant_value = value.const_value if constant_value is None: return self.fail( f"Value {value.name} is not a constant, expecting {pattern_constant.value}.", ) + constant_value_numpy = constant_value.numpy() # TODO (rama): allow users to specify shape requirement, if desired. - if constant_value.size != 1: + if constant_value_numpy.size != 1: return self.fail( f"Value {value.name} is not a scalar, expecting {pattern_constant.value}.", ) if not math.isclose( - constant_value.item(), + constant_value_numpy.item(), pattern_constant._value, rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, ): return self.fail( - f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value.item()}.", + f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value_numpy.item()}.", ) return True diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 1ccddcc31..fde2c3b06 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -3,7 +3,6 @@ import logging import unittest -import numpy as np import onnx.checker import onnx.parser @@ -258,9 +257,11 @@ def identity(op, x, newshape): def check_for_redundant_reshape(context, x, newshape): oldshape = x.shape newshape = _ir_utils.propagate_const_value(newshape) - newshape = _ir_utils.get_numpy_from_ir_value(newshape) - if not isinstance(newshape, np.ndarray): + newshape_const_value = newshape.const_value + if newshape_const_value is None: return False + + newshape = newshape_const_value.numpy() newshape = newshape.tolist() if len(oldshape) != len(newshape):