Skip to content

Commit

Permalink
Simplify get_numpy_from_ir_value (#1561)
Browse files Browse the repository at this point in the history
Fix #1550 

NOTE: from
#1553 (comment),
I think we can hide the None check in this function.
  • Loading branch information
titaiwangms authored May 21, 2024
1 parent a5ed079 commit fca7401
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 31 deletions.
9 changes: 0 additions & 9 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,3 @@ def propagate_const_value(ir_value: ir.Value) -> ir.Value:
ir_value.shape = const_value.shape # type: ignore
ir_value.dtype = const_value.dtype
return ir_value


def get_numpy_from_ir_value(value: ir.Value) -> np.ndarray | None:
constant_value = value.const_value
if constant_value is not None:
if isinstance(constant_value, ir.serde.TensorProtoTensor):
return constant_value.numpy()
return np.array(constant_value)
return constant_value
28 changes: 20 additions & 8 deletions onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ def check_if_simulated_instance_norm_is_used(
Returns:
bool: True if the simulated instance normalization is used, False otherwise.
"""
weight_for_norm = _ir_utils.propagate_const_value(weight_for_norm)
weight_for_norm = _ir_utils.get_numpy_from_ir_value(weight_for_norm)
weight_for_norm_prop = _ir_utils.propagate_const_value(weight_for_norm)
weight_for_norm_const_value = weight_for_norm_prop.const_value
if weight_for_norm_const_value is None:
return False
weight_for_norm = weight_for_norm_const_value.numpy()

bias_for_norm = _ir_utils.propagate_const_value(bias_for_norm)
bias_for_norm = _ir_utils.get_numpy_from_ir_value(bias_for_norm)
bias_for_norm_prop = _ir_utils.propagate_const_value(bias_for_norm)
bias_for_norm_const_value = bias_for_norm_prop.const_value
if bias_for_norm_const_value is None:
return False
bias_for_norm = bias_for_norm_const_value.numpy()

if not np.all(weight_for_norm == 1):
return False
Expand All @@ -69,16 +75,22 @@ def check_if_simulated_instance_norm_is_used(
return False

adjusted_input_shape = _ir_utils.propagate_const_value(adjusted_input_shape)
adjusted_input_shape = _ir_utils.get_numpy_from_ir_value(adjusted_input_shape)
adjusted_input_shape_const_value = adjusted_input_shape.const_value

g = weight_for_norm.shape[0]
if adjusted_input_shape is None or adjusted_input_shape.tolist() != [0, g, -1]:
if (
adjusted_input_shape_const_value is None
or adjusted_input_shape_const_value.numpy().tolist() != [0, g, -1]
):
return False

# NOTE: Restrict the rule to only support constant shape
original_input_shape = _ir_utils.propagate_const_value(original_input_shape)
original_input_shape = _ir_utils.get_numpy_from_ir_value(original_input_shape)
if original_input_shape is None or original_input_shape.tolist() != input_x.shape:
original_input_shape_const_value = original_input_shape.const_value
if (
original_input_shape_const_value is None
or original_input_shape_const_value.numpy().tolist() != input_x.shape
):
return False

return True
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/rewriter/onnxruntime/transformers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def _fusion(self, function: ir.Function) -> ir.Function:
raise function_rule.FunctionRewriteError("Could not find Add node")

eps_ir_value = _ir_utils.propagate_const_value(aten_add_node.inputs[1])
eps_numpy_value = _ir_utils.get_numpy_from_ir_value(eps_ir_value)
if eps_numpy_value is None:
eps_const_value = eps_ir_value.const_value
if eps_const_value is None:
raise function_rule.FunctionRewriteError("Could not find eps")
eps_numpy_value = eps_const_value.numpy()
eps = eps_numpy_value.item()
logger.info("eps: %s", eps)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig:
constant_node.op_type == "Constant"
), "Expected the second input to Reshape to be a Constant node."
value = _ir_utils.propagate_const_value(reshape_node.inputs[1])
constant_numpy_value = _ir_utils.get_numpy_from_ir_value(value)
constant_value = value.const_value
if constant_value is None:
raise function_rule.FunctionRewriteError(
"Failed to propagate constant value for Reshape node."
)
constant_numpy_value = constant_value.numpy()
if constant_numpy_value.shape[0] == 4:
num_attention_heads = constant_numpy_value[2]
head_size = constant_numpy_value[3]
Expand Down
23 changes: 15 additions & 8 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,18 +576,24 @@ def value(self) -> int | float:

def matches(self, value: ir.Value, match: MatchResult) -> MatchResult:
value = _ir_utils.propagate_const_value(value)
constant_value = _ir_utils.get_numpy_from_ir_value(value)
constant_value = value.const_value
if constant_value is None:
return match.fail(f"Value is not a constant, expecting {self.value}.")

constant_value_numpy = constant_value.numpy()
# TODO (rama): allow users to specify shape requirement, if desired.
if constant_value.size != 1:
if constant_value_numpy.size != 1:
return match.fail(f"Value is not a scalar, expecting {self.value}.")

if not math.isclose(
constant_value.item(), self._value, rel_tol=self._rel_tol, abs_tol=self._abs_tol
constant_value_numpy.item(),
self._value,
rel_tol=self._rel_tol,
abs_tol=self._abs_tol,
):
match.fail(f"Value mismatch: expected {self._value}, got {constant_value.item()}.")
match.fail(
f"Value mismatch: expected {self._value}, got {constant_value_numpy.item()}."
)

# Note: If the value is produced by a Constant node, we could include
# the Constant node in the return_value list. However, we don't do that.
Expand Down Expand Up @@ -893,26 +899,27 @@ def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool:
node if it is not used elsewhere.
"""
value = _ir_utils.propagate_const_value(value)
constant_value = _ir_utils.get_numpy_from_ir_value(value)
constant_value = value.const_value
if constant_value is None:
return self.fail(
f"Value {value.name} is not a constant, expecting {pattern_constant.value}.",
)

constant_value_numpy = constant_value.numpy()
# TODO (rama): allow users to specify shape requirement, if desired.
if constant_value.size != 1:
if constant_value_numpy.size != 1:
return self.fail(
f"Value {value.name} is not a scalar, expecting {pattern_constant.value}.",
)

if not math.isclose(
constant_value.item(),
constant_value_numpy.item(),
pattern_constant._value,
rel_tol=pattern_constant._rel_tol,
abs_tol=pattern_constant._abs_tol,
):
return self.fail(
f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value.item()}.",
f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value_numpy.item()}.",
)

return True
Expand Down
7 changes: 4 additions & 3 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import unittest

import numpy as np
import onnx.checker
import onnx.parser

Expand Down Expand Up @@ -258,9 +257,11 @@ def identity(op, x, newshape):
def check_for_redundant_reshape(context, x, newshape):
oldshape = x.shape
newshape = _ir_utils.propagate_const_value(newshape)
newshape = _ir_utils.get_numpy_from_ir_value(newshape)
if not isinstance(newshape, np.ndarray):
newshape_const_value = newshape.const_value
if newshape_const_value is None:
return False

newshape = newshape_const_value.numpy()
newshape = newshape.tolist()

if len(oldshape) != len(newshape):
Expand Down

0 comments on commit fca7401

Please sign in to comment.