diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 1af6223b1..a664b59ee 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -50,6 +50,7 @@ "serialize_tensor_into", "serialize_tensor", "serialize_type_into", + "serialize_type", "serialize_value_into", "serialize_value", "SerdeError", @@ -1511,6 +1512,12 @@ def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtoc raise TypeError(f"Unsupported type: {from_}") +def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto: + type_proto = onnx.TypeProto() + serialize_type_into(type_proto, from_=type_protocol) + return type_proto + + @_capture_errors(lambda type_proto, from_: repr(from_)) def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None: value_field = type_proto.WhichOneof("value") diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py new file mode 100644 index 000000000..6140b06f7 --- /dev/null +++ b/onnxscript/optimizer/_constant_folding.py @@ -0,0 +1,731 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files. + +from __future__ import annotations + +import dataclasses +import logging +import math +from typing import Any, Callable, Sequence, Union + +import numpy as np +import onnx +import onnx.reference.ops + +import onnxscript.ir as ir +import onnxscript.ir._convenience as _convenience +import onnxscript.optimizer.constant_folding as constant_folding +import onnxscript.rewriter.pattern as orp + + +def is_control_flow_op(node: ir.Node) -> bool: + return any( + isinstance(attr, (ir.AttrGraph, ir.AttrGraphs)) for attr in node.attributes.values() + ) + + +def is_non_deterministic_op(node: ir.Node) -> bool: + return ( + node.op_type in constant_folding.non_deterministic_ops + and constant_folding.is_onnx_domain(node.domain) + ) + + +def is_constant_op(node: ir.Node) -> bool: + return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain( + node.domain + ) + + +_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT + +logger = logging.getLogger(__name__) + +# "Standard" evaluators are used to perform constant-folding. +# The API below works only for non-control-flow ops (ops without any graph-attributes). +# This currently used ONNX's reference implementation. But we could also +# use ORT's implementation if we want to. + + +class ReferenceEvaluator: + def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None: + try: + op_impl_class = onnx.reference.ops.load_op(domain, op, version) + return op_impl_class.eval # noqa: TRY300 + except Exception: + return None + + def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: + logger.debug("Evaluating %s::%s", domain, op) + evaluator = self.get_evaluator(domain, op, version) + if evaluator is None: + return None + return evaluator(*args, **kwargs) + + +_reference_evaluator = ReferenceEvaluator() + + +@dataclasses.dataclass +class Replacement: + """A replacement for a node in the graph.""" + + new_outputs: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] + + +class OptimizerState: + def __init__(self): + self._sym_value_map: dict[ir.Value, Any] = {} + + def get_sym_value(self, value: ir.Value | None) -> Any: + if value is None: + return None + return self._sym_value_map.get(value) + + def set_sym_value(self, value: ir.Value, sym_value: Any) -> None: + self._sym_value_map[value] = sym_value + + +# The "partial evaluators" below are non-standard evaluators. They are used to perform +# partial evaluation and/or static program analysis (abstract interpretation). + +# A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns +# a Replacement for the node or None (if no replacement is needed). It may also return just +# the ir.Value or ir.Values to replace the output values of the node, when the new nodes +# can be inferred from the RewriterContext used to build the new nodes. + +ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] +PartialEvaluatorFunction = Callable[ + [ir.Node, orp.RewriterContext, OptimizerState], ReturnValue +] + + +@dataclasses.dataclass +class PartialEvaluator: + """A class that represents a partial-evaluator for a particular op. + + It is applicable for a specific version range (min_version, max_version) of the op. + The min_version and max_version can be None, indicating that there is no version + constraint in that direction. + """ + + min_version: int | None + max_version: int | None + function: PartialEvaluatorFunction + + def valid_for(self, version: int) -> bool: + """Returns True if this evaluator is applicable for the given version.""" + return (self.min_version is None or version >= self.min_version) and ( + self.max_version is None or version <= self.max_version + ) + + +class PartialEvaluatorRegistry: + """A class that maintains a registry of evaluators for ops.""" + + def __init__(self): + self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} + + def lookup_evaluators(self, domain: str, opname: str, version: int): + evaluator_list = self.op_evaluators.get((domain, opname), []) + return [ + evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) + ] + + def register( + self, opname: str, domain: str = "", version=None + ) -> Callable[[PartialEvaluatorFunction], PartialEvaluatorFunction]: + if (domain, opname) in self.op_evaluators: + evaluator_list = self.op_evaluators[(domain, opname)] + else: + evaluator_list = [] + self.op_evaluators[(domain, opname)] = evaluator_list + if version is None: + min_version = None + max_version = None + elif isinstance(version, int): + min_version = version + max_version = version + elif isinstance(version, tuple): + min_version, max_version = version + + def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: + evaluator_list.append(PartialEvaluator(min_version, max_version, function)) + return function + + return decorator + + +registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() + +register = registry.register + + +def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None: + if val is None: + return None + const_value = val.const_value + if const_value is not None: + return const_value.numpy() + return None + + +def _get_bool_value(val: ir.Value | None) -> bool | None: + if val is None: + return None + value = _get_numpy_value(val) + if value is None: + return None + # TODO: cleanup following checks, which seem redundant. But need to also ensure + # the invariant when setting the value (and also use clearly defined representation + # types in evaluators, such a reference-evaluator). + if isinstance(value, bool): + return value + if isinstance(value, np.bool_): + return bool(value) + if isinstance(value, np.ndarray) and value.size == 1 and value.dtype == bool: + return value.item(0) + return None + + +def _get_input(node: ir.Node, index: int) -> ir.Value | None: + if index < len(node.inputs): + return node.inputs[index] + return None + + +def _get_output(node: ir.Node, index: int) -> ir.Value | None: + if index < len(node.outputs): + return node.outputs[index] + return None + + +def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None: + if type is not None: + # TODO: merge types + value.type = type + + +def _get_input_element_type(node: ir.Node, index: int) -> int: + input = _get_input(node, index) + if input is not None and input.type is not None: + return input.type.dtype.value + return ir.DataType.UNDEFINED.value + + +def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: + if name in node.attributes: + attr = node.attributes[name] + if not isinstance(attr, ir.Attr): + return None + attr_val = attr.value + if isinstance(attr_val, int): + return attr_val + # This is an invalid model: attribute has invalid/unexpected type. + # For now, we just return None. We could raise an error too. + return None + return default + + +# TODO(rama): The following should not be necessary. Generic incremental shape-inference +# should handle this. This essentially implements type/shape-inference for Cast op. +@register("Cast") +def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = _get_input(node, 0) + output = _get_output(node, 0) + if input is not None and output is not None: + _update_type(output, input.type) + return None + + +@register("CastLike") +def cast_like(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input0 = node.inputs[0] + source_element_type = _get_input_element_type(node, 0) + target_element_type = _get_input_element_type(node, 1) + + if target_element_type == ir.DataType.UNDEFINED: + return None + if source_element_type == target_element_type: + return op.Identity(input0) + return op.Cast(input0, to=target_element_type) + + +@register("Shape") +def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = node.inputs[0] + if input is None: + return None + shape = input.shape + if shape is None: + return None + start = _get_int_attribute(node, "start", 0) + end = _get_int_attribute(node, "end", None) + shape_slice = shape[start:end] + if all(isinstance(d, int) for d in shape_slice): + return op.Constant(value_ints=list(shape_slice)) + return None + + +@register("Size") +def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = _get_input(node, 0) + if input is None: + return None + shape = input.shape + if shape is None: + return None + size = 1 + for d in shape: + if not isinstance(d, int): + return None + size *= d + return op.Constant(value_int=size) + + +@register("If") +def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + cond_input = _get_input(node, 0) + cond = _get_bool_value(cond_input) + if cond is not None: + # cond is a constant-value: inline the branch + branch = "then_branch" if cond else "else_branch" + graph_attr = node.attributes.get(branch, None) + if not isinstance(graph_attr, ir.AttrGraph): + return None + graph: ir.Graph = graph_attr.value + formal_outs = graph.outputs + actual_outs = node.outputs + renamings = { + formal.name: actual.name + for formal, actual in zip(formal_outs, actual_outs) + if actual is not None + } + # TODO: Extend renaming to intermediate values. + + def rename(name): + return renamings.get(name, name) + + graph_nodes = list(graph) + graph.remove(graph_nodes) + for sub_node in graph_nodes: + # TODO: handle renaming inside subgraphs in nodes + for v in sub_node.outputs: + v.name = rename(v.name) + # Avoid name collision. + sub_node.name = f"{node.name}_{sub_node.name}" + + # TODO: we should handle initializers as well! + return Replacement(formal_outs, graph_nodes) + return None + + +@register("Identity") +def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + del op + input = node.inputs[0] + output = node.outputs[0] + if input is not None and output is not None: + state.set_sym_value(output, input) + return None + + +@register("SequenceConstruct") +def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + del op + output = node.outputs[0] + if output is not None: + state.set_sym_value(output, list(node.inputs)) + return None + + +@register("ConcatFromSequence") +def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = node.inputs[0] + inputs = state.get_sym_value(input) + if any(x is None for x in inputs): + return None + new_axis = _get_int_attribute(node, "new_axis", 0) + axis = _get_int_attribute(node, "axis", None) + if axis is None: + return None + if input is not None and isinstance(inputs, list): + if new_axis == 0: + logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) + return op.Concat(*inputs, axis=axis) + if new_axis == 1: + # Unsqueeze the inputs with concat axis if new_axis is 1 + axis_value = op.Constant(value_int=axis) + unsqueezed_inputs = [] + for node_input in inputs: + unsqueezed_input = op.Unsqueeze( + node_input, axis_value, outputs=[f"{node_input.name}_unsqueeze"] + ) + unsqueezed_inputs.append(unsqueezed_input) + # Send unsqueezed outputs to Concat + logger.debug( + "ConcatFromSequence => Concat %s", [x.name for x in unsqueezed_inputs] + ) + return op.Concat(*unsqueezed_inputs, axis=axis) + return None + + +@register("SplitToSequence") +def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Rewriting pattern. + + From + + splits = onnx::SplitToSequence(input, split, axis=axis) + + to + + split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + or + + split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + where number of output tensors in `splits` is statically known. + onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. + This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. + """ + input = node.inputs[0] + split = node.inputs[1] + output = node.outputs[0] + + if input is None or split is None or output is None: + return None + + axis = _get_int_attribute(node, "axis", 0) + if axis is None: + return None + shape = input.shape + if shape is None: + return None + rank = len(shape) + if axis < 0: + axis = axis + rank + if axis < 0 or axis >= rank: + return None + split_dimension_size = shape[axis] + if not isinstance(split_dimension_size, int): + return None + + split_value = _get_numpy_value(split) + if split_value is None: + return None + assert isinstance(split_value, np.ndarray) + + if split_value.ndim == 0: + # split into chunks all of size 'split' if possible. + num_outputs = math.ceil(split_dimension_size / split_value.item()) + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split( + input, axis=axis, num_outputs=num_outputs, outputs=split_outputs + ) + elif split_value.ndim == 1: + # split into 'size(split)' chunks + num_outputs = split_value.size + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split(input, split, axis=axis, outputs=split_outputs) + else: + return None + + keepdims = _get_int_attribute(node, "keepdims", 1) + if keepdims is None: + return None + if keepdims == 0: + # squeeze the split dimension if keepdims is 0 + axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) + squeezed_values = [] + for i in range(num_outputs): + squeezed = op.Squeeze( + split_values[i], axis_val, outputs=[f"{split_outputs[i]}_squeeze"] + ) + squeezed_values.append(squeezed) + split_values = squeezed_values + + logger.debug("SplitToSequence => Split + SequenceConstruct") + + if isinstance(split_values, ir.Value): + split_values = [split_values] + return op.SequenceConstruct(*split_values) + + +@register("SequenceAt") +def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = node.inputs[0] + position = node.inputs[1] + output = node.outputs[0] + if input is not None and position is not None: + input_vals = state.get_sym_value(input) + position_val = _get_numpy_value(position) + if isinstance(input_vals, list) and position_val is not None: + if position_val.size != 1: + return None + position_val = position_val.item() + try: + result = input_vals[position_val] # type: ignore[index] + except IndexError: + return None + state.set_sym_value(output, result) + logger.debug("SequenceAt %s => %s", input.name, result.name) + return op.Identity(result) + return None + + +class ConstantFolder: + opset_imports: dict[str, int] + + def __init__( + self, + external_data_folder: str, + do_shape_inference: bool, + ) -> None: + self._external_data_folder = external_data_folder + self._do_shape_inference = do_shape_inference + self._init() + + def _init(self) -> None: + self.counts: dict[str, int] = {} + self.sizes: dict[str, int] = {} + self.modified = False + self._state = OptimizerState() + + def _do_inference(self, node: ir.Node) -> None: + output_types = {} + + # TODO: handle optional inputs + def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: + value = _get_numpy_value(x) + if isinstance(value, np.ndarray) and value.size < 20: + return onnx.numpy_helper.from_array(value, x.name) + return None + + def get_type(value: ir.Value) -> onnx.TypeProto | None: + if value.type is not None: + type_proto = ir.serde.serialize_type(value.type) + if value.shape is not None: + ir.serde.serialize_shape_into(type_proto, value.shape) + return type_proto + return None + + input_types = {x.name: get_type(x) for x in node.inputs if x is not None} + input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None} + input_data = {k: v for k, v in input_data.items() if v is not None} + if any(t is None for t in input_types.values()): + logger.debug( + "Skipping shape inference for node %s due to missing input type.", + node.name, + ) + else: + # TODO: pass in constant values, ir_version + try: + schema = onnx.defs.get_schema( + node.op_type, self.opset_imports[node.domain], node.domain + ) + output_types = onnx.shape_inference.infer_node_outputs( + schema, + ir.serde.serialize_node(node), + input_types, # type: ignore[arg-type] + input_data, # type: ignore[arg-type] + ) + for output in node.outputs: + if output.name in output_types: + inferred_type = output_types[output.name] + # TODO: merge types, check for conflicts + output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type) + output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) + except Exception as e: + logger.debug( + "Skipping shape inference for node %s due to exception: %s", + node.name, + e, + ) + + def new_constant(self, irvalue: ir.Value, value): + # TODO(rama): Why do we need the conversion below? + if isinstance(value, (int, float, np.ScalarType)): + value = np.array(value) + + if not isinstance(value, np.ndarray): + # ONNX does not have a way to represent non-tensor constants, eg. a sequence. + # So, a constant-value of type sequence is not folded, but it can be used + # to optimize subsequent operations when possible. + logger.info( + "Skip storing constant folded value %s due to unsupported type %s.", + irvalue.name, + type(value), + ) + return None + + irvalue.const_value = _convenience.tensor(value) + + if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: + logger.info( + "Skip storing constant folded nvalue %s due to large size %s.", + irvalue.name, + value.nbytes, + ) + return None + + tensor = onnx.numpy_helper.from_array(value, irvalue.name) + + logger.debug( + "New constant for value %s dtype: %s shape: %s", + irvalue.name, + value.dtype, + value.shape, + ) + + attributes = _convenience.convert_attributes({"value": tensor}) + node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) + return node + + def process_node(self, node: ir.Node): + for i, value in enumerate(node.inputs): + sym_value = self._state.get_sym_value(value) + if isinstance(sym_value, ir.Value): + node.replace_input_with(i, sym_value) + # TODO(rama): consider merging type/other info from both values + + # Do incremental shape inference + if self._do_shape_inference and not is_control_flow_op(node): + self._do_inference(node) + + if node.domain not in self.opset_imports: + return None + version = self.opset_imports[node.domain] + op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) + for optimizer in op_optimizers: + assert optimizer + context = orp.RewriterContext() + output = optimizer(node, context, self._state) + if output is not None: + if isinstance(output, Replacement): + return output + if isinstance(output, ir.Value): + output = [output] + return Replacement(output, context.nodes) + + if is_control_flow_op(node) or is_non_deterministic_op(node): + return None + + input_values = [_get_numpy_value(x) for x in node.inputs] + if any(x is None for x in input_values): + return None + + # Filter out bfloat16 cases? + def convert(av): + if isinstance(av, ir.AttrTensor): + return ir.serde.serialize_tensor(av.value) + return av.value + + attr_values = {name: convert(attr) for name, attr in node.attributes.items()} + outputs = _reference_evaluator.evaluate( + node.domain, node.op_type, version, *input_values, **attr_values + ) + + if outputs is None: + return None + if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): + replacement = self.new_constant(node.outputs[0], outputs) + if is_constant_op(node) or replacement is None: + return None + return Replacement(replacement.outputs, [replacement]) + else: + logger.warning( + "Skipping constant folding for op %s with multiple outputs.", node.op_type + ) + return None + + def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): + logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) + + # TODO: what about new opset_imports? + old_values = node.outputs + new_values = replacement.new_outputs + for old_value, new_value in zip(old_values, new_values): + # Propagate relevant info from old value to new value + # TODO(Rama): Perhaps we should merge old and new types. As of now, new + # values don't have type information. Note that this could be a problem + # for semantics-altering rewrite-rules: we should allow users to override + # this for such rules. + new_value.type = old_value.type + new_value.shape = old_value.shape + new_value.const_value = old_value.const_value + new_value.name = old_value.name + + # Reconnect the users of the deleted node to use the new outputs + _convenience.replace_all_uses_with(old_values, new_values) + # Update graph/function outputs if the node generates output + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(root.outputs): + if graph_or_function_output in replacement_mapping: + root.outputs[idx] = replacement_mapping[graph_or_function_output] + + # insert new nodes after the index node + root.insert_after(node, replacement.new_nodes) + root.remove(node, safe=True) + + # TODO: track statistics about replaced nodes and sizes of new constants + + def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None: + if isinstance(attr, ir.Attr): + if attr.type == ir.AttributeType.GRAPH: + self.visit_graph(attr.value) # type: ignore[arg-type] + elif attr.type == ir.AttributeType.GRAPHS: + for graph in attr.value: + self.visit_graph(graph) # type: ignore[arg-type] + + def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): + replacement = self.process_node(node) + if replacement is None: + # No change. Process attributes. + for attr in node.attributes.values(): + self.visit_attribute(attr) + return None + else: + self.replace_node(node, replacement, root) + + def visit_graph(self, graph: ir.Graph) -> None: + for node in graph: + self.visit_node(node, graph) + + def visit_model(self, model: ir.Model) -> None: + self._init() + self.opset_imports = model.opset_imports + self.visit_graph(model.graph) + # TODO(rama): handle functions + # Pending decision on whether we want to specialize functions or not. + + +def fold_constants( + model: ir.Model, + external_data_folder: str = "", + *, + onnx_shape_inference: bool = False, +) -> bool: + """ + Applies constant folding optimization to the model. + Returns true iff the model was modified. + """ + folder = ConstantFolder( + external_data_folder, + onnx_shape_inference, + ) + folder.visit_model(model) + for op in folder.counts: + logger.info( + "Constant-folded '%s' %s times, with %s size.", + op, + folder.counts[op], + folder.sizes[op], + ) + return folder.modified diff --git a/onnxscript/optimizer/constant_folding_test.py b/onnxscript/optimizer/constant_folding_test.py index 8fc7fe4a0..7629653d4 100644 --- a/onnxscript/optimizer/constant_folding_test.py +++ b/onnxscript/optimizer/constant_folding_test.py @@ -3,12 +3,29 @@ import unittest import onnx +import parameterized import pytest -from onnxscript import optimizer +import onnxscript.optimizer as optimizer +from onnxscript.ir import serde +from onnxscript.optimizer import _constant_folding, constant_folding +@parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) class FoldConstantsTest(unittest.TestCase): + def _fold(self, model: onnx.ModelProto, onnx_shape_inference=False): + if self.using_ir: + ir_model = serde.deserialize_model(model) + _constant_folding.fold_constants( + ir_model, onnx_shape_inference=onnx_shape_inference + ) + optimizer.remove_unused_nodes(ir_model) + return serde.serialize_model(ir_model) + else: + constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) + optimizer.remove_unused_nodes(model) + return model + def test_fold_add(self): model = onnx.parser.parse_model( """ @@ -20,7 +37,7 @@ def test_fold_add(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "four") @@ -36,7 +53,7 @@ def test_fold_cast_like(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "four") @@ -53,7 +70,7 @@ def test_fold_shape(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "four") @@ -70,7 +87,7 @@ def test_fold_shape_slice(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "four") @@ -91,7 +108,7 @@ def test_fold_if_cond(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 1) self.assertEqual(optimized.graph.node[0].output[0], "z") self.assertEqual(optimized.graph.node[0].op_type, "Mul") @@ -117,7 +134,7 @@ def test_fold_inside_if_branch(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 1) then_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "then_branch") self.assertEqual(len(then_graph.node), 2) @@ -144,7 +161,7 @@ def test_fold_if_propagate(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) print(onnx.printer.to_text(optimized)) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "m_square") @@ -161,7 +178,7 @@ def test_fold_redundant_cast(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model, onnx_shape_inference=True) self.assertEqual(len(optimized.graph.node), 2) def test_fold_redundant_cast2(self): @@ -174,7 +191,7 @@ def test_fold_redundant_cast2(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model, onnx_shape_inference=True) self.assertEqual(len(optimized.graph.node), 1) self.assertEqual(optimized.graph.node[0].op_type, "Identity") self.assertEqual(optimized.graph.node[0].output[0], "z") @@ -196,7 +213,7 @@ def test_fold_undefined_vars(self): """ ) # No optimizations expected. Just make sure it doesn't crash. - optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) + optimized = self._fold(model, onnx_shape_inference=False) self.assertEqual(len(optimized.graph.node), 6) def test_shape_inference(self): @@ -222,7 +239,7 @@ def test_shape_inference(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model, onnx_shape_inference=True) print(onnx.printer.to_text(optimized)) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "C") @@ -274,7 +291,7 @@ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_ split_3 = SequenceAt (splits, int64_3) } """ - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(len(optimized.graph.node[-2].output), 4) self.assertEqual(optimized.graph.node[-2].op_type, "Split") @@ -301,7 +318,7 @@ def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_sp } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 3) self.assertEqual(len(optimized.graph.node[-2].output), 3) self.assertEqual(optimized.graph.node[-2].op_type, "Split") @@ -328,77 +345,12 @@ def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_ } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 7) self.assertEqual(len(optimized.graph.node[1].output), 3) self.assertEqual(optimized.graph.node[1].op_type, "Split") self.assertEqual(len([n for n in optimized.graph.node if n.op_type == "Squeeze"]), 3) - def test_static_split_to_sequence_with_uneven_split(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], - producer_name: "pytorch", - producer_version: "2.2.0" -> -main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val) - < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1> -{ - _val_1 = Constant () - _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1) - _val_3 = Constant () - getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3) - _val_5 = Constant () - getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5) - return_val = Concat (getitem_1, getitem) -} -< - domain: "pkg.onnxscript.torch_lib", - opset_import: ["" : 18] -> -aten_split (self, split_size) => (return_val) -{ - return_val = SplitToSequence (self, split_size) -} -< - domain: "pkg.onnxscript.torch_lib", - opset_import: ["" : 18] -> -aten_getitem (self, i) => (return_val) -{ - return_val = SequenceAt (self, i) -} -< - domain: "pkg.onnxscript.torch_lib.common", - opset_import: ["" : 18] -> -Rank (input) => (return_val) -{ - tmp = Shape (input) - return_val = Size (tmp) -} -< - domain: "pkg.onnxscript.torch_lib.common", - opset_import: ["" : 18] -> -IsScalar (input) => (return_val) -{ - tmp = Shape (input) - tmp_0 = Size (tmp) - tmp_1 = Constant () - return_val = Equal (tmp_0, tmp_1) -} - """ - ) - optimized = optimizer.optimize(model, onnx_shape_inference=False) - - print(onnx.printer.to_text(optimized)) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(len(optimized.graph.node[0].output), 2) - self.assertEqual(optimized.graph.node[0].op_type, "Split") - def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( self, ): @@ -408,14 +360,14 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( ir_version: 8, opset_import: ["" : 18] > -func (float[1,3] x) => ( return_val) { +func (float[1,3] x) => (float[1,3] return_val) { const = Constant () splits = SplitToSequence (x, const) return_val = ConcatFromSequence (splits) } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 3) self.assertEqual(optimized.graph.node[2].op_type, "Concat") onnx.checker.check_model(optimized) @@ -429,14 +381,14 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( ir_version: 8, opset_import: ["" : 18] > -func (float[1,3] x) => ( return_val) { +func (float[1,3] x) => (float[1,3] return_val) { const = Constant () splits = SplitToSequence (x, const) return_val = ConcatFromSequence (splits) } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 7) self.assertEqual(optimized.graph.node[6].op_type, "Concat") onnx.checker.check_model(optimized) diff --git a/onnxscript/optimizer/optimizer_test.py b/onnxscript/optimizer/optimizer_test.py new file mode 100644 index 000000000..57f6f3a80 --- /dev/null +++ b/onnxscript/optimizer/optimizer_test.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import onnx + +import onnxscript.optimizer as optimizer + + +class OptimizerTest(unittest.TestCase): + def test_static_split_to_sequence_with_uneven_split(self): + model = onnx.parser.parse_model( + """ + < + ir_version: 8, + opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], + producer_name: "pytorch", + producer_version: "2.2.0" + > + main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val) + < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1> + { + _val_1 = Constant () + _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1) + _val_3 = Constant () + getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3) + _val_5 = Constant () + getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5) + return_val = Concat (getitem_1, getitem) + } + + + aten_split (self, split_size) => (return_val) + { + return_val = SplitToSequence (self, split_size) + } + + + aten_getitem (self, i) => (return_val) + { + return_val = SequenceAt (self, i) + } + + + Rank (input) => (return_val) + { + tmp = Shape (input) + return_val = Size (tmp) + } + + + IsScalar (input) => (return_val) + { + tmp = Shape (input) + tmp_0 = Size (tmp) + tmp_1 = Constant () + return_val = Equal (tmp_0, tmp_1) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph.node[0].output), 2) + self.assertEqual(optimized.graph.node[0].op_type, "Split") + + +if __name__ == "__main__": + unittest.main()