diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 1ecfa0911..8b4dbbfe5 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -133,6 +133,18 @@ class Replacement: new_nodes: Sequence[ir.Node] +# The optimizer tracks an optional symbolic value for each value in the model. +# The symbolic value attached to a value X can be: +# - another IR value Y (indicating that X is equal to Y) +# - a list of IR values [Y1, Y2, ...] (indicating that X is a sequence of values Y1, Y2, ...) +# - a Shape object (indicating that X is a shape value) +# A Shape object as a symbolic value indicates that the corresponding value is +# 1-D (or 0-D) tensor of INT64 values. The values in this object may be constants +# or symbolic dimension values (like "batch_size", "sequence_length", etc.). +# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative. +# TODO: Add support for negative symbolic dimensions. + + class OptimizerState: def __init__(self): self._sym_value_map: dict[ir.Value, Any] = {} @@ -159,6 +171,18 @@ def add_initializer_input(self, value: ir.Value) -> None: def is_initializer_input(self, value: ir.Value) -> bool: return any(value in inputs for inputs in self._initializer_inputs) + def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None: + const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10) + if const_value is not None: + if const_value.ndim == 1: + return ir.Shape(const_value.tolist()) + return None + sym_value = self.get_sym_value(value) + if isinstance(sym_value, ir.Shape): + return sym_value + # TODO use shape of value if available + return None + # The "partial evaluators" below are non-standard evaluators. They are used to perform # partial evaluation and/or static program analysis (abstract interpretation). @@ -235,11 +259,33 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: register = registry.register -def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None: +def _same_shape(shape1: ir.Shape, shape2: ir.Shape) -> bool: + # Comparison of shapes as tuples works except if any dimension is None + # (which represents an unknown dimension value). Thus, two shapes such + # as (Batch, 1024) and (Batch, 1024) are considered equal, but (None, 1024) + # and (None, 1024) are not considered equal. + if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in shape1): + return False + return shape1.dims == shape2.dims + + +def _get_numpy_value( + val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None +) -> np.ndarray | None: + """Returns the numpy value of a constant value, if available. + + It returns None if the value is not a constant value, or if the value is not of + the specified element dtype, or if the size of the value exceeds the specified + size_limit. + """ if val is None: return None const_value = val.const_value if const_value is not None: + if dtype is not None and const_value.dtype != dtype: + return None + if size_limit is not None and const_value.size > size_limit: + return None try: array = const_value.numpy() except FileNotFoundError: @@ -256,7 +302,7 @@ def _get_bool_value(val: ir.Value | None) -> bool | None: value = _get_numpy_value(val) if value is None: return None - if value.size == 1 and value.dtype == np.bool_: + if value.size == 1 and value.dtype == bool: return value.item(0) return None @@ -300,6 +346,54 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> return default +@register("Abs") +def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace an Abs node by Identity when applicable. + + Currently, addresses Abs applied to symbolic shapes. + """ + input = _get_input(node, 0) + input_sym_value = state.get_shape_value(input) + if input_sym_value is None: + return None + if any(isinstance(d, int) and d < 0 for d in input_sym_value): + return None + # Abs applied to a symbolic shape of the form [1, 1, SequenceLength]. + # We assume that SequenceLength is a non-negative integer. + # The Abs op is redundant in this case. + return op.Identity(input) + + +@register("Gather") +def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Gather node by a constant when applicable. + + Currently, handles the case of Gathering from a shape tensor. + """ + input = _get_input(node, 0) + indices = _get_input(node, 1) + if input is None or indices is None: + return None + input_sym_value = state.get_shape_value(input) + if input_sym_value is None: + return None + axis = _get_int_attribute(node, "axis", None) + if axis != 0: + return None + indices_numpy_value = _get_numpy_value(indices) + if indices_numpy_value is None: + return None + if indices_numpy_value.ndim != 1: + return None + gathered = [input_sym_value[i] for i in indices_numpy_value] + output = _get_output(node, 0) + if output is not None: + state.set_sym_value(output, ir.Shape(gathered)) + if all(isinstance(d, int) for d in gathered): + return op.Constant(value_ints=gathered) + return None + + @register("Reshape") def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace a Reshape node by Identity when applicable.""" @@ -310,15 +404,16 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input_shape = input.shape if input_shape is None: return None - input_shape_dims = list(input_shape.dims) - if any(not isinstance(dim, int) for dim in input_shape_dims): - return None - shape_value = _get_numpy_value(shape) + # input_shape_dims = list(input_shape.dims) + # if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in input_shape_dims): + # return None + shape_value = state.get_shape_value(shape) if shape_value is None: return None - target_shape_dims = shape_value.tolist() - if input_shape_dims == target_shape_dims: - # No need to check for special values like -1, 0, etc. here + # target_shape_dims = list(shape_value.dims) + # if input_shape_dims == target_shape_dims: + # No need to check for special values like -1, 0, etc. here + if _same_shape(input_shape, shape_value): return op.Identity(input) return None @@ -373,6 +468,9 @@ def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: start = _get_int_attribute(node, "start", 0) end = _get_int_attribute(node, "end", None) shape_slice = shape[start:end] + output = _get_output(node, 0) + if output is not None: + state.set_sym_value(output, ir.Shape(shape_slice)) if all(isinstance(d, int) for d in shape_slice): return op.Constant(value_ints=list(shape_slice)) return None @@ -459,6 +557,19 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: inputs = node.inputs if len(inputs) == 1: return op.Identity(inputs[0]) + # Track value of tensors that carry a shape value: + output = node.outputs[0] + if output is None: + return None + # Check axis attribute is 0 + axis = _get_int_attribute(node, "axis", None) + if axis != 0: + return None + shapes = [state.get_shape_value(input) for input in inputs] + if any(shape is None for shape in shapes): + return None + concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) # type: ignore[union-attr] + state.set_sym_value(output, concatenated) return None @@ -507,7 +618,10 @@ def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None if (expanded_shape := _get_numpy_value(node.inputs[1])) is None: # Target shape is not known. - return None + expanded_sym_shape = state.get_shape_value(node.inputs[1]) + if expanded_sym_shape is None or not _same_shape(input_shape, expanded_sym_shape): + return None + return op.Identity(input) if expanded_shape.ndim != 1: # Target shape must be a 1D tensor. Erroneous model. return None @@ -658,6 +772,27 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None +def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None: + def merge_dims(dim1, dim2): + if dim1 == dim2: + return dim1 + if not isinstance(dim1, ir.SymbolicDim): + return dim1 # Prefer int value over symbolic dim + if not isinstance(dim2, ir.SymbolicDim): + return dim2 + if dim1.value is None: + return dim2 + return dim1 + + if shape1 is None: + return shape2 + if shape2 is None: + return shape1 + if len(shape1) != len(shape2): + raise ValueError("Shapes must have the same rank.") + return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) + + class ConstantFolder: opset_imports: dict[str, int] @@ -723,7 +858,10 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: if output.name in output_types: inferred_type = output_types[output.name] # TODO: merge types, check for conflicts - output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type) + inferred_shape = ir.serde.deserialize_type_proto_for_shape( + inferred_type + ) + output.shape = _merge_shapes(output.shape, inferred_shape) output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: logger.debug( @@ -763,13 +901,8 @@ def new_constant(self, irvalue: ir.Value, value): value.shape, ) - node = ir.Node( - "", - "Constant", - inputs=[], - attributes=ir.convenience.convert_attributes({"value": tensor}), - num_outputs=1, - ) + attributes = ir.convenience.convert_attributes({"value": tensor}) + node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) return node def process_node(self, node: ir.Node): diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 8f2dc0026..b0df4dd54 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -486,6 +486,77 @@ def test_expand_identity(self): optimized = self._fold(model) self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + def test_expand_identity_symdim(self): + model = """ + + agraph (float[B, 256] x) => (float[B, 256] z) + { + b = Shape (x) + const_256 = Constant () + shape = Concat (b, const_256) + z = Expand (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_abs_symdim(self): + model = """ + + agraph (float[B, 256] x) => (float[B, 256] z) + { + b = Shape (x) + const_256 = Constant () + b_256 = Concat (b, const_256) + shape = Abs (b_256) + z = Expand (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_reshape_identity(self): + model = """ + + agraph (float[128, 256] x) => (float[128, 256] z) + { + shape = Constant () + z = Reshape (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_reshape_identity_symdim(self): + model = """ + + agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z) + { + b = Shape (y) + const_256 = Constant () + shape = Concat (b, const_256) + z = Reshape (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_gather_symdim(self): + model = """ + + agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z) + { + b_128 = Shape (y) + index_0 = Constant () + b = Gather (b_128, index_0) + const_256 = Constant () + shape = Concat (b, const_256) + z = Reshape (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + if __name__ == "__main__": unittest.main()