Skip to content

Commit

Permalink
Optimizer extensions (#2003)
Browse files Browse the repository at this point in the history
Extend the optimizer to enable optimizations (such as elimination of
redundant Expand/Reshape) when symbolic dimensions are present. This
requires propagating symbolic shape values (tensors that carry shape
information that is not completely known at compile time) through the
optimizer. These optimizations also help fusion optimizations
(otherwise, we need more pattern variations and more complex patterns).

Handles symbolic shape propagation through Abs, Gather, Concat and uses
them in Reshape/Expand. Abs shows up in Expand translation because Torch
allows -1 for "no expansion" while ONNX uses 1, but this is not
necessary if the input is a symbolic shape where every value is
guaranteed to be non-negative.

Also fix node-level shape-inference to refine shape by merging best info
from pre-existing shape and inferred shape.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
gramalingam and justinchuby authored Jan 9, 2025
1 parent 646116c commit a942e95
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 18 deletions.
169 changes: 151 additions & 18 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ class Replacement:
new_nodes: Sequence[ir.Node]


# The optimizer tracks an optional symbolic value for each value in the model.
# The symbolic value attached to a value X can be:
# - another IR value Y (indicating that X is equal to Y)
# - a list of IR values [Y1, Y2, ...] (indicating that X is a sequence of values Y1, Y2, ...)
# - a Shape object (indicating that X is a shape value)
# A Shape object as a symbolic value indicates that the corresponding value is
# 1-D (or 0-D) tensor of INT64 values. The values in this object may be constants
# or symbolic dimension values (like "batch_size", "sequence_length", etc.).
# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
# TODO: Add support for negative symbolic dimensions.


class OptimizerState:
def __init__(self):
self._sym_value_map: dict[ir.Value, Any] = {}
Expand All @@ -159,6 +171,18 @@ def add_initializer_input(self, value: ir.Value) -> None:
def is_initializer_input(self, value: ir.Value) -> bool:
return any(value in inputs for inputs in self._initializer_inputs)

def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
if const_value is not None:
if const_value.ndim == 1:
return ir.Shape(const_value.tolist())
return None
sym_value = self.get_sym_value(value)
if isinstance(sym_value, ir.Shape):
return sym_value
# TODO use shape of value if available
return None


# The "partial evaluators" below are non-standard evaluators. They are used to perform
# partial evaluation and/or static program analysis (abstract interpretation).
Expand Down Expand Up @@ -235,11 +259,33 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction:
register = registry.register


def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
def _same_shape(shape1: ir.Shape, shape2: ir.Shape) -> bool:
# Comparison of shapes as tuples works except if any dimension is None
# (which represents an unknown dimension value). Thus, two shapes such
# as (Batch, 1024) and (Batch, 1024) are considered equal, but (None, 1024)
# and (None, 1024) are not considered equal.
if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in shape1):
return False
return shape1.dims == shape2.dims


def _get_numpy_value(
val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None
) -> np.ndarray | None:
"""Returns the numpy value of a constant value, if available.
It returns None if the value is not a constant value, or if the value is not of
the specified element dtype, or if the size of the value exceeds the specified
size_limit.
"""
if val is None:
return None
const_value = val.const_value
if const_value is not None:
if dtype is not None and const_value.dtype != dtype:
return None
if size_limit is not None and const_value.size > size_limit:
return None
try:
array = const_value.numpy()
except FileNotFoundError:
Expand All @@ -256,7 +302,7 @@ def _get_bool_value(val: ir.Value | None) -> bool | None:
value = _get_numpy_value(val)
if value is None:
return None
if value.size == 1 and value.dtype == np.bool_:
if value.size == 1 and value.dtype == bool:
return value.item(0)
return None

Expand Down Expand Up @@ -300,6 +346,54 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
return default


@register("Abs")
def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace an Abs node by Identity when applicable.
Currently, addresses Abs applied to symbolic shapes.
"""
input = _get_input(node, 0)
input_sym_value = state.get_shape_value(input)
if input_sym_value is None:
return None
if any(isinstance(d, int) and d < 0 for d in input_sym_value):
return None
# Abs applied to a symbolic shape of the form [1, 1, SequenceLength].
# We assume that SequenceLength is a non-negative integer.
# The Abs op is redundant in this case.
return op.Identity(input)


@register("Gather")
def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Gather node by a constant when applicable.
Currently, handles the case of Gathering from a shape tensor.
"""
input = _get_input(node, 0)
indices = _get_input(node, 1)
if input is None or indices is None:
return None
input_sym_value = state.get_shape_value(input)
if input_sym_value is None:
return None
axis = _get_int_attribute(node, "axis", None)
if axis != 0:
return None
indices_numpy_value = _get_numpy_value(indices)
if indices_numpy_value is None:
return None
if indices_numpy_value.ndim != 1:
return None
gathered = [input_sym_value[i] for i in indices_numpy_value]
output = _get_output(node, 0)
if output is not None:
state.set_sym_value(output, ir.Shape(gathered))
if all(isinstance(d, int) for d in gathered):
return op.Constant(value_ints=gathered)
return None


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Reshape node by Identity when applicable."""
Expand All @@ -310,15 +404,16 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input_shape = input.shape
if input_shape is None:
return None
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None
shape_value = _get_numpy_value(shape)
# input_shape_dims = list(input_shape.dims)
# if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in input_shape_dims):
# return None
shape_value = state.get_shape_value(shape)
if shape_value is None:
return None
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
# target_shape_dims = list(shape_value.dims)
# if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
if _same_shape(input_shape, shape_value):
return op.Identity(input)
return None

Expand Down Expand Up @@ -373,6 +468,9 @@ def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
start = _get_int_attribute(node, "start", 0)
end = _get_int_attribute(node, "end", None)
shape_slice = shape[start:end]
output = _get_output(node, 0)
if output is not None:
state.set_sym_value(output, ir.Shape(shape_slice))
if all(isinstance(d, int) for d in shape_slice):
return op.Constant(value_ints=list(shape_slice))
return None
Expand Down Expand Up @@ -459,6 +557,19 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
inputs = node.inputs
if len(inputs) == 1:
return op.Identity(inputs[0])
# Track value of tensors that carry a shape value:
output = node.outputs[0]
if output is None:
return None
# Check axis attribute is 0
axis = _get_int_attribute(node, "axis", None)
if axis != 0:
return None
shapes = [state.get_shape_value(input) for input in inputs]
if any(shape is None for shape in shapes):
return None
concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) # type: ignore[union-attr]
state.set_sym_value(output, concatenated)
return None


Expand Down Expand Up @@ -507,7 +618,10 @@ def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None
if (expanded_shape := _get_numpy_value(node.inputs[1])) is None:
# Target shape is not known.
return None
expanded_sym_shape = state.get_shape_value(node.inputs[1])
if expanded_sym_shape is None or not _same_shape(input_shape, expanded_sym_shape):
return None
return op.Identity(input)
if expanded_shape.ndim != 1:
# Target shape must be a 1D tensor. Erroneous model.
return None
Expand Down Expand Up @@ -658,6 +772,27 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None


def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
def merge_dims(dim1, dim2):
if dim1 == dim2:
return dim1
if not isinstance(dim1, ir.SymbolicDim):
return dim1 # Prefer int value over symbolic dim
if not isinstance(dim2, ir.SymbolicDim):
return dim2
if dim1.value is None:
return dim2
return dim1

if shape1 is None:
return shape2
if shape2 is None:
return shape1
if len(shape1) != len(shape2):
raise ValueError("Shapes must have the same rank.")
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])


class ConstantFolder:
opset_imports: dict[str, int]

Expand Down Expand Up @@ -723,7 +858,10 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
if output.name in output_types:
inferred_type = output_types[output.name]
# TODO: merge types, check for conflicts
output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type)
inferred_shape = ir.serde.deserialize_type_proto_for_shape(
inferred_type
)
output.shape = _merge_shapes(output.shape, inferred_shape)
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
except Exception as e:
logger.debug(
Expand Down Expand Up @@ -763,13 +901,8 @@ def new_constant(self, irvalue: ir.Value, value):
value.shape,
)

node = ir.Node(
"",
"Constant",
inputs=[],
attributes=ir.convenience.convert_attributes({"value": tensor}),
num_outputs=1,
)
attributes = ir.convenience.convert_attributes({"value": tensor})
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node

def process_node(self, node: ir.Node):
Expand Down
71 changes: 71 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,77 @@ def test_expand_identity(self):
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_expand_identity_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (x)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Expand (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_abs_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (x)
const_256 = Constant <value_ints=[256]> ()
b_256 = Concat <axis=0> (b, const_256)
shape = Abs (b_256)
z = Expand (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_reshape_identity(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[128, 256] x) => (float[128, 256] z)
{
shape = Constant <value_ints=[128, 256]> ()
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_reshape_identity_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (y)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_gather_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z)
{
b_128 = Shape (y)
index_0 = Constant <value_ints=[0]> ()
b = Gather <axis=0> (b_128, index_0)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")


if __name__ == "__main__":
unittest.main()

0 comments on commit a942e95

Please sign in to comment.