From 4437044d8e3364bd60df2fd8ca90113306cb6cf8 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 20 Jun 2024 17:16:40 -0700 Subject: [PATCH 01/22] First version of updated partial evaluators --- onnxscript/optimizer/evaluator_ir.py | 416 +++++++++++++++++++++++++++ 1 file changed, 416 insertions(+) create mode 100644 onnxscript/optimizer/evaluator_ir.py diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py new file mode 100644 index 000000000..63676b830 --- /dev/null +++ b/onnxscript/optimizer/evaluator_ir.py @@ -0,0 +1,416 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- + +from __future__ import annotations + +import dataclasses +import logging +import math +from typing import Any, Callable, Protocol, Sequence, Union + +import numpy as np +import onnx +import onnx.reference.ops + +import onnxscript.ir as ir +from onnxscript.utils.utils import ( + get_node_attr_value, +) + +logger = logging.getLogger(__name__) + +# "Standard" evaluators are used to perform constant-folding. +# The API below works only for non-control-flow ops (ops without any graph-attributes). +# This currently used ONNX's reference implementation. But we could also +# use ORT's implementation if we want to. + + +class ReferenceEvaluator: + def get_evaluator(self, domain: str, op: str, version: int) -> callable | None: + try: + op_impl_class = onnx.reference.ops.load_op(domain, op, version) + return op_impl_class.eval # noqa: TRY300 + except Exception: + return None + + def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: + logger.debug("Evaluating %s::%s", domain, op) + evaluator = self.get_evaluator(domain, op, version) + if evaluator is None: + return None + return evaluator(*args, **kwargs) + + +reference_evaluator = ReferenceEvaluator() + +# The "partial evaluators" below are non-standard evaluators. They are used to perform +# partial evaluation and/or static program analysis (abstract interpretation). + + +# class IRContext(Protocol): +# """A class that represents the context for partial evaluation. + +# This is a placeholder, subject to simplification when a proper IR is defined. +# """ + +# def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... + +# def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... + +# def input_const_value(self, node: onnx.NodeProto, index: int) -> ir.ConcreteValue: ... + +# def input_shape( +# self, node: onnx.NodeProto, index: int +# ) -> onnx.TensorShapeProto | None: ... + +# def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: ... + +# def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: ... + +# def lookup_version(self, domain: str) -> int: ... + +# def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict: ... + +# def new_constant(self, name: str, value: Any) -> Sequence[onnx.NodeProto] | None: ... + + +# A partial-evaluator function takes an IRContext and a node, and returns a list of +# replacement nodes or None (if no replacement is needed). We return None instead +# of [input node] so the caller is aware that the node is not replaced. If the node +# is replaced, the caller will recursively visit the replacement nodes to process them. + +PartialEvaluatorFunction = Callable[[ir.Node], Union[Sequence[ir.Node], None]] + +@dataclasses.dataclass +class PartialEvaluator: + """A class that represents a partial-evaluator for a particular op. + + It is applicable for a specific version range (min_version, max_version) of the op. + The min_version and max_version can be None, indicating that there is no version + constraint in that direction. + """ + + min_version: int | None + max_version: int | None + function: PartialEvaluatorFunction + + def valid_for(self, version: int) -> bool: + """Returns True if this evaluator is applicable for the given version.""" + return (self.min_version is None or version >= self.min_version) and ( + self.max_version is None or version <= self.max_version + ) + + +class PartialEvaluatorRegistry: + """A class that maintains a registry of evaluators for ops.""" + + def __init__(self): + self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} + + def lookup_evaluators(self, domain: str, opname: str, version: int): + evaluator_list = self.op_evaluators.get((domain, opname), []) + return [ + evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) + ] + + def register(self, opname: str, domain: str = "", version=None): + if (domain, opname) not in self.op_evaluators: + evaluator_list = [] + self.op_evaluators[(domain, opname)] = evaluator_list + else: + evaluator_list = self.op_evaluators[(domain, opname)] + if version is None: + min_version = None + max_version = None + elif isinstance(version, int): + min_version = version + max_version = version + elif isinstance(version, tuple): + min_version, max_version = version + + def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: + evaluator_list.append(PartialEvaluator(min_version, max_version, function)) + return function + + return decorator + + +registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() + +register = registry.register + +def get_numpy_value(val: ir.Value) -> np.ndarray | None: + const_value = val.const_value + if hasattr(const_value, "numpy"): + return const_value.numpy() + return None + +def get_bool_value(val) -> bool | None: + if isinstance(val, bool): + return val + if isinstance(val, np.bool_): + return bool(val) + if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: + return val.item(0) + return None + + +def get_size_info(type: onnx.TypeProto) -> np.ndarray | None: + if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): + size = 1 + for d in type.tensor_type.shape.dim: + size *= d.dim_value + return np.array(size, dtype=np.int64) + return None + + +def get_dim_info(type: ir.Type, dim: int) -> int | None: + if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + rank = len(type.tensor_type.shape.dim) + dim = dim if dim >= 0 else dim + rank + if dim < 0 or dim >= rank: + return None + if type.tensor_type.shape.dim[dim].HasField("dim_value"): + return type.tensor_type.shape.dim[dim].dim_value + return None + +def getInput(node:ir.Node, index: int) -> ir.Value | None: + if index < len(node.inputs): + return node.inputs[index] + return None + + +@register("Cast") +def cast(node: ir.Node) -> Sequence[ir.Node] | None: + if context.input_shape(node, 0) is not None: + output_value = context.get_output(node, 0) + output_value.type = onnx.TypeProto() + output_value.type.CopyFrom(context.input_type(node, 0)) + output_value.type.tensor_type.elem_type = node.attribute[0].i + return None + + +@register("CastLike") +def cast_like(op, node: ir.Node): + input0 = node.inputs[0] + input1 = node.inputs[1] + source_element_type = input0.type.dtype.value + target_element_type = input1.type.dtype.value + + if target_element_type is None: + return None + if source_element_type == target_element_type: + return op.Identity(input0) + return op.Cast(input0, to=target_element_type) + + +@register("Shape") +def shape(op, node: ir.Node): + del op + input = node.inputs[0] + shape = input.shape + if shape is None: + return None + start = node.attributes.get("start", 0) + end = node.attributes.get("end", None) + shape_slice = shape.dim[start:end] + if all(d.HasField("dim_value") for d in shape_slice): + return np.array([d.dim_value for d in shape_slice], dtype=np.int64) + return None + + +@register("Size") +def size(op, node: ir.Node): + del op + shape = node.inputs[0].shape + if shape is None: + return None + size = 1 + for d in shape: + if not isinstance(d, int): + return None + size *= d + return np.array(size, dtype=np.int64) + +@register("If") +def if_op(context: IRContext, node: onnx.NodeProto): + cond = context.input_const_value(node, 0) + if cond is ir.NotConstant: + # Visitor will recursively visit subgraphs to constant-fold them. + return None + cond = get_bool_value(cond) + if cond is not None: + # cond is a constant-value: inline the branch + branch = "then_branch" if cond else "else_branch" + graph = onnx.helper.get_node_attr_value(node, branch) + + formal_outs = list(graph.output) + actual_outs = node.output + renamings = { + formal.name: actual + for formal, actual in zip(formal_outs, actual_outs) + if actual != "" + } + # TODO: Extend renaming to intermediate values. + + def rename(name): + return renamings.get(name, name) + + for sub_node in graph.node: + # TODO: handle renaming inside subgraphs in nodes + sub_node.input[:] = [rename(name) for name in sub_node.input] + sub_node.output[:] = [rename(name) for name in sub_node.output] + # Avoid name collision. + sub_node.name = f"{node.name}_{sub_node.name}" + + # TODO: we should handle initializers as well! + return list(graph.node) + return None + + +@register("Identity") +def identity(op, node: ir.Node): + del op + input = node.inputs[0] + output = node.outputs[0] + if input is not None and output is not None: + output.symbolic_value = input + return None + + +@register("SequenceConstruct") +def sequence_construct(op, node: ir.Node): + del op + output = node.outputs[0] + if output is not None: + output.symbolic_value = list(node.inputs) + return None + + +@register("ConcatFromSequence") +def concat_from_sequence(op, node: ir.Node): + input = node.inputs[0] + inputs = input.symbolic_value + if any(x is None for x in inputs): + return None + new_axis = node.attributes.get("new_axis", 0) + axis = node.attributes["axis"] + if input is not None and isinstance(input.symbolic_value, list): + if new_axis == 0: + logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) + return op.Concat(*inputs, axis=axis) + if new_axis == 1: + # Unsqueeze the inputs with concat axis if new_axis is 1 + axis_value = op.Constant(value_int=axis) + unsqueezed_inputs = [] + for node_input in inputs: + unsqueezed_input = op.Unsqueeze(node_input, axis_value, output=[f"{node_input.name}_unsqueeze"]) + unsqueezed_inputs.append(unsqueezed_input) + # Send unsqueezed outputs to Concat + logger.debug( + "ConcatFromSequence => Concat %s", + [x.name for x in unsqueezed_inputs] + ) + return op.Concat(*unsqueezed_inputs, axis=axis) + return None + + +@register("SplitToSequence") +def split_to_sequence(op, node: ir.Node): + """Rewriting pattern. + + From + + splits = onnx::SplitToSequence(input, split, axis=axis) + + to + + split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + or + + split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + where number of output tensors in `splits` is statically known. + onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. + This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. + """ + input = node.inputs[0] + split = node.inputs[1] + output = node.outputs[0] + + if input is None or split is None or output is None: + return None + + axis = node.attributes.get("axis", 0) + shape = input.shape + if shape is None: + return None + rank = len(shape) + if axis < 0: + axis = axis + rank + if axis < 0 or axis >= rank: + return None + split_dimension_size = shape[axis] + if not isinstance(split_dimension_size, int): + return None + + split_value = get_numpy_value(split) + if split_value is None: + return None + assert isinstance(split_value, np.ndarray) + + if split_value.ndim == 0: + # split into chunks all of size 'split' if possible. + num_outputs = math.ceil(split_dimension_size / split_value.item()) + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split(input, axis=axis, num_outputs=num_outputs, output=split_outputs) + elif split_value.ndim == 1: + # split into 'size(split)' chunks + num_outputs = split_value.size + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split(input, split, axis=axis, output=split_outputs) + else: + return None + + keepdims = node.attributes.get("keepdims", 1) + if keepdims == 0: + # squeeze the split dimension if keepdims is 0 + axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) + squeezed_values = [] + for i in range(num_outputs): + squeezed = op.Squeeze(split_values[i], axis_val, output=[f"{split_outputs[i]}_squeeze"]) + squeezed_values.append(squeezed) + split_values = squeezed_values + + logger.debug("SplitToSequence => Split + SequenceConstruct") + + return op.SequenceConstruct(*split_values) + # return [split_node, *squeeze_nodes, node] + + +@register("SequenceAt") +def sequence_at(op, node: ir.Node): + input = node.inputs[0] + position = node.inputs[1] + output = node.outputs[0] + if input is not None and position is not None: + input_vals = input.symbolic_value + position_val = get_numpy_value(position) + if isinstance(input_vals, list) and position_val is not None: + if position_val.size != 1: + return None + position_val = position_val.item() + try: + result = input_vals[position_val] + except IndexError: + return None + output.symbolic_value = result + logger.debug("SequenceAt %s => %s", input.name, result.name) + return op.Identity(result) + return None From 32ff7155f3ded6b2140355d98e17ea2d6c6e3291 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 21 Jun 2024 17:46:25 -0700 Subject: [PATCH 02/22] Migrate constant-folder to new IR --- onnxscript/optimizer/evaluator_ir.py | 220 ++++++++++++++++++++++----- 1 file changed, 179 insertions(+), 41 deletions(-) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index 63676b830..416286f56 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -15,10 +15,18 @@ import onnx.reference.ops import onnxscript.ir as ir +import onnxscript.ir._convenience as _convenience +import onnxscript.optimizer.constant_folding as constant_folding +import onnxscript.rewriter.pattern from onnxscript.utils.utils import ( get_node_attr_value, ) +is_control_flow_op = constant_folding.is_control_flow_op +is_non_deterministic_op = constant_folding.is_non_deterministic_op +is_constant_op = constant_folding.is_constant_op +_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT + logger = logging.getLogger(__name__) # "Standard" evaluators are used to perform constant-folding. @@ -147,7 +155,12 @@ def get_numpy_value(val: ir.Value) -> np.ndarray | None: return const_value.numpy() return None -def get_bool_value(val) -> bool | None: +def get_bool_value(val: ir.Value | None) -> bool | None: + if val is None: + return None + val = get_numpy_value(val) + if val is None: + return None if isinstance(val, bool): return val if isinstance(val, np.bool_): @@ -157,39 +170,27 @@ def get_bool_value(val) -> bool | None: return None -def get_size_info(type: onnx.TypeProto) -> np.ndarray | None: - if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): - size = 1 - for d in type.tensor_type.shape.dim: - size *= d.dim_value - return np.array(size, dtype=np.int64) - return None - - -def get_dim_info(type: ir.Type, dim: int) -> int | None: - if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - rank = len(type.tensor_type.shape.dim) - dim = dim if dim >= 0 else dim + rank - if dim < 0 or dim >= rank: - return None - if type.tensor_type.shape.dim[dim].HasField("dim_value"): - return type.tensor_type.shape.dim[dim].dim_value - return None - def getInput(node:ir.Node, index: int) -> ir.Value | None: if index < len(node.inputs): return node.inputs[index] return None +def getOutput(node:ir.Node, index: int) -> ir.Value | None: + if index < len(node.outputs): + return node.outputs[index] + return None + +def updateType(value: ir.Value, type: ir.TypeProtocol) -> None: + # TODO: merge types + value.type = type @register("Cast") def cast(node: ir.Node) -> Sequence[ir.Node] | None: - if context.input_shape(node, 0) is not None: - output_value = context.get_output(node, 0) - output_value.type = onnx.TypeProto() - output_value.type.CopyFrom(context.input_type(node, 0)) - output_value.type.tensor_type.elem_type = node.attribute[0].i + # This should not be necessary. Generic incremental shape-inference should handle this. + input = getInput(node, 0) + output = getOutput(node, 0) + if input is not None and output is not None: + updateType(output, input.type) return None @@ -236,38 +237,36 @@ def size(op, node: ir.Node): return np.array(size, dtype=np.int64) @register("If") -def if_op(context: IRContext, node: onnx.NodeProto): - cond = context.input_const_value(node, 0) - if cond is ir.NotConstant: - # Visitor will recursively visit subgraphs to constant-fold them. - return None +def if_op(op, node: ir.Node): + cond = getInput(node, 0) cond = get_bool_value(cond) if cond is not None: # cond is a constant-value: inline the branch branch = "then_branch" if cond else "else_branch" - graph = onnx.helper.get_node_attr_value(node, branch) - - formal_outs = list(graph.output) - actual_outs = node.output + graph = node.attributes.get(branch, None) + if graph is None: + return None + formal_outs = graph.outputs + actual_outs = node.outputs renamings = { - formal.name: actual + formal.name: actual.name for formal, actual in zip(formal_outs, actual_outs) - if actual != "" + if actual is not None } # TODO: Extend renaming to intermediate values. def rename(name): return renamings.get(name, name) - for sub_node in graph.node: + for sub_node in graph: # TODO: handle renaming inside subgraphs in nodes - sub_node.input[:] = [rename(name) for name in sub_node.input] - sub_node.output[:] = [rename(name) for name in sub_node.output] + for v in sub_node.outputs: + v.name = rename(v.name) # Avoid name collision. sub_node.name = f"{node.name}_{sub_node.name}" # TODO: we should handle initializers as well! - return list(graph.node) + return list(graph) return None @@ -414,3 +413,142 @@ def sequence_at(op, node: ir.Node): logger.debug("SequenceAt %s => %s", input.name, result.name) return op.Identity(result) return None + + +class ConstantFolder: + opset_imports: dict[str, int] + + def new_constant(self, irvalue: ir.Value, value): + # TODO(rama): Why do we need the conversion below? + if isinstance(value, (int, float, np.ScalarType)): + value = np.array(value) + + irvalue.const_value = value + + if not isinstance(value, np.ndarray): + # ONNX does not have a way to represent non-tensor constants, eg. a sequence. + # So, a constant-value of type sequence is not folded, but it can be used + # to optimize subsequent operations when possible. + logger.info( + "Skip storing constant folded value %s due to unsupported type %s.", + irvalue.name, + type(value), + ) + return None + + if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: + logger.info( + "Skip storing constant folded nvalue %s due to large size %s.", + irvalue.name, + value.nbytes, + ) + return None + + tensor = onnx.numpy_helper.from_array(value, name) + + logger.debug( + "New constant for value %s dtype: %s shape: %s", + irvalue.name, + value.dtype, + value.shape, + ) + + # TODO(rama) + # irvalue.type = onnx.helper.make_tensor_type_proto( + # onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape + # ) + attributes = _convenience.convert_attributes({"value": tensor}) + node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) + return [node] + + def process_node(self, node: ir.Node, root: ir.Graph | ir.Function): + for i, value in enumerate(node.inputs): + if value is not None and value.symbolic_value is not None: + sym_value = value.symbolic_value + if isinstance(sym_value, ir.Value): + node.replace_input_with(i, sym_value) + # TODO(rama): consider merging type/other info from both values + + # Do incremental shape inference + + if node.domain not in self.opset_imports: + return None + version = self.opset_imports[node.domain] + op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) + for optimizer in op_optimizers: + assert optimizer + context = onnxscript.rewriter.pattern.RewriterContext() + output = optimizer(context, node) + if output is not None: + return output + + if is_control_flow_op(node) or is_non_deterministic_op(node): + return None + + if any((x is not None and x.const_value is None) for x in node.inputs): + return None + + input_values = [x.const_value.numpy() if x is not None else None for x in node.inputs] + # Filter out bfloat16 cases? + outputs = reference_evaluator.evaluate(node.domain, node.op_type, version, *input_values, **node.attributes) + if outputs is None: + return None + if len(node.output) == 1 and not isinstance(outputs, (tuple, list)): + replacement = self.new_constant(node.outputs[0], outputs) + if is_constant_op(node): + return None + # self.add_count(op, outputs.size) + return replacement + else: + logger.warning("Skipping constant folding for op %s with multiple outputs.", node.op_type) + return None + + def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): + # TODO: apply delta! what about opset_imports? + + for old_value, new_value in zip(old_values, new_values): + # Propagate relevant info from old value to new value + # TODO(Rama): Perhaps we should merge old and new types. As of now, new + # values don't have type information. Note that this could be a problem + # for semantics-altering rewrite-rules: we should allow users to override + # this for such rules. + new_value.type = old_value.type + new_value.shape = old_value.shape + new_value.const_value = old_value.const_value + new_value.name = old_value.name + + # Reconnect the users of the deleted node to use the new outputs + _convenience.replace_all_uses_with(old_values, new_values) + # Update graph/function outputs if the node generates output + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(root.outputs): + if graph_or_function_output in replacement_mapping: + root.outputs[idx] = replacement_mapping[graph_or_function_output] + + # insert new nodes after the index node + root.insert_after(node, delta.new_nodes) + root.remove(node, safe=True) + + # if isinstance(output, list): + # return output + # else: + # # Currently handles single output only + # self.add_count(node.op_type, output.size) + # return self.new_constant(node.output[0], output) + + def visit_node(self, node: ir.Node): + replacement = self.process_node(node) + # logger.debug( + # "visit_node: %s::%s %s replacement %s", + # node.domain, + # node.op_type, + # node.name, + # "found" if replacement is not None else "missed", + # ) + if replacement is None: + # No change. Process attributes. + for attr in node.attribute: + self.visit_attribute(attr) + return None + else: + self.replace_node(node, replacement) From 3a374033d3a9cbe7046533119c579463227046f6 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 3 Jul 2024 16:45:17 -0700 Subject: [PATCH 03/22] Some cleanup --- onnxscript/optimizer/evaluator_ir.py | 72 +++++++++------------------- 1 file changed, 22 insertions(+), 50 deletions(-) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index 416286f56..90b13b282 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -17,7 +17,7 @@ import onnxscript.ir as ir import onnxscript.ir._convenience as _convenience import onnxscript.optimizer.constant_folding as constant_folding -import onnxscript.rewriter.pattern +import onnxscript.rewriter.pattern as orp from onnxscript.utils.utils import ( get_node_attr_value, ) @@ -56,40 +56,11 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: # The "partial evaluators" below are non-standard evaluators. They are used to perform # partial evaluation and/or static program analysis (abstract interpretation). +# A partial-evaluator function takes an RewriterContext and a node, and returns the ir.Value +# or ir.Values to replace the output values of the node or None (if no replacement is needed). -# class IRContext(Protocol): -# """A class that represents the context for partial evaluation. - -# This is a placeholder, subject to simplification when a proper IR is defined. -# """ - -# def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... - -# def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... - -# def input_const_value(self, node: onnx.NodeProto, index: int) -> ir.ConcreteValue: ... - -# def input_shape( -# self, node: onnx.NodeProto, index: int -# ) -> onnx.TensorShapeProto | None: ... - -# def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: ... - -# def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: ... - -# def lookup_version(self, domain: str) -> int: ... - -# def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict: ... - -# def new_constant(self, name: str, value: Any) -> Sequence[onnx.NodeProto] | None: ... - - -# A partial-evaluator function takes an IRContext and a node, and returns a list of -# replacement nodes or None (if no replacement is needed). We return None instead -# of [input node] so the caller is aware that the node is not replaced. If the node -# is replaced, the caller will recursively visit the replacement nodes to process them. - -PartialEvaluatorFunction = Callable[[ir.Node], Union[Sequence[ir.Node], None]] +ReturnValue = Union[Sequence[ir.Value], ir.Value, None] +PartialEvaluatorFunction = Callable[[orp.RewriterContext, ir.Node], ReturnValue] @dataclasses.dataclass class PartialEvaluator: @@ -184,9 +155,10 @@ def updateType(value: ir.Value, type: ir.TypeProtocol) -> None: # TODO: merge types value.type = type +# TODO(rama): The following should not be necessary. Generic incremental shape-inference +# should handle this. This essentially implements type/shape-inference for Cast op. @register("Cast") -def cast(node: ir.Node) -> Sequence[ir.Node] | None: - # This should not be necessary. Generic incremental shape-inference should handle this. +def cast(node: ir.Node) -> ReturnValue: input = getInput(node, 0) output = getOutput(node, 0) if input is not None and output is not None: @@ -195,7 +167,7 @@ def cast(node: ir.Node) -> Sequence[ir.Node] | None: @register("CastLike") -def cast_like(op, node: ir.Node): +def cast_like(op, node: ir.Node) -> ReturnValue: input0 = node.inputs[0] input1 = node.inputs[1] source_element_type = input0.type.dtype.value @@ -209,7 +181,7 @@ def cast_like(op, node: ir.Node): @register("Shape") -def shape(op, node: ir.Node): +def shape(op, node: ir.Node) -> ReturnValue: del op input = node.inputs[0] shape = input.shape @@ -219,12 +191,12 @@ def shape(op, node: ir.Node): end = node.attributes.get("end", None) shape_slice = shape.dim[start:end] if all(d.HasField("dim_value") for d in shape_slice): - return np.array([d.dim_value for d in shape_slice], dtype=np.int64) + return op.Constant(value_ints = [d.dim_value for d in shape_slice]) return None @register("Size") -def size(op, node: ir.Node): +def size(op, node: ir.Node) -> ReturnValue: del op shape = node.inputs[0].shape if shape is None: @@ -234,10 +206,10 @@ def size(op, node: ir.Node): if not isinstance(d, int): return None size *= d - return np.array(size, dtype=np.int64) + return op.Constant(value_int = size) @register("If") -def if_op(op, node: ir.Node): +def if_op(op, node: ir.Node) -> ReturnValue: cond = getInput(node, 0) cond = get_bool_value(cond) if cond is not None: @@ -266,12 +238,12 @@ def rename(name): sub_node.name = f"{node.name}_{sub_node.name}" # TODO: we should handle initializers as well! - return list(graph) + return formal_outs return None @register("Identity") -def identity(op, node: ir.Node): +def identity(op, node: ir.Node) -> ReturnValue: del op input = node.inputs[0] output = node.outputs[0] @@ -281,7 +253,7 @@ def identity(op, node: ir.Node): @register("SequenceConstruct") -def sequence_construct(op, node: ir.Node): +def sequence_construct(op, node: ir.Node) -> ReturnValue: del op output = node.outputs[0] if output is not None: @@ -290,7 +262,7 @@ def sequence_construct(op, node: ir.Node): @register("ConcatFromSequence") -def concat_from_sequence(op, node: ir.Node): +def concat_from_sequence(op, node: ir.Node) -> ReturnValue: input = node.inputs[0] inputs = input.symbolic_value if any(x is None for x in inputs): @@ -318,7 +290,7 @@ def concat_from_sequence(op, node: ir.Node): @register("SplitToSequence") -def split_to_sequence(op, node: ir.Node): +def split_to_sequence(op, node: ir.Node) -> ReturnValue: """Rewriting pattern. From @@ -390,11 +362,10 @@ def split_to_sequence(op, node: ir.Node): logger.debug("SplitToSequence => Split + SequenceConstruct") return op.SequenceConstruct(*split_values) - # return [split_node, *squeeze_nodes, node] @register("SequenceAt") -def sequence_at(op, node: ir.Node): +def sequence_at(op, node: ir.Node) -> ReturnValue: input = node.inputs[0] position = node.inputs[1] output = node.outputs[0] @@ -477,9 +448,10 @@ def process_node(self, node: ir.Node, root: ir.Graph | ir.Function): op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) for optimizer in op_optimizers: assert optimizer - context = onnxscript.rewriter.pattern.RewriterContext() + context = orp.RewriterContext() output = optimizer(context, node) if output is not None: + # TODO(rama): return nodes, values return output if is_control_flow_op(node) or is_non_deterministic_op(node): From e868b952bc92c7c985f434fd6a54c58f3118e61d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 4 Jul 2024 20:38:22 -0700 Subject: [PATCH 04/22] Complete refactoring --- onnxscript/optimizer/evaluator_ir.py | 122 +++++++++++++++++++++++++-- 1 file changed, 116 insertions(+), 6 deletions(-) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index 90b13b282..7389f584f 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -16,6 +16,7 @@ import onnxscript.ir as ir import onnxscript.ir._convenience as _convenience +import onnxscript.ir.serde as serde import onnxscript.optimizer.constant_folding as constant_folding import onnxscript.rewriter.pattern as orp from onnxscript.utils.utils import ( @@ -385,10 +386,80 @@ def sequence_at(op, node: ir.Node) -> ReturnValue: return op.Identity(result) return None +@dataclasses.dataclass +class Replacement: + """A replacement for a node in the graph.""" + new_outputs: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] class ConstantFolder: opset_imports: dict[str, int] + def __init__( + self, + external_data_folder: str, + *, + do_shape_inference: bool, + ) -> None: + self._external_data_folder = external_data_folder + self._do_shape_inference = do_shape_inference + self._init() + + def _init(self) -> None: + self.counts = {} + self.sizes = {} + self.modified = False + + def _do_inference(self, node: ir.Node) -> None: + output_types = {} + + # TODO: handle optional inputs + def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: + value = get_numpy_value(x) + if isinstance(value, np.ndarray) and value.size < 20: + return onnx.numpy_helper.from_array(value, node.inputs[i].name) + return None + + def get_type(value: ir.Value) -> onnx.TypeProto | None: + if value.type is not None: + type_proto = onnx.TypeProto() + serde.serialize_type_into(type_proto, value.type) + if value.shape is not None: + serde.serialize_shape_into(type_proto, value.shape) + return type_proto + return None + + input_types = {x.name: get_type(x) for x in node.inputs if x is not None} + input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None} + input_data = {k: v for k, v in input_data.items() if v is not None} + if any(t is None for t in input_types.values()): + logger.debug( + "Skipping shape inference for node %s due to missing input type.", + node.name, + ) + else: + # TODO: pass in constant values, ir_version + try: + schema = onnx.defs.get_schema( + node.op_type, self.opset_imports[node.domain], node.domain + ) + output_types = onnx.shape_inference.infer_node_outputs( + schema, node, input_types, input_data + ) + except Exception as e: + logger.debug( + "Skipping shape inference for node %s due to exception: %s", + node.name, + e, + ) + + for output in node.outputs: + if output.name in output_types: + inferred_type = output_types[output.name] + # TODO: merge types, check for conflicts + output.shape = serde.deserialize_type_proto_for_shape(inferred_type) + output.type = serde.deserialize_type_proto_for_type(inferred_type) + def new_constant(self, irvalue: ir.Value, value): # TODO(rama): Why do we need the conversion below? if isinstance(value, (int, float, np.ScalarType)): @@ -430,7 +501,7 @@ def new_constant(self, irvalue: ir.Value, value): # ) attributes = _convenience.convert_attributes({"value": tensor}) node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) - return [node] + return node def process_node(self, node: ir.Node, root: ir.Graph | ir.Function): for i, value in enumerate(node.inputs): @@ -441,6 +512,8 @@ def process_node(self, node: ir.Node, root: ir.Graph | ir.Function): # TODO(rama): consider merging type/other info from both values # Do incremental shape inference + if self.do_shape_inference and not is_control_flow_op(node): + self._do_inference(node) if node.domain not in self.opset_imports: return None @@ -452,7 +525,9 @@ def process_node(self, node: ir.Node, root: ir.Graph | ir.Function): output = optimizer(context, node) if output is not None: # TODO(rama): return nodes, values - return output + if isinstance(output, ir.Value): + output = [output] + return Replacement(output, context.nodes) if is_control_flow_op(node) or is_non_deterministic_op(node): return None @@ -467,17 +542,18 @@ def process_node(self, node: ir.Node, root: ir.Graph | ir.Function): return None if len(node.output) == 1 and not isinstance(outputs, (tuple, list)): replacement = self.new_constant(node.outputs[0], outputs) - if is_constant_op(node): + if is_constant_op(node) or replacement is None: return None # self.add_count(op, outputs.size) - return replacement + return Replacement(replacement.outputs, [replacement]) else: logger.warning("Skipping constant folding for op %s with multiple outputs.", node.op_type) return None def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): # TODO: apply delta! what about opset_imports? - + old_values = node.outputs + new_values = replacement.new_outputs for old_value, new_value in zip(old_values, new_values): # Propagate relevant info from old value to new value # TODO(Rama): Perhaps we should merge old and new types. As of now, new @@ -498,7 +574,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) root.outputs[idx] = replacement_mapping[graph_or_function_output] # insert new nodes after the index node - root.insert_after(node, delta.new_nodes) + root.insert_after(node, replacement.new_nodes) root.remove(node, safe=True) # if isinstance(output, list): @@ -524,3 +600,37 @@ def visit_node(self, node: ir.Node): return None else: self.replace_node(node, replacement) + + def visit_graph(self, graph: ir.Graph) -> None: + for node in graph: + self.visit_node(node) + + def visit_model(self, model: ir.Model) -> None: + self._init() + self.opset_imports = model.opset_imports + self.visit_graph(model.graph) + # TODO(rama): handle functions + +def fold_constants( + model: ir.Model, + external_data_folder: str = "", + *, + onnx_shape_inference: bool = False, +) -> bool: + """ + Applies constant folding optimization to the model. + Returns true iff the model was modified. + """ + folder = ConstantFolder( + external_data_folder, + onnx_shape_inference, + ) + folder.visit_model(model) + for op in folder.counts: + logger.info( + "Constant-folded '%s' %s times, with %s size.", + op, + folder.counts[op], + folder.sizes[op], + ) + return folder.modified \ No newline at end of file From f237cd6abde0fa1d714dc90ab70731c4977b0e45 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 5 Jul 2024 14:33:53 -0700 Subject: [PATCH 05/22] Fix various bugs and issues --- onnxscript/optimizer/evaluator_ir.py | 77 ++++++++++++++++++---------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index 7389f584f..fefd7abed 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -23,9 +23,15 @@ get_node_attr_value, ) -is_control_flow_op = constant_folding.is_control_flow_op -is_non_deterministic_op = constant_folding.is_non_deterministic_op -is_constant_op = constant_folding.is_constant_op +def is_control_flow_op(node: ir.Node) -> bool: + return any(isinstance(attr, (ir.AttrGraph, ir.AttrGraphs)) for attr in node.attributes.values()) + +def is_non_deterministic_op(node: ir.Node) -> bool: + return node.op_type in constant_folding.non_deterministic_ops and constant_folding.is_onnx_domain(node.domain) + +def is_constant_op(node: ir.Node) -> bool: + return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain(node.domain) + _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT logger = logging.getLogger(__name__) @@ -121,6 +127,13 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: register = registry.register +def get_sym_value(val: ir.Value | None) -> ir.Value | None: + if val is None: + return None + if hasattr(val, "symbolic_value"): + return val.symbolic_value + return None + def get_numpy_value(val: ir.Value) -> np.ndarray | None: const_value = val.const_value if hasattr(const_value, "numpy"): @@ -398,7 +411,6 @@ class ConstantFolder: def __init__( self, external_data_folder: str, - *, do_shape_inference: bool, ) -> None: self._external_data_folder = external_data_folder @@ -446,6 +458,12 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: output_types = onnx.shape_inference.infer_node_outputs( schema, node, input_types, input_data ) + for output in node.outputs: + if output.name in output_types: + inferred_type = output_types[output.name] + # TODO: merge types, check for conflicts + output.shape = serde.deserialize_type_proto_for_shape(inferred_type) + output.type = serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: logger.debug( "Skipping shape inference for node %s due to exception: %s", @@ -453,20 +471,13 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) - for output in node.outputs: - if output.name in output_types: - inferred_type = output_types[output.name] - # TODO: merge types, check for conflicts - output.shape = serde.deserialize_type_proto_for_shape(inferred_type) - output.type = serde.deserialize_type_proto_for_type(inferred_type) + def new_constant(self, irvalue: ir.Value, value): # TODO(rama): Why do we need the conversion below? if isinstance(value, (int, float, np.ScalarType)): value = np.array(value) - irvalue.const_value = value - if not isinstance(value, np.ndarray): # ONNX does not have a way to represent non-tensor constants, eg. a sequence. # So, a constant-value of type sequence is not folded, but it can be used @@ -478,6 +489,8 @@ def new_constant(self, irvalue: ir.Value, value): ) return None + irvalue.const_value = _convenience.tensor(value) + if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: logger.info( "Skip storing constant folded nvalue %s due to large size %s.", @@ -486,7 +499,7 @@ def new_constant(self, irvalue: ir.Value, value): ) return None - tensor = onnx.numpy_helper.from_array(value, name) + tensor = onnx.numpy_helper.from_array(value, irvalue.name) logger.debug( "New constant for value %s dtype: %s shape: %s", @@ -503,16 +516,15 @@ def new_constant(self, irvalue: ir.Value, value): node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) return node - def process_node(self, node: ir.Node, root: ir.Graph | ir.Function): + def process_node(self, node: ir.Node): for i, value in enumerate(node.inputs): - if value is not None and value.symbolic_value is not None: - sym_value = value.symbolic_value - if isinstance(sym_value, ir.Value): - node.replace_input_with(i, sym_value) - # TODO(rama): consider merging type/other info from both values + sym_value = get_sym_value(value) + if isinstance(sym_value, ir.Value): + node.replace_input_with(i, sym_value) + # TODO(rama): consider merging type/other info from both values # Do incremental shape inference - if self.do_shape_inference and not is_control_flow_op(node): + if self._do_shape_inference and not is_control_flow_op(node): self._do_inference(node) if node.domain not in self.opset_imports: @@ -537,10 +549,15 @@ def process_node(self, node: ir.Node, root: ir.Graph | ir.Function): input_values = [x.const_value.numpy() if x is not None else None for x in node.inputs] # Filter out bfloat16 cases? - outputs = reference_evaluator.evaluate(node.domain, node.op_type, version, *input_values, **node.attributes) + def convert(av): + if isinstance(av, ir.AttrTensor): + return serde.serialize_tensor(av.value) + return av.value + attr_values = { name: convert(attr) for name, attr in node.attributes.items() } + outputs = reference_evaluator.evaluate(node.domain, node.op_type, version, *input_values, **attr_values) if outputs is None: return None - if len(node.output) == 1 and not isinstance(outputs, (tuple, list)): + if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): replacement = self.new_constant(node.outputs[0], outputs) if is_constant_op(node) or replacement is None: return None @@ -584,7 +601,14 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) # self.add_count(node.op_type, output.size) # return self.new_constant(node.output[0], output) - def visit_node(self, node: ir.Node): + def visit_attribute(self, attr: ir.Attr) -> None: + if isinstance(attr, ir.AttrGraph): + self.visit_graph(attr.value) + elif isinstance(attr, ir.AttrGraphs): + for graph in attr.value: + self.visit_graph(graph) + + def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): replacement = self.process_node(node) # logger.debug( # "visit_node: %s::%s %s replacement %s", @@ -595,15 +619,16 @@ def visit_node(self, node: ir.Node): # ) if replacement is None: # No change. Process attributes. - for attr in node.attribute: + for attr in node.attributes.values(): self.visit_attribute(attr) return None + else: - self.replace_node(node, replacement) + self.replace_node(node, replacement, root) def visit_graph(self, graph: ir.Graph) -> None: for node in graph: - self.visit_node(node) + self.visit_node(node, graph) def visit_model(self, model: ir.Model) -> None: self._init() From 937885ed216c439c0f2c0c892ee398dfa06adcb1 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 5 Jul 2024 15:50:06 -0700 Subject: [PATCH 06/22] More fixes --- onnxscript/optimizer/evaluator_ir.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index fefd7abed..4be34891c 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -555,6 +555,7 @@ def convert(av): return av.value attr_values = { name: convert(attr) for name, attr in node.attributes.items() } outputs = reference_evaluator.evaluate(node.domain, node.op_type, version, *input_values, **attr_values) + if outputs is None: return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): From a206709fb50ebeba479978496742865e477c851f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 5 Jul 2024 16:21:44 -0700 Subject: [PATCH 07/22] More fixes --- onnxscript/optimizer/evaluator_ir.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index 4be34891c..54295a694 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -17,6 +17,7 @@ import onnxscript.ir as ir import onnxscript.ir._convenience as _convenience import onnxscript.ir.serde as serde +import onnxscript.ir._enums as _enums import onnxscript.optimizer.constant_folding as constant_folding import onnxscript.rewriter.pattern as orp from onnxscript.utils.utils import ( @@ -169,10 +170,16 @@ def updateType(value: ir.Value, type: ir.TypeProtocol) -> None: # TODO: merge types value.type = type +def getInputElementType(node: ir.Node, index: int) -> int: + input = getInput(node, index) + if input is not None and input.type is not None: + return input.type.dtype.value + return _enums.DataType.UNDEFINED.value + # TODO(rama): The following should not be necessary. Generic incremental shape-inference # should handle this. This essentially implements type/shape-inference for Cast op. @register("Cast") -def cast(node: ir.Node) -> ReturnValue: +def cast(op, node: ir.Node) -> ReturnValue: input = getInput(node, 0) output = getOutput(node, 0) if input is not None and output is not None: @@ -183,11 +190,10 @@ def cast(node: ir.Node) -> ReturnValue: @register("CastLike") def cast_like(op, node: ir.Node) -> ReturnValue: input0 = node.inputs[0] - input1 = node.inputs[1] - source_element_type = input0.type.dtype.value - target_element_type = input1.type.dtype.value + source_element_type = getInputElementType(node, 0) + target_element_type = getInputElementType(node, 1) - if target_element_type is None: + if target_element_type is _enums.DataType.UNDEFINED.value: return None if source_element_type == target_element_type: return op.Identity(input0) @@ -196,22 +202,20 @@ def cast_like(op, node: ir.Node) -> ReturnValue: @register("Shape") def shape(op, node: ir.Node) -> ReturnValue: - del op input = node.inputs[0] shape = input.shape if shape is None: return None start = node.attributes.get("start", 0) end = node.attributes.get("end", None) - shape_slice = shape.dim[start:end] - if all(d.HasField("dim_value") for d in shape_slice): - return op.Constant(value_ints = [d.dim_value for d in shape_slice]) + shape_slice = shape[start:end] + if all(isinstance(d,int) for d in shape_slice): + return op.Constant(value_ints = [d for d in shape_slice]) return None @register("Size") def size(op, node: ir.Node) -> ReturnValue: - del op shape = node.inputs[0].shape if shape is None: return None @@ -555,7 +559,7 @@ def convert(av): return av.value attr_values = { name: convert(attr) for name, attr in node.attributes.items() } outputs = reference_evaluator.evaluate(node.domain, node.op_type, version, *input_values, **attr_values) - + if outputs is None: return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): From c689aca740324572e16088efdc15248b3c405663 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 5 Jul 2024 18:17:55 -0700 Subject: [PATCH 08/22] More fixes --- onnxscript/optimizer/evaluator_ir.py | 44 ++++++++++++++++++---------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index 54295a694..caae4a0d6 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -61,13 +61,21 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: reference_evaluator = ReferenceEvaluator() +@dataclasses.dataclass +class Replacement: + """A replacement for a node in the graph.""" + new_outputs: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] + # The "partial evaluators" below are non-standard evaluators. They are used to perform # partial evaluation and/or static program analysis (abstract interpretation). -# A partial-evaluator function takes an RewriterContext and a node, and returns the ir.Value -# or ir.Values to replace the output values of the node or None (if no replacement is needed). +# A partial-evaluator function takes an RewriterContext and a node, and returns a Replacement +# for the node or None (if no replacement is needed). It may also return just the ir.Value +# or ir.Values to replace the output values of the node, when the new nodes can be inferred +# from the RewriterContext used to build the new nodes. -ReturnValue = Union[Sequence[ir.Value], ir.Value, None] +ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] PartialEvaluatorFunction = Callable[[orp.RewriterContext, ir.Node], ReturnValue] @dataclasses.dataclass @@ -207,7 +215,11 @@ def shape(op, node: ir.Node) -> ReturnValue: if shape is None: return None start = node.attributes.get("start", 0) + if start != 0: + start = start.value end = node.attributes.get("end", None) + if end is not None: + end = end.value shape_slice = shape[start:end] if all(isinstance(d,int) for d in shape_slice): return op.Constant(value_ints = [d for d in shape_slice]) @@ -233,9 +245,10 @@ def if_op(op, node: ir.Node) -> ReturnValue: if cond is not None: # cond is a constant-value: inline the branch branch = "then_branch" if cond else "else_branch" - graph = node.attributes.get(branch, None) - if graph is None: + graph_attr = node.attributes.get(branch, None) + if not isinstance(graph_attr, ir.AttrGraph): return None + graph : ir.Graph = graph_attr.value formal_outs = graph.outputs actual_outs = node.outputs renamings = { @@ -247,8 +260,9 @@ def if_op(op, node: ir.Node) -> ReturnValue: def rename(name): return renamings.get(name, name) - - for sub_node in graph: + graph_nodes = list(graph) + graph.remove(graph_nodes) + for sub_node in graph_nodes: # TODO: handle renaming inside subgraphs in nodes for v in sub_node.outputs: v.name = rename(v.name) @@ -256,7 +270,7 @@ def rename(name): sub_node.name = f"{node.name}_{sub_node.name}" # TODO: we should handle initializers as well! - return formal_outs + return Replacement(formal_outs, graph_nodes) return None @@ -337,6 +351,8 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: return None axis = node.attributes.get("axis", 0) + if axis != 0: + axis = axis.value shape = input.shape if shape is None: return None @@ -403,11 +419,6 @@ def sequence_at(op, node: ir.Node) -> ReturnValue: return op.Identity(result) return None -@dataclasses.dataclass -class Replacement: - """A replacement for a node in the graph.""" - new_outputs: Sequence[ir.Value] - new_nodes: Sequence[ir.Node] class ConstantFolder: opset_imports: dict[str, int] @@ -433,7 +444,7 @@ def _do_inference(self, node: ir.Node) -> None: def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: value = get_numpy_value(x) if isinstance(value, np.ndarray) and value.size < 20: - return onnx.numpy_helper.from_array(value, node.inputs[i].name) + return onnx.numpy_helper.from_array(value, x.name) return None def get_type(value: ir.Value) -> onnx.TypeProto | None: @@ -460,7 +471,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: node.op_type, self.opset_imports[node.domain], node.domain ) output_types = onnx.shape_inference.infer_node_outputs( - schema, node, input_types, input_data + schema, serde.serialize_node(node), input_types, input_data ) for output in node.outputs: if output.name in output_types: @@ -540,7 +551,8 @@ def process_node(self, node: ir.Node): context = orp.RewriterContext() output = optimizer(context, node) if output is not None: - # TODO(rama): return nodes, values + if isinstance(output, Replacement): + return output if isinstance(output, ir.Value): output = [output] return Replacement(output, context.nodes) From 46c773f45ee6107555945cbfbcb53a2ec4ef86ca Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 7 Jul 2024 20:51:44 -0700 Subject: [PATCH 09/22] More typo fixes --- onnxscript/optimizer/evaluator_ir.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index caae4a0d6..61444f9d4 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -184,6 +184,14 @@ def getInputElementType(node: ir.Node, index: int) -> int: return input.type.dtype.value return _enums.DataType.UNDEFINED.value +def getIntAttribute(node: ir.Node, name: str, default: int | None = None) -> int | None: + if name in node.attributes: + attr = node.attributes[name] + if isinstance(attr, ir.AttrInt64): + return attr.value + return None + return default + # TODO(rama): The following should not be necessary. Generic incremental shape-inference # should handle this. This essentially implements type/shape-inference for Cast op. @register("Cast") @@ -300,7 +308,11 @@ def concat_from_sequence(op, node: ir.Node) -> ReturnValue: if any(x is None for x in inputs): return None new_axis = node.attributes.get("new_axis", 0) - axis = node.attributes["axis"] + if new_axis != 0: + new_axis = new_axis.value + if "axis" not in node.attributes: + return None + axis = node.attributes["axis"].value if input is not None and isinstance(input.symbolic_value, list): if new_axis == 0: logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) @@ -310,7 +322,7 @@ def concat_from_sequence(op, node: ir.Node) -> ReturnValue: axis_value = op.Constant(value_int=axis) unsqueezed_inputs = [] for node_input in inputs: - unsqueezed_input = op.Unsqueeze(node_input, axis_value, output=[f"{node_input.name}_unsqueeze"]) + unsqueezed_input = op.Unsqueeze(node_input, axis_value, outputs=[f"{node_input.name}_unsqueeze"]) unsqueezed_inputs.append(unsqueezed_input) # Send unsqueezed outputs to Concat logger.debug( @@ -374,27 +386,31 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: # split into chunks all of size 'split' if possible. num_outputs = math.ceil(split_dimension_size / split_value.item()) split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split(input, axis=axis, num_outputs=num_outputs, output=split_outputs) + split_values = op.Split(input, axis=axis, num_outputs=num_outputs, outputs=split_outputs) elif split_value.ndim == 1: # split into 'size(split)' chunks num_outputs = split_value.size split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split(input, split, axis=axis, output=split_outputs) + split_values = op.Split(input, split, axis=axis, outputs=split_outputs) else: return None - keepdims = node.attributes.get("keepdims", 1) + keepdims = getIntAttribute(node, "keepdims", 1) + if keepdims is None: + return None if keepdims == 0: # squeeze the split dimension if keepdims is 0 axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) squeezed_values = [] for i in range(num_outputs): - squeezed = op.Squeeze(split_values[i], axis_val, output=[f"{split_outputs[i]}_squeeze"]) + squeezed = op.Squeeze(split_values[i], axis_val, outputs=[f"{split_outputs[i]}_squeeze"]) squeezed_values.append(squeezed) split_values = squeezed_values logger.debug("SplitToSequence => Split + SequenceConstruct") + if isinstance(split_values, ir.Value): + split_values = [split_values] return op.SequenceConstruct(*split_values) From ea4b1ca06d178c301771238b3601c0c854bff725 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 8 Jul 2024 08:53:49 -0700 Subject: [PATCH 10/22] Various minor fixes --- onnxscript/optimizer/evaluator_ir.py | 152 ++++++++++++++++----------- 1 file changed, 88 insertions(+), 64 deletions(-) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index 61444f9d4..08497085f 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -8,7 +8,7 @@ import dataclasses import logging import math -from typing import Any, Callable, Protocol, Sequence, Union +from typing import Any, Callable, Sequence, Union import numpy as np import onnx @@ -16,22 +16,30 @@ import onnxscript.ir as ir import onnxscript.ir._convenience as _convenience -import onnxscript.ir.serde as serde import onnxscript.ir._enums as _enums +import onnxscript.ir.serde as serde import onnxscript.optimizer.constant_folding as constant_folding import onnxscript.rewriter.pattern as orp -from onnxscript.utils.utils import ( - get_node_attr_value, -) + def is_control_flow_op(node: ir.Node) -> bool: - return any(isinstance(attr, (ir.AttrGraph, ir.AttrGraphs)) for attr in node.attributes.values()) + return any( + isinstance(attr, (ir.AttrGraph, ir.AttrGraphs)) for attr in node.attributes.values() + ) + def is_non_deterministic_op(node: ir.Node) -> bool: - return node.op_type in constant_folding.non_deterministic_ops and constant_folding.is_onnx_domain(node.domain) + return ( + node.op_type in constant_folding.non_deterministic_ops + and constant_folding.is_onnx_domain(node.domain) + ) + def is_constant_op(node: ir.Node) -> bool: - return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain(node.domain) + return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain( + node.domain + ) + _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT @@ -61,12 +69,15 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: reference_evaluator = ReferenceEvaluator() + @dataclasses.dataclass class Replacement: """A replacement for a node in the graph.""" + new_outputs: Sequence[ir.Value] new_nodes: Sequence[ir.Node] + # The "partial evaluators" below are non-standard evaluators. They are used to perform # partial evaluation and/or static program analysis (abstract interpretation). @@ -78,6 +89,7 @@ class Replacement: ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] PartialEvaluatorFunction = Callable[[orp.RewriterContext, ir.Node], ReturnValue] + @dataclasses.dataclass class PartialEvaluator: """A class that represents a partial-evaluator for a particular op. @@ -136,23 +148,26 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: register = registry.register -def get_sym_value(val: ir.Value | None) -> ir.Value | None: + +def _get_sym_value(val: ir.Value | None) -> ir.Value | None: if val is None: return None if hasattr(val, "symbolic_value"): return val.symbolic_value return None -def get_numpy_value(val: ir.Value) -> np.ndarray | None: + +def _get_numpy_value(val: ir.Value) -> np.ndarray | None: const_value = val.const_value if hasattr(const_value, "numpy"): return const_value.numpy() return None -def get_bool_value(val: ir.Value | None) -> bool | None: + +def _get_bool_value(val: ir.Value | None) -> bool | None: if val is None: return None - val = get_numpy_value(val) + val = _get_numpy_value(val) if val is None: return None if isinstance(val, bool): @@ -164,50 +179,55 @@ def get_bool_value(val: ir.Value | None) -> bool | None: return None -def getInput(node:ir.Node, index: int) -> ir.Value | None: +def _get_input(node: ir.Node, index: int) -> ir.Value | None: if index < len(node.inputs): return node.inputs[index] return None -def getOutput(node:ir.Node, index: int) -> ir.Value | None: + +def _get_output(node: ir.Node, index: int) -> ir.Value | None: if index < len(node.outputs): return node.outputs[index] return None -def updateType(value: ir.Value, type: ir.TypeProtocol) -> None: + +def _update_type(value: ir.Value, type: ir.TypeProtocol) -> None: # TODO: merge types value.type = type -def getInputElementType(node: ir.Node, index: int) -> int: - input = getInput(node, index) + +def _get_input_element_type(node: ir.Node, index: int) -> int: + input = _get_input(node, index) if input is not None and input.type is not None: return input.type.dtype.value return _enums.DataType.UNDEFINED.value -def getIntAttribute(node: ir.Node, name: str, default: int | None = None) -> int | None: + +def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: if name in node.attributes: attr = node.attributes[name] if isinstance(attr, ir.AttrInt64): return attr.value return None return default - + + # TODO(rama): The following should not be necessary. Generic incremental shape-inference # should handle this. This essentially implements type/shape-inference for Cast op. @register("Cast") def cast(op, node: ir.Node) -> ReturnValue: - input = getInput(node, 0) - output = getOutput(node, 0) + input = _get_input(node, 0) + output = _get_output(node, 0) if input is not None and output is not None: - updateType(output, input.type) + _update_type(output, input.type) return None @register("CastLike") def cast_like(op, node: ir.Node) -> ReturnValue: input0 = node.inputs[0] - source_element_type = getInputElementType(node, 0) - target_element_type = getInputElementType(node, 1) + source_element_type = _get_input_element_type(node, 0) + target_element_type = _get_input_element_type(node, 1) if target_element_type is _enums.DataType.UNDEFINED.value: return None @@ -222,15 +242,11 @@ def shape(op, node: ir.Node) -> ReturnValue: shape = input.shape if shape is None: return None - start = node.attributes.get("start", 0) - if start != 0: - start = start.value - end = node.attributes.get("end", None) - if end is not None: - end = end.value + start = _get_int_attribute(node, "start", 0) + end = _get_int_attribute(node, "end", None) shape_slice = shape[start:end] - if all(isinstance(d,int) for d in shape_slice): - return op.Constant(value_ints = [d for d in shape_slice]) + if all(isinstance(d, int) for d in shape_slice): + return op.Constant(value_ints=[d for d in shape_slice]) return None @@ -244,19 +260,20 @@ def size(op, node: ir.Node) -> ReturnValue: if not isinstance(d, int): return None size *= d - return op.Constant(value_int = size) + return op.Constant(value_int=size) + @register("If") def if_op(op, node: ir.Node) -> ReturnValue: - cond = getInput(node, 0) - cond = get_bool_value(cond) + cond = _get_input(node, 0) + cond = _get_bool_value(cond) if cond is not None: # cond is a constant-value: inline the branch branch = "then_branch" if cond else "else_branch" graph_attr = node.attributes.get(branch, None) if not isinstance(graph_attr, ir.AttrGraph): return None - graph : ir.Graph = graph_attr.value + graph: ir.Graph = graph_attr.value formal_outs = graph.outputs actual_outs = node.outputs renamings = { @@ -268,6 +285,7 @@ def if_op(op, node: ir.Node) -> ReturnValue: def rename(name): return renamings.get(name, name) + graph_nodes = list(graph) graph.remove(graph_nodes) for sub_node in graph_nodes: @@ -307,9 +325,7 @@ def concat_from_sequence(op, node: ir.Node) -> ReturnValue: inputs = input.symbolic_value if any(x is None for x in inputs): return None - new_axis = node.attributes.get("new_axis", 0) - if new_axis != 0: - new_axis = new_axis.value + new_axis = _get_int_attribute(node, "new_axis", 0) if "axis" not in node.attributes: return None axis = node.attributes["axis"].value @@ -322,12 +338,13 @@ def concat_from_sequence(op, node: ir.Node) -> ReturnValue: axis_value = op.Constant(value_int=axis) unsqueezed_inputs = [] for node_input in inputs: - unsqueezed_input = op.Unsqueeze(node_input, axis_value, outputs=[f"{node_input.name}_unsqueeze"]) + unsqueezed_input = op.Unsqueeze( + node_input, axis_value, outputs=[f"{node_input.name}_unsqueeze"] + ) unsqueezed_inputs.append(unsqueezed_input) # Send unsqueezed outputs to Concat logger.debug( - "ConcatFromSequence => Concat %s", - [x.name for x in unsqueezed_inputs] + "ConcatFromSequence => Concat %s", [x.name for x in unsqueezed_inputs] ) return op.Concat(*unsqueezed_inputs, axis=axis) return None @@ -362,9 +379,7 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: if input is None or split is None or output is None: return None - axis = node.attributes.get("axis", 0) - if axis != 0: - axis = axis.value + axis = _get_int_attribute(node, "axis", 0) shape = input.shape if shape is None: return None @@ -377,7 +392,7 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: if not isinstance(split_dimension_size, int): return None - split_value = get_numpy_value(split) + split_value = _get_numpy_value(split) if split_value is None: return None assert isinstance(split_value, np.ndarray) @@ -386,7 +401,9 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: # split into chunks all of size 'split' if possible. num_outputs = math.ceil(split_dimension_size / split_value.item()) split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split(input, axis=axis, num_outputs=num_outputs, outputs=split_outputs) + split_values = op.Split( + input, axis=axis, num_outputs=num_outputs, outputs=split_outputs + ) elif split_value.ndim == 1: # split into 'size(split)' chunks num_outputs = split_value.size @@ -395,7 +412,7 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: else: return None - keepdims = getIntAttribute(node, "keepdims", 1) + keepdims = _get_int_attribute(node, "keepdims", 1) if keepdims is None: return None if keepdims == 0: @@ -403,7 +420,9 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) squeezed_values = [] for i in range(num_outputs): - squeezed = op.Squeeze(split_values[i], axis_val, outputs=[f"{split_outputs[i]}_squeeze"]) + squeezed = op.Squeeze( + split_values[i], axis_val, outputs=[f"{split_outputs[i]}_squeeze"] + ) squeezed_values.append(squeezed) split_values = squeezed_values @@ -421,7 +440,7 @@ def sequence_at(op, node: ir.Node) -> ReturnValue: output = node.outputs[0] if input is not None and position is not None: input_vals = input.symbolic_value - position_val = get_numpy_value(position) + position_val = _get_numpy_value(position) if isinstance(input_vals, list) and position_val is not None: if position_val.size != 1: return None @@ -430,7 +449,7 @@ def sequence_at(op, node: ir.Node) -> ReturnValue: result = input_vals[position_val] except IndexError: return None - output.symbolic_value = result + output.symbolic_value = result logger.debug("SequenceAt %s => %s", input.name, result.name) return op.Identity(result) return None @@ -458,11 +477,11 @@ def _do_inference(self, node: ir.Node) -> None: # TODO: handle optional inputs def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: - value = get_numpy_value(x) + value = _get_numpy_value(x) if isinstance(value, np.ndarray) and value.size < 20: return onnx.numpy_helper.from_array(value, x.name) return None - + def get_type(value: ir.Value) -> onnx.TypeProto | None: if value.type is not None: type_proto = onnx.TypeProto() @@ -492,7 +511,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: for output in node.outputs: if output.name in output_types: inferred_type = output_types[output.name] - # TODO: merge types, check for conflicts + # TODO: merge types, check for conflicts output.shape = serde.deserialize_type_proto_for_shape(inferred_type) output.type = serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: @@ -502,8 +521,6 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) - - def new_constant(self, irvalue: ir.Value, value): # TODO(rama): Why do we need the conversion below? if isinstance(value, (int, float, np.ScalarType)): @@ -546,10 +563,10 @@ def new_constant(self, irvalue: ir.Value, value): attributes = _convenience.convert_attributes({"value": tensor}) node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) return node - + def process_node(self, node: ir.Node): for i, value in enumerate(node.inputs): - sym_value = get_sym_value(value) + sym_value = _get_sym_value(value) if isinstance(sym_value, ir.Value): node.replace_input_with(i, sym_value) # TODO(rama): consider merging type/other info from both values @@ -560,7 +577,7 @@ def process_node(self, node: ir.Node): if node.domain not in self.opset_imports: return None - version = self.opset_imports[node.domain] + version = self.opset_imports[node.domain] op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) for optimizer in op_optimizers: assert optimizer @@ -580,13 +597,17 @@ def process_node(self, node: ir.Node): return None input_values = [x.const_value.numpy() if x is not None else None for x in node.inputs] + # Filter out bfloat16 cases? def convert(av): if isinstance(av, ir.AttrTensor): return serde.serialize_tensor(av.value) return av.value - attr_values = { name: convert(attr) for name, attr in node.attributes.items() } - outputs = reference_evaluator.evaluate(node.domain, node.op_type, version, *input_values, **attr_values) + + attr_values = {name: convert(attr) for name, attr in node.attributes.items()} + outputs = reference_evaluator.evaluate( + node.domain, node.op_type, version, *input_values, **attr_values + ) if outputs is None: return None @@ -597,7 +618,9 @@ def convert(av): # self.add_count(op, outputs.size) return Replacement(replacement.outputs, [replacement]) else: - logger.warning("Skipping constant folding for op %s with multiple outputs.", node.op_type) + logger.warning( + "Skipping constant folding for op %s with multiple outputs.", node.op_type + ) return None def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): @@ -655,7 +678,7 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): for attr in node.attributes.values(): self.visit_attribute(attr) return None - + else: self.replace_node(node, replacement, root) @@ -669,6 +692,7 @@ def visit_model(self, model: ir.Model) -> None: self.visit_graph(model.graph) # TODO(rama): handle functions + def fold_constants( model: ir.Model, external_data_folder: str = "", @@ -691,4 +715,4 @@ def fold_constants( folder.counts[op], folder.sizes[op], ) - return folder.modified \ No newline at end of file + return folder.modified From bdcd70f8afcbcad853e809d5f09993fcc7b501da Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 8 Jul 2024 10:31:42 -0700 Subject: [PATCH 11/22] Update test cases --- ...evaluator_ir.py => constant_folding_ir.py} | 0 onnxscript/optimizer/constant_folding_test.py | 56 ++++++++++++------- 2 files changed, 36 insertions(+), 20 deletions(-) rename onnxscript/optimizer/{evaluator_ir.py => constant_folding_ir.py} (100%) diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/constant_folding_ir.py similarity index 100% rename from onnxscript/optimizer/evaluator_ir.py rename to onnxscript/optimizer/constant_folding_ir.py diff --git a/onnxscript/optimizer/constant_folding_test.py b/onnxscript/optimizer/constant_folding_test.py index 8fc7fe4a0..eb7adc711 100644 --- a/onnxscript/optimizer/constant_folding_test.py +++ b/onnxscript/optimizer/constant_folding_test.py @@ -1,14 +1,30 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import parameterized import unittest import onnx import pytest -from onnxscript import optimizer +import onnxscript.optimizer as optimizer +from onnxscript.optimizer import constant_folding, constant_folding_ir +from onnxscript.ir import serde +@parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) class FoldConstantsTest(unittest.TestCase): + + def _fold(self, model: onnx.ModelProto, onnx_shape_inference=False): + if self.using_ir: + ir_model = serde.deserialize_model(model) + constant_folding_ir.fold_constants(ir_model, onnx_shape_inference=onnx_shape_inference) + optimizer.remove_unused_nodes(ir_model) + return serde.serialize_model(ir_model) + else: + constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) + optimizer.remove_unused_nodes(model) + return model + def test_fold_add(self): model = onnx.parser.parse_model( """ @@ -20,7 +36,7 @@ def test_fold_add(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "four") @@ -36,7 +52,7 @@ def test_fold_cast_like(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "four") @@ -53,7 +69,7 @@ def test_fold_shape(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "four") @@ -70,7 +86,7 @@ def test_fold_shape_slice(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "four") @@ -91,7 +107,7 @@ def test_fold_if_cond(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 1) self.assertEqual(optimized.graph.node[0].output[0], "z") self.assertEqual(optimized.graph.node[0].op_type, "Mul") @@ -117,7 +133,7 @@ def test_fold_inside_if_branch(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 1) then_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "then_branch") self.assertEqual(len(then_graph.node), 2) @@ -144,7 +160,7 @@ def test_fold_if_propagate(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) print(onnx.printer.to_text(optimized)) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "m_square") @@ -161,7 +177,7 @@ def test_fold_redundant_cast(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model, onnx_shape_inference=True) self.assertEqual(len(optimized.graph.node), 2) def test_fold_redundant_cast2(self): @@ -174,7 +190,7 @@ def test_fold_redundant_cast2(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model, onnx_shape_inference=True) self.assertEqual(len(optimized.graph.node), 1) self.assertEqual(optimized.graph.node[0].op_type, "Identity") self.assertEqual(optimized.graph.node[0].output[0], "z") @@ -196,7 +212,7 @@ def test_fold_undefined_vars(self): """ ) # No optimizations expected. Just make sure it doesn't crash. - optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) + optimized = self._fold(model, onnx_shape_inference=False) self.assertEqual(len(optimized.graph.node), 6) def test_shape_inference(self): @@ -222,7 +238,7 @@ def test_shape_inference(self): } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model, onnx_shape_inference=True) print(onnx.printer.to_text(optimized)) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(optimized.graph.node[0].output[0], "C") @@ -274,7 +290,7 @@ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_ split_3 = SequenceAt (splits, int64_3) } """ - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 2) self.assertEqual(len(optimized.graph.node[-2].output), 4) self.assertEqual(optimized.graph.node[-2].op_type, "Split") @@ -301,7 +317,7 @@ def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_sp } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 3) self.assertEqual(len(optimized.graph.node[-2].output), 3) self.assertEqual(optimized.graph.node[-2].op_type, "Split") @@ -328,7 +344,7 @@ def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_ } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 7) self.assertEqual(len(optimized.graph.node[1].output), 3) self.assertEqual(optimized.graph.node[1].op_type, "Split") @@ -392,7 +408,7 @@ def test_static_split_to_sequence_with_uneven_split(self): } """ ) - optimized = optimizer.optimize(model, onnx_shape_inference=False) + optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) print(onnx.printer.to_text(optimized)) self.assertEqual(len(optimized.graph.node), 2) @@ -408,14 +424,14 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( ir_version: 8, opset_import: ["" : 18] > -func (float[1,3] x) => ( return_val) { +func (float[1,3] x) => (float[1,3] return_val) { const = Constant () splits = SplitToSequence (x, const) return_val = ConcatFromSequence (splits) } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 3) self.assertEqual(optimized.graph.node[2].op_type, "Concat") onnx.checker.check_model(optimized) @@ -429,14 +445,14 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( ir_version: 8, opset_import: ["" : 18] > -func (float[1,3] x) => ( return_val) { +func (float[1,3] x) => (float[1,3] return_val) { const = Constant () splits = SplitToSequence (x, const) return_val = ConcatFromSequence (splits) } """ ) - optimized = optimizer.optimize(model, num_iterations=1) + optimized = self._fold(model) self.assertEqual(len(optimized.graph.node), 7) self.assertEqual(optimized.graph.node[6].op_type, "Concat") onnx.checker.check_model(optimized) From 08592ae2428228808d0a1d2b531170cfeb286720 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 8 Jul 2024 11:42:10 -0700 Subject: [PATCH 12/22] Refactor tests --- onnxscript/optimizer/constant_folding_test.py | 76 ++----------------- onnxscript/optimizer/optimizer_test.py | 68 +++++++++++++++++ 2 files changed, 74 insertions(+), 70 deletions(-) create mode 100644 onnxscript/optimizer/optimizer_test.py diff --git a/onnxscript/optimizer/constant_folding_test.py b/onnxscript/optimizer/constant_folding_test.py index eb7adc711..bd4d5b65d 100644 --- a/onnxscript/optimizer/constant_folding_test.py +++ b/onnxscript/optimizer/constant_folding_test.py @@ -1,30 +1,31 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import parameterized import unittest import onnx +import parameterized import pytest import onnxscript.optimizer as optimizer -from onnxscript.optimizer import constant_folding, constant_folding_ir from onnxscript.ir import serde +from onnxscript.optimizer import constant_folding, constant_folding_ir @parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) class FoldConstantsTest(unittest.TestCase): - def _fold(self, model: onnx.ModelProto, onnx_shape_inference=False): if self.using_ir: ir_model = serde.deserialize_model(model) - constant_folding_ir.fold_constants(ir_model, onnx_shape_inference=onnx_shape_inference) + constant_folding_ir.fold_constants( + ir_model, onnx_shape_inference=onnx_shape_inference + ) optimizer.remove_unused_nodes(ir_model) return serde.serialize_model(ir_model) else: constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) optimizer.remove_unused_nodes(model) return model - + def test_fold_add(self): model = onnx.parser.parse_model( """ @@ -350,71 +351,6 @@ def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_ self.assertEqual(optimized.graph.node[1].op_type, "Split") self.assertEqual(len([n for n in optimized.graph.node if n.op_type == "Squeeze"]), 3) - def test_static_split_to_sequence_with_uneven_split(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], - producer_name: "pytorch", - producer_version: "2.2.0" -> -main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val) - < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1> -{ - _val_1 = Constant () - _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1) - _val_3 = Constant () - getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3) - _val_5 = Constant () - getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5) - return_val = Concat (getitem_1, getitem) -} -< - domain: "pkg.onnxscript.torch_lib", - opset_import: ["" : 18] -> -aten_split (self, split_size) => (return_val) -{ - return_val = SplitToSequence (self, split_size) -} -< - domain: "pkg.onnxscript.torch_lib", - opset_import: ["" : 18] -> -aten_getitem (self, i) => (return_val) -{ - return_val = SequenceAt (self, i) -} -< - domain: "pkg.onnxscript.torch_lib.common", - opset_import: ["" : 18] -> -Rank (input) => (return_val) -{ - tmp = Shape (input) - return_val = Size (tmp) -} -< - domain: "pkg.onnxscript.torch_lib.common", - opset_import: ["" : 18] -> -IsScalar (input) => (return_val) -{ - tmp = Shape (input) - tmp_0 = Size (tmp) - tmp_1 = Constant () - return_val = Equal (tmp_0, tmp_1) -} - """ - ) - optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) - - print(onnx.printer.to_text(optimized)) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(len(optimized.graph.node[0].output), 2) - self.assertEqual(optimized.graph.node[0].op_type, "Split") - def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( self, ): diff --git a/onnxscript/optimizer/optimizer_test.py b/onnxscript/optimizer/optimizer_test.py new file mode 100644 index 000000000..54d72fc01 --- /dev/null +++ b/onnxscript/optimizer/optimizer_test.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import onnx +import onnxscript.optimizer as optimizer + + + +class OptimizerTest(unittest.TestCase): + def test_static_split_to_sequence_with_uneven_split(self): + model = onnx.parser.parse_model( + """ + < + ir_version: 8, + opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], + producer_name: "pytorch", + producer_version: "2.2.0" + > + main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val) + < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1> + { + _val_1 = Constant () + _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1) + _val_3 = Constant () + getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3) + _val_5 = Constant () + getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5) + return_val = Concat (getitem_1, getitem) + } + + + aten_split (self, split_size) => (return_val) + { + return_val = SplitToSequence (self, split_size) + } + + + aten_getitem (self, i) => (return_val) + { + return_val = SequenceAt (self, i) + } + + + Rank (input) => (return_val) + { + tmp = Shape (input) + return_val = Size (tmp) + } + + + IsScalar (input) => (return_val) + { + tmp = Shape (input) + tmp_0 = Size (tmp) + tmp_1 = Constant () + return_val = Equal (tmp_0, tmp_1) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph.node[0].output), 2) + self.assertEqual(optimized.graph.node[0].op_type, "Split") + +if __name__ == "__main__": + unittest.main() From d2632247ef26ec8d272e696f44d8d1b6d65659f2 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 9 Jul 2024 17:24:56 -0700 Subject: [PATCH 13/22] Address PR feedback --- onnxscript/ir/serde.py | 8 +- onnxscript/optimizer/evaluator_ir.py | 663 +++++++++++++++++++++++++ onnxscript/optimizer/optimizer_test.py | 19 +- 3 files changed, 680 insertions(+), 10 deletions(-) create mode 100644 onnxscript/optimizer/evaluator_ir.py diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 1af6223b1..91003ab95 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -50,7 +50,7 @@ "serialize_tensor_into", "serialize_tensor", "serialize_type_into", - "serialize_value_into", + "serialize_typeserialize_value_into", "serialize_value", "SerdeError", ] @@ -1511,6 +1511,12 @@ def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtoc raise TypeError(f"Unsupported type: {from_}") +def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto: + type_proto = onnx.TypeProto() + serialize_type_into(type_proto, from_=type_protocol) + return type_proto + + @_capture_errors(lambda type_proto, from_: repr(from_)) def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None: value_field = type_proto.WhichOneof("value") diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py new file mode 100644 index 000000000..4d86dae60 --- /dev/null +++ b/onnxscript/optimizer/evaluator_ir.py @@ -0,0 +1,663 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- + +from __future__ import annotations + +import dataclasses +import logging +import math +from typing import Any, Callable, Protocol, Sequence, Union + +import numpy as np +import onnx +import onnx.reference.ops + +import onnxscript.ir as ir +import onnxscript.ir._convenience as _convenience +import onnxscript.optimizer.constant_folding as constant_folding +import onnxscript.rewriter.pattern as orp +from onnxscript.utils.utils import ( + get_node_attr_value, +) + +def is_control_flow_op(node: ir.Node) -> bool: + return any(isinstance(attr, (ir.AttrGraph, ir.AttrGraphs)) for attr in node.attributes.values()) + +def is_non_deterministic_op(node: ir.Node) -> bool: + return node.op_type in constant_folding.non_deterministic_ops and constant_folding.is_onnx_domain(node.domain) + +def is_constant_op(node: ir.Node) -> bool: + return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain(node.domain) + +_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT + +logger = logging.getLogger(__name__) + +# "Standard" evaluators are used to perform constant-folding. +# The API below works only for non-control-flow ops (ops without any graph-attributes). +# This currently used ONNX's reference implementation. But we could also +# use ORT's implementation if we want to. + + +class ReferenceEvaluator: + def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None: + try: + op_impl_class = onnx.reference.ops.load_op(domain, op, version) + return op_impl_class.eval # noqa: TRY300 + except Exception: + return None + + def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: + logger.debug("Evaluating %s::%s", domain, op) + evaluator = self.get_evaluator(domain, op, version) + if evaluator is None: + return None + return evaluator(*args, **kwargs) + + +_reference_evaluator = ReferenceEvaluator() + +# The "partial evaluators" below are non-standard evaluators. They are used to perform +# partial evaluation and/or static program analysis (abstract interpretation). + +# A partial-evaluator function takes an RewriterContext and a node, and returns the ir.Value +# or ir.Values to replace the output values of the node or None (if no replacement is needed). + +ReturnValue = Union[Sequence[ir.Value], ir.Value, None] +PartialEvaluatorFunction = Callable[[orp.RewriterContext, ir.Node], ReturnValue] + +@dataclasses.dataclass +class PartialEvaluator: + """A class that represents a partial-evaluator for a particular op. + + It is applicable for a specific version range (min_version, max_version) of the op. + The min_version and max_version can be None, indicating that there is no version + constraint in that direction. + """ + + min_version: int | None + max_version: int | None + function: PartialEvaluatorFunction + + def valid_for(self, version: int) -> bool: + """Returns True if this evaluator is applicable for the given version.""" + return (self.min_version is None or version >= self.min_version) and ( + self.max_version is None or version <= self.max_version + ) + + +class PartialEvaluatorRegistry: + """A class that maintains a registry of evaluators for ops.""" + + def __init__(self): + self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} + + def lookup_evaluators(self, domain: str, opname: str, version: int): + evaluator_list = self.op_evaluators.get((domain, opname), []) + return [ + evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) + ] + + def register(self, opname: str, domain: str = "", version=None): + if (domain, opname) not in self.op_evaluators: + evaluator_list = [] + self.op_evaluators[(domain, opname)] = evaluator_list + else: + evaluator_list = self.op_evaluators[(domain, opname)] + if version is None: + min_version = None + max_version = None + elif isinstance(version, int): + min_version = version + max_version = version + elif isinstance(version, tuple): + min_version, max_version = version + + def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: + evaluator_list.append(PartialEvaluator(min_version, max_version, function)) + return function + + return decorator + + +registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() + +register = registry.register + +def get_sym_value(val: ir.Value | None) -> ir.Value | None: + if val is None: + return None + if hasattr(val, "symbolic_value"): + return val.symbolic_value + return None + +def get_numpy_value(val: ir.Value) -> np.ndarray | None: + const_value = val.const_value + if hasattr(const_value, "numpy"): + return const_value.numpy() + return None + +def get_bool_value(val: ir.Value | None) -> bool | None: + if val is None: + return None + val = get_numpy_value(val) + if val is None: + return None + if isinstance(val, bool): + return val + if isinstance(val, np.bool_): + return bool(val) + if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: + return val.item(0) + return None + + +def getInput(node:ir.Node, index: int) -> ir.Value | None: + if index < len(node.inputs): + return node.inputs[index] + return None + +def getOutput(node:ir.Node, index: int) -> ir.Value | None: + if index < len(node.outputs): + return node.outputs[index] + return None + +def updateType(value: ir.Value, type: ir.TypeProtocol) -> None: + # TODO: merge types + value.type = type + +def getInputElementType(node: ir.Node, index: int) -> int: + input = getInput(node, index) + if input is not None and input.type is not None: + return input.type.dtype.value + return ir.DataType.UNDEFINED.value + +# TODO(rama): The following should not be necessary. Generic incremental shape-inference +# should handle this. This essentially implements type/shape-inference for Cast op. +@register("Cast") +def cast(op, node: ir.Node) -> ReturnValue: + input = getInput(node, 0) + output = getOutput(node, 0) + if input is not None and output is not None: + updateType(output, input.type) + return None + + +@register("CastLike") +def cast_like(op, node: ir.Node) -> ReturnValue: + input0 = node.inputs[0] + source_element_type = getInputElementType(node, 0) + target_element_type = getInputElementType(node, 1) + + if target_element_type == ir.DataType.UNDEFINED.value: + return None + if source_element_type == target_element_type: + return op.Identity(input0) + return op.Cast(input0, to=target_element_type) + + +@register("Shape") +def shape(op, node: ir.Node) -> ReturnValue: + input = node.inputs[0] + shape = input.shape + if shape is None: + return None + start = node.attributes.get("start", 0) + end = node.attributes.get("end", None) + shape_slice = shape[start:end] + if all(isinstance(d,int) for d in shape_slice): + return op.Constant(value_ints = [d for d in shape_slice]) + return None + + +@register("Size") +def size(op, node: ir.Node) -> ReturnValue: + shape = node.inputs[0].shape + if shape is None: + return None + size = 1 + for d in shape: + if not isinstance(d, int): + return None + size *= d + return op.Constant(value_int = size) + +@register("If") +def if_op(op, node: ir.Node) -> ReturnValue: + cond = getInput(node, 0) + cond = get_bool_value(cond) + if cond is not None: + # cond is a constant-value: inline the branch + branch = "then_branch" if cond else "else_branch" + graph = node.attributes.get(branch, None) + if graph is None: + return None + formal_outs = graph.outputs + actual_outs = node.outputs + renamings = { + formal.name: actual.name + for formal, actual in zip(formal_outs, actual_outs) + if actual is not None + } + # TODO: Extend renaming to intermediate values. + + def rename(name): + return renamings.get(name, name) + + for sub_node in graph: + # TODO: handle renaming inside subgraphs in nodes + for v in sub_node.outputs: + v.name = rename(v.name) + # Avoid name collision. + sub_node.name = f"{node.name}_{sub_node.name}" + + # TODO: we should handle initializers as well! + return formal_outs + return None + + +@register("Identity") +def identity(op, node: ir.Node) -> ReturnValue: + del op + input = node.inputs[0] + output = node.outputs[0] + if input is not None and output is not None: + output.symbolic_value = input + return None + + +@register("SequenceConstruct") +def sequence_construct(op, node: ir.Node) -> ReturnValue: + del op + output = node.outputs[0] + if output is not None: + output.symbolic_value = list(node.inputs) + return None + + +@register("ConcatFromSequence") +def concat_from_sequence(op, node: ir.Node) -> ReturnValue: + input = node.inputs[0] + inputs = input.symbolic_value + if any(x is None for x in inputs): + return None + new_axis = node.attributes.get("new_axis", 0) + axis = node.attributes["axis"] + if input is not None and isinstance(input.symbolic_value, list): + if new_axis == 0: + logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) + return op.Concat(*inputs, axis=axis) + if new_axis == 1: + # Unsqueeze the inputs with concat axis if new_axis is 1 + axis_value = op.Constant(value_int=axis) + unsqueezed_inputs = [] + for node_input in inputs: + unsqueezed_input = op.Unsqueeze(node_input, axis_value, output=[f"{node_input.name}_unsqueeze"]) + unsqueezed_inputs.append(unsqueezed_input) + # Send unsqueezed outputs to Concat + logger.debug( + "ConcatFromSequence => Concat %s", + [x.name for x in unsqueezed_inputs] + ) + return op.Concat(*unsqueezed_inputs, axis=axis) + return None + + +@register("SplitToSequence") +def split_to_sequence(op, node: ir.Node) -> ReturnValue: + """Rewriting pattern. + + From + + splits = onnx::SplitToSequence(input, split, axis=axis) + + to + + split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + or + + split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + where number of output tensors in `splits` is statically known. + onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. + This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. + """ + input = node.inputs[0] + split = node.inputs[1] + output = node.outputs[0] + + if input is None or split is None or output is None: + return None + + axis = node.attributes.get("axis", 0) + shape = input.shape + if shape is None: + return None + rank = len(shape) + if axis < 0: + axis = axis + rank + if axis < 0 or axis >= rank: + return None + split_dimension_size = shape[axis] + if not isinstance(split_dimension_size, int): + return None + + split_value = get_numpy_value(split) + if split_value is None: + return None + assert isinstance(split_value, np.ndarray) + + if split_value.ndim == 0: + # split into chunks all of size 'split' if possible. + num_outputs = math.ceil(split_dimension_size / split_value.item()) + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split(input, axis=axis, num_outputs=num_outputs, output=split_outputs) + elif split_value.ndim == 1: + # split into 'size(split)' chunks + num_outputs = split_value.size + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split(input, split, axis=axis, output=split_outputs) + else: + return None + + keepdims = node.attributes.get("keepdims", 1) + if keepdims == 0: + # squeeze the split dimension if keepdims is 0 + axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) + squeezed_values = [] + for i in range(num_outputs): + squeezed = op.Squeeze(split_values[i], axis_val, output=[f"{split_outputs[i]}_squeeze"]) + squeezed_values.append(squeezed) + split_values = squeezed_values + + logger.debug("SplitToSequence => Split + SequenceConstruct") + + return op.SequenceConstruct(*split_values) + + +@register("SequenceAt") +def sequence_at(op, node: ir.Node) -> ReturnValue: + input = node.inputs[0] + position = node.inputs[1] + output = node.outputs[0] + if input is not None and position is not None: + input_vals = input.symbolic_value + position_val = get_numpy_value(position) + if isinstance(input_vals, list) and position_val is not None: + if position_val.size != 1: + return None + position_val = position_val.item() + try: + result = input_vals[position_val] + except IndexError: + return None + output.symbolic_value = result + logger.debug("SequenceAt %s => %s", input.name, result.name) + return op.Identity(result) + return None + +@dataclasses.dataclass +class Replacement: + """A replacement for a node in the graph.""" + new_outputs: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] + +class ConstantFolder: + opset_imports: dict[str, int] + + def __init__( + self, + external_data_folder: str, + do_shape_inference: bool, + ) -> None: + self._external_data_folder = external_data_folder + self._do_shape_inference = do_shape_inference + self._init() + + def _init(self) -> None: + self.counts = {} + self.sizes = {} + self.modified = False + + def _do_inference(self, node: ir.Node) -> None: + output_types = {} + + # TODO: handle optional inputs + def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: + value = get_numpy_value(x) + if isinstance(value, np.ndarray) and value.size < 20: + return onnx.numpy_helper.from_array(value, node.inputs[i].name) + return None + + def get_type(value: ir.Value) -> onnx.TypeProto | None: + if value.type is not None: + type_proto = ir.serde.serialize_type(value.type) + if value.shape is not None: + ir.serde.serialize_shape_into(type_proto, value.shape) + return type_proto + return None + + input_types = {x.name: get_type(x) for x in node.inputs if x is not None} + input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None} + input_data = {k: v for k, v in input_data.items() if v is not None} + if any(t is None for t in input_types.values()): + logger.debug( + "Skipping shape inference for node %s due to missing input type.", + node.name, + ) + else: + # TODO: pass in constant values, ir_version + try: + schema = onnx.defs.get_schema( + node.op_type, self.opset_imports[node.domain], node.domain + ) + output_types = onnx.shape_inference.infer_node_outputs( + schema, node, input_types, input_data + ) + for output in node.outputs: + if output.name in output_types: + inferred_type = output_types[output.name] + # TODO: merge types, check for conflicts + output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type) + output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) + except Exception as e: + logger.debug( + "Skipping shape inference for node %s due to exception: %s", + node.name, + e, + ) + + + + def new_constant(self, irvalue: ir.Value, value): + # TODO(rama): Why do we need the conversion below? + if isinstance(value, (int, float, np.ScalarType)): + value = np.array(value) + + if not isinstance(value, np.ndarray): + # ONNX does not have a way to represent non-tensor constants, eg. a sequence. + # So, a constant-value of type sequence is not folded, but it can be used + # to optimize subsequent operations when possible. + logger.info( + "Skip storing constant folded value %s due to unsupported type %s.", + irvalue.name, + type(value), + ) + return None + + irvalue.const_value = _convenience.tensor(value) + + if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: + logger.info( + "Skip storing constant folded nvalue %s due to large size %s.", + irvalue.name, + value.nbytes, + ) + return None + + tensor = onnx.numpy_helper.from_array(value, irvalue.name) + + logger.debug( + "New constant for value %s dtype: %s shape: %s", + irvalue.name, + value.dtype, + value.shape, + ) + + # TODO(rama) + # irvalue.type = onnx.helper.make_tensor_type_proto( + # onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape + # ) + attributes = _convenience.convert_attributes({"value": tensor}) + node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) + return node + + def process_node(self, node: ir.Node): + for i, value in enumerate(node.inputs): + sym_value = get_sym_value(value) + if isinstance(sym_value, ir.Value): + node.replace_input_with(i, sym_value) + # TODO(rama): consider merging type/other info from both values + + # Do incremental shape inference + if self._do_shape_inference and not is_control_flow_op(node): + self._do_inference(node) + + if node.domain not in self.opset_imports: + return None + version = self.opset_imports[node.domain] + op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) + for optimizer in op_optimizers: + assert optimizer + context = orp.RewriterContext() + output = optimizer(context, node) + if output is not None: + # TODO(rama): return nodes, values + if isinstance(output, ir.Value): + output = [output] + return Replacement(output, context.nodes) + + if is_control_flow_op(node) or is_non_deterministic_op(node): + return None + + if any((x is not None and x.const_value is None) for x in node.inputs): + return None + + input_values = [x.const_value.numpy() if x is not None else None for x in node.inputs] + # Filter out bfloat16 cases? + def convert(av): + if isinstance(av, ir.AttrTensor): + return ir.serde.serialize_tensor(av.value) + return av.value + attr_values = { name: convert(attr) for name, attr in node.attributes.items() } + outputs = _reference_evaluator.evaluate(node.domain, node.op_type, version, *input_values, **attr_values) + + if outputs is None: + return None + if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): + replacement = self.new_constant(node.outputs[0], outputs) + if is_constant_op(node) or replacement is None: + return None + # self.add_count(op, outputs.size) + return Replacement(replacement.outputs, [replacement]) + else: + logger.warning("Skipping constant folding for op %s with multiple outputs.", node.op_type) + return None + + def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): + # TODO: apply delta! what about opset_imports? + old_values = node.outputs + new_values = replacement.new_outputs + for old_value, new_value in zip(old_values, new_values): + # Propagate relevant info from old value to new value + # TODO(Rama): Perhaps we should merge old and new types. As of now, new + # values don't have type information. Note that this could be a problem + # for semantics-altering rewrite-rules: we should allow users to override + # this for such rules. + new_value.type = old_value.type + new_value.shape = old_value.shape + new_value.const_value = old_value.const_value + new_value.name = old_value.name + + # Reconnect the users of the deleted node to use the new outputs + _convenience.replace_all_uses_with(old_values, new_values) + # Update graph/function outputs if the node generates output + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(root.outputs): + if graph_or_function_output in replacement_mapping: + root.outputs[idx] = replacement_mapping[graph_or_function_output] + + # insert new nodes after the index node + root.insert_after(node, replacement.new_nodes) + root.remove(node, safe=True) + + # if isinstance(output, list): + # return output + # else: + # # Currently handles single output only + # self.add_count(node.op_type, output.size) + # return self.new_constant(node.output[0], output) + + def visit_attribute(self, attr: ir.Attr) -> None: + if isinstance(attr, ir.AttrGraph): + self.visit_graph(attr.value) + elif isinstance(attr, ir.AttrGraphs): + for graph in attr.value: + self.visit_graph(graph) + + def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): + replacement = self.process_node(node) + # logger.debug( + # "visit_node: %s::%s %s replacement %s", + # node.domain, + # node.op_type, + # node.name, + # "found" if replacement is not None else "missed", + # ) + if replacement is None: + # No change. Process attributes. + for attr in node.attributes.values(): + self.visit_attribute(attr) + return None + + else: + self.replace_node(node, replacement, root) + + def visit_graph(self, graph: ir.Graph) -> None: + for node in graph: + self.visit_node(node, graph) + + def visit_model(self, model: ir.Model) -> None: + self._init() + self.opset_imports = model.opset_imports + self.visit_graph(model.graph) + # TODO(rama): handle functions + +def fold_constants( + model: ir.Model, + external_data_folder: str = "", + *, + onnx_shape_inference: bool = False, +) -> bool: + """ + Applies constant folding optimization to the model. + Returns true iff the model was modified. + """ + folder = ConstantFolder( + external_data_folder, + onnx_shape_inference, + ) + folder.visit_model(model) + for op in folder.counts: + logger.info( + "Constant-folded '%s' %s times, with %s size.", + op, + folder.counts[op], + folder.sizes[op], + ) + return folder.modified \ No newline at end of file diff --git a/onnxscript/optimizer/optimizer_test.py b/onnxscript/optimizer/optimizer_test.py index 54d72fc01..57f6f3a80 100644 --- a/onnxscript/optimizer/optimizer_test.py +++ b/onnxscript/optimizer/optimizer_test.py @@ -4,14 +4,14 @@ import unittest import onnx -import onnxscript.optimizer as optimizer +import onnxscript.optimizer as optimizer class OptimizerTest(unittest.TestCase): - def test_static_split_to_sequence_with_uneven_split(self): - model = onnx.parser.parse_model( - """ + def test_static_split_to_sequence_with_uneven_split(self): + model = onnx.parser.parse_model( + """ < ir_version: 8, opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], @@ -58,11 +58,12 @@ def test_static_split_to_sequence_with_uneven_split(self): return_val = Equal (tmp_0, tmp_1) } """ - ) - optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(len(optimized.graph.node[0].output), 2) - self.assertEqual(optimized.graph.node[0].op_type, "Split") + ) + optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph.node[0].output), 2) + self.assertEqual(optimized.graph.node[0].op_type, "Split") + if __name__ == "__main__": unittest.main() From dfdf3f0de116d800bf81e31657994b16264a9b10 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 10 Jul 2024 09:53:33 -0700 Subject: [PATCH 14/22] Fix accidental character deletion --- onnxscript/ir/serde.py | 3 +- onnxscript/optimizer/evaluator_ir.py | 88 +++++++++++++++++++--------- 2 files changed, 61 insertions(+), 30 deletions(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 91003ab95..a664b59ee 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -50,7 +50,8 @@ "serialize_tensor_into", "serialize_tensor", "serialize_type_into", - "serialize_typeserialize_value_into", + "serialize_type", + "serialize_value_into", "serialize_value", "SerdeError", ] diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py index 4d86dae60..b55f21450 100644 --- a/onnxscript/optimizer/evaluator_ir.py +++ b/onnxscript/optimizer/evaluator_ir.py @@ -8,7 +8,7 @@ import dataclasses import logging import math -from typing import Any, Callable, Protocol, Sequence, Union +from typing import Any, Callable, Sequence, Union import numpy as np import onnx @@ -18,18 +18,26 @@ import onnxscript.ir._convenience as _convenience import onnxscript.optimizer.constant_folding as constant_folding import onnxscript.rewriter.pattern as orp -from onnxscript.utils.utils import ( - get_node_attr_value, -) + def is_control_flow_op(node: ir.Node) -> bool: - return any(isinstance(attr, (ir.AttrGraph, ir.AttrGraphs)) for attr in node.attributes.values()) + return any( + isinstance(attr, (ir.AttrGraph, ir.AttrGraphs)) for attr in node.attributes.values() + ) + def is_non_deterministic_op(node: ir.Node) -> bool: - return node.op_type in constant_folding.non_deterministic_ops and constant_folding.is_onnx_domain(node.domain) + return ( + node.op_type in constant_folding.non_deterministic_ops + and constant_folding.is_onnx_domain(node.domain) + ) + def is_constant_op(node: ir.Node) -> bool: - return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain(node.domain) + return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain( + node.domain + ) + _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT @@ -68,6 +76,7 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: ReturnValue = Union[Sequence[ir.Value], ir.Value, None] PartialEvaluatorFunction = Callable[[orp.RewriterContext, ir.Node], ReturnValue] + @dataclasses.dataclass class PartialEvaluator: """A class that represents a partial-evaluator for a particular op. @@ -126,6 +135,7 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: register = registry.register + def get_sym_value(val: ir.Value | None) -> ir.Value | None: if val is None: return None @@ -133,12 +143,14 @@ def get_sym_value(val: ir.Value | None) -> ir.Value | None: return val.symbolic_value return None + def get_numpy_value(val: ir.Value) -> np.ndarray | None: const_value = val.const_value if hasattr(const_value, "numpy"): return const_value.numpy() return None + def get_bool_value(val: ir.Value | None) -> bool | None: if val is None: return None @@ -154,26 +166,30 @@ def get_bool_value(val: ir.Value | None) -> bool | None: return None -def getInput(node:ir.Node, index: int) -> ir.Value | None: +def getInput(node: ir.Node, index: int) -> ir.Value | None: if index < len(node.inputs): return node.inputs[index] return None -def getOutput(node:ir.Node, index: int) -> ir.Value | None: + +def getOutput(node: ir.Node, index: int) -> ir.Value | None: if index < len(node.outputs): return node.outputs[index] return None + def updateType(value: ir.Value, type: ir.TypeProtocol) -> None: # TODO: merge types value.type = type + def getInputElementType(node: ir.Node, index: int) -> int: input = getInput(node, index) if input is not None and input.type is not None: return input.type.dtype.value return ir.DataType.UNDEFINED.value + # TODO(rama): The following should not be necessary. Generic incremental shape-inference # should handle this. This essentially implements type/shape-inference for Cast op. @register("Cast") @@ -207,8 +223,8 @@ def shape(op, node: ir.Node) -> ReturnValue: start = node.attributes.get("start", 0) end = node.attributes.get("end", None) shape_slice = shape[start:end] - if all(isinstance(d,int) for d in shape_slice): - return op.Constant(value_ints = [d for d in shape_slice]) + if all(isinstance(d, int) for d in shape_slice): + return op.Constant(value_ints=[d for d in shape_slice]) return None @@ -222,7 +238,8 @@ def size(op, node: ir.Node) -> ReturnValue: if not isinstance(d, int): return None size *= d - return op.Constant(value_int = size) + return op.Constant(value_int=size) + @register("If") def if_op(op, node: ir.Node) -> ReturnValue: @@ -294,12 +311,13 @@ def concat_from_sequence(op, node: ir.Node) -> ReturnValue: axis_value = op.Constant(value_int=axis) unsqueezed_inputs = [] for node_input in inputs: - unsqueezed_input = op.Unsqueeze(node_input, axis_value, output=[f"{node_input.name}_unsqueeze"]) + unsqueezed_input = op.Unsqueeze( + node_input, axis_value, output=[f"{node_input.name}_unsqueeze"] + ) unsqueezed_inputs.append(unsqueezed_input) # Send unsqueezed outputs to Concat logger.debug( - "ConcatFromSequence => Concat %s", - [x.name for x in unsqueezed_inputs] + "ConcatFromSequence => Concat %s", [x.name for x in unsqueezed_inputs] ) return op.Concat(*unsqueezed_inputs, axis=axis) return None @@ -356,7 +374,9 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: # split into chunks all of size 'split' if possible. num_outputs = math.ceil(split_dimension_size / split_value.item()) split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split(input, axis=axis, num_outputs=num_outputs, output=split_outputs) + split_values = op.Split( + input, axis=axis, num_outputs=num_outputs, output=split_outputs + ) elif split_value.ndim == 1: # split into 'size(split)' chunks num_outputs = split_value.size @@ -371,7 +391,9 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) squeezed_values = [] for i in range(num_outputs): - squeezed = op.Squeeze(split_values[i], axis_val, output=[f"{split_outputs[i]}_squeeze"]) + squeezed = op.Squeeze( + split_values[i], axis_val, output=[f"{split_outputs[i]}_squeeze"] + ) squeezed_values.append(squeezed) split_values = squeezed_values @@ -396,17 +418,20 @@ def sequence_at(op, node: ir.Node) -> ReturnValue: result = input_vals[position_val] except IndexError: return None - output.symbolic_value = result + output.symbolic_value = result logger.debug("SequenceAt %s => %s", input.name, result.name) return op.Identity(result) return None + @dataclasses.dataclass class Replacement: """A replacement for a node in the graph.""" + new_outputs: Sequence[ir.Value] new_nodes: Sequence[ir.Node] + class ConstantFolder: opset_imports: dict[str, int] @@ -433,7 +458,7 @@ def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: if isinstance(value, np.ndarray) and value.size < 20: return onnx.numpy_helper.from_array(value, node.inputs[i].name) return None - + def get_type(value: ir.Value) -> onnx.TypeProto | None: if value.type is not None: type_proto = ir.serde.serialize_type(value.type) @@ -462,7 +487,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: for output in node.outputs: if output.name in output_types: inferred_type = output_types[output.name] - # TODO: merge types, check for conflicts + # TODO: merge types, check for conflicts output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type) output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: @@ -472,8 +497,6 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) - - def new_constant(self, irvalue: ir.Value, value): # TODO(rama): Why do we need the conversion below? if isinstance(value, (int, float, np.ScalarType)): @@ -516,7 +539,7 @@ def new_constant(self, irvalue: ir.Value, value): attributes = _convenience.convert_attributes({"value": tensor}) node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) return node - + def process_node(self, node: ir.Node): for i, value in enumerate(node.inputs): sym_value = get_sym_value(value) @@ -530,7 +553,7 @@ def process_node(self, node: ir.Node): if node.domain not in self.opset_imports: return None - version = self.opset_imports[node.domain] + version = self.opset_imports[node.domain] op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) for optimizer in op_optimizers: assert optimizer @@ -549,13 +572,17 @@ def process_node(self, node: ir.Node): return None input_values = [x.const_value.numpy() if x is not None else None for x in node.inputs] + # Filter out bfloat16 cases? def convert(av): if isinstance(av, ir.AttrTensor): return ir.serde.serialize_tensor(av.value) return av.value - attr_values = { name: convert(attr) for name, attr in node.attributes.items() } - outputs = _reference_evaluator.evaluate(node.domain, node.op_type, version, *input_values, **attr_values) + + attr_values = {name: convert(attr) for name, attr in node.attributes.items()} + outputs = _reference_evaluator.evaluate( + node.domain, node.op_type, version, *input_values, **attr_values + ) if outputs is None: return None @@ -566,7 +593,9 @@ def convert(av): # self.add_count(op, outputs.size) return Replacement(replacement.outputs, [replacement]) else: - logger.warning("Skipping constant folding for op %s with multiple outputs.", node.op_type) + logger.warning( + "Skipping constant folding for op %s with multiple outputs.", node.op_type + ) return None def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): @@ -624,7 +653,7 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): for attr in node.attributes.values(): self.visit_attribute(attr) return None - + else: self.replace_node(node, replacement, root) @@ -638,6 +667,7 @@ def visit_model(self, model: ir.Model) -> None: self.visit_graph(model.graph) # TODO(rama): handle functions + def fold_constants( model: ir.Model, external_data_folder: str = "", @@ -660,4 +690,4 @@ def fold_constants( folder.counts[op], folder.sizes[op], ) - return folder.modified \ No newline at end of file + return folder.modified From 82ff9607f0d39710030ea3b9d9b0752a87a64c2e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 10 Jul 2024 11:05:13 -0700 Subject: [PATCH 15/22] Remove duplicate file --- onnxscript/optimizer/evaluator_ir.py | 693 --------------------------- 1 file changed, 693 deletions(-) delete mode 100644 onnxscript/optimizer/evaluator_ir.py diff --git a/onnxscript/optimizer/evaluator_ir.py b/onnxscript/optimizer/evaluator_ir.py deleted file mode 100644 index b55f21450..000000000 --- a/onnxscript/optimizer/evaluator_ir.py +++ /dev/null @@ -1,693 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# ------------------------------------------------------------------------- - -from __future__ import annotations - -import dataclasses -import logging -import math -from typing import Any, Callable, Sequence, Union - -import numpy as np -import onnx -import onnx.reference.ops - -import onnxscript.ir as ir -import onnxscript.ir._convenience as _convenience -import onnxscript.optimizer.constant_folding as constant_folding -import onnxscript.rewriter.pattern as orp - - -def is_control_flow_op(node: ir.Node) -> bool: - return any( - isinstance(attr, (ir.AttrGraph, ir.AttrGraphs)) for attr in node.attributes.values() - ) - - -def is_non_deterministic_op(node: ir.Node) -> bool: - return ( - node.op_type in constant_folding.non_deterministic_ops - and constant_folding.is_onnx_domain(node.domain) - ) - - -def is_constant_op(node: ir.Node) -> bool: - return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain( - node.domain - ) - - -_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT - -logger = logging.getLogger(__name__) - -# "Standard" evaluators are used to perform constant-folding. -# The API below works only for non-control-flow ops (ops without any graph-attributes). -# This currently used ONNX's reference implementation. But we could also -# use ORT's implementation if we want to. - - -class ReferenceEvaluator: - def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None: - try: - op_impl_class = onnx.reference.ops.load_op(domain, op, version) - return op_impl_class.eval # noqa: TRY300 - except Exception: - return None - - def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: - logger.debug("Evaluating %s::%s", domain, op) - evaluator = self.get_evaluator(domain, op, version) - if evaluator is None: - return None - return evaluator(*args, **kwargs) - - -_reference_evaluator = ReferenceEvaluator() - -# The "partial evaluators" below are non-standard evaluators. They are used to perform -# partial evaluation and/or static program analysis (abstract interpretation). - -# A partial-evaluator function takes an RewriterContext and a node, and returns the ir.Value -# or ir.Values to replace the output values of the node or None (if no replacement is needed). - -ReturnValue = Union[Sequence[ir.Value], ir.Value, None] -PartialEvaluatorFunction = Callable[[orp.RewriterContext, ir.Node], ReturnValue] - - -@dataclasses.dataclass -class PartialEvaluator: - """A class that represents a partial-evaluator for a particular op. - - It is applicable for a specific version range (min_version, max_version) of the op. - The min_version and max_version can be None, indicating that there is no version - constraint in that direction. - """ - - min_version: int | None - max_version: int | None - function: PartialEvaluatorFunction - - def valid_for(self, version: int) -> bool: - """Returns True if this evaluator is applicable for the given version.""" - return (self.min_version is None or version >= self.min_version) and ( - self.max_version is None or version <= self.max_version - ) - - -class PartialEvaluatorRegistry: - """A class that maintains a registry of evaluators for ops.""" - - def __init__(self): - self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} - - def lookup_evaluators(self, domain: str, opname: str, version: int): - evaluator_list = self.op_evaluators.get((domain, opname), []) - return [ - evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) - ] - - def register(self, opname: str, domain: str = "", version=None): - if (domain, opname) not in self.op_evaluators: - evaluator_list = [] - self.op_evaluators[(domain, opname)] = evaluator_list - else: - evaluator_list = self.op_evaluators[(domain, opname)] - if version is None: - min_version = None - max_version = None - elif isinstance(version, int): - min_version = version - max_version = version - elif isinstance(version, tuple): - min_version, max_version = version - - def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: - evaluator_list.append(PartialEvaluator(min_version, max_version, function)) - return function - - return decorator - - -registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() - -register = registry.register - - -def get_sym_value(val: ir.Value | None) -> ir.Value | None: - if val is None: - return None - if hasattr(val, "symbolic_value"): - return val.symbolic_value - return None - - -def get_numpy_value(val: ir.Value) -> np.ndarray | None: - const_value = val.const_value - if hasattr(const_value, "numpy"): - return const_value.numpy() - return None - - -def get_bool_value(val: ir.Value | None) -> bool | None: - if val is None: - return None - val = get_numpy_value(val) - if val is None: - return None - if isinstance(val, bool): - return val - if isinstance(val, np.bool_): - return bool(val) - if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: - return val.item(0) - return None - - -def getInput(node: ir.Node, index: int) -> ir.Value | None: - if index < len(node.inputs): - return node.inputs[index] - return None - - -def getOutput(node: ir.Node, index: int) -> ir.Value | None: - if index < len(node.outputs): - return node.outputs[index] - return None - - -def updateType(value: ir.Value, type: ir.TypeProtocol) -> None: - # TODO: merge types - value.type = type - - -def getInputElementType(node: ir.Node, index: int) -> int: - input = getInput(node, index) - if input is not None and input.type is not None: - return input.type.dtype.value - return ir.DataType.UNDEFINED.value - - -# TODO(rama): The following should not be necessary. Generic incremental shape-inference -# should handle this. This essentially implements type/shape-inference for Cast op. -@register("Cast") -def cast(op, node: ir.Node) -> ReturnValue: - input = getInput(node, 0) - output = getOutput(node, 0) - if input is not None and output is not None: - updateType(output, input.type) - return None - - -@register("CastLike") -def cast_like(op, node: ir.Node) -> ReturnValue: - input0 = node.inputs[0] - source_element_type = getInputElementType(node, 0) - target_element_type = getInputElementType(node, 1) - - if target_element_type == ir.DataType.UNDEFINED.value: - return None - if source_element_type == target_element_type: - return op.Identity(input0) - return op.Cast(input0, to=target_element_type) - - -@register("Shape") -def shape(op, node: ir.Node) -> ReturnValue: - input = node.inputs[0] - shape = input.shape - if shape is None: - return None - start = node.attributes.get("start", 0) - end = node.attributes.get("end", None) - shape_slice = shape[start:end] - if all(isinstance(d, int) for d in shape_slice): - return op.Constant(value_ints=[d for d in shape_slice]) - return None - - -@register("Size") -def size(op, node: ir.Node) -> ReturnValue: - shape = node.inputs[0].shape - if shape is None: - return None - size = 1 - for d in shape: - if not isinstance(d, int): - return None - size *= d - return op.Constant(value_int=size) - - -@register("If") -def if_op(op, node: ir.Node) -> ReturnValue: - cond = getInput(node, 0) - cond = get_bool_value(cond) - if cond is not None: - # cond is a constant-value: inline the branch - branch = "then_branch" if cond else "else_branch" - graph = node.attributes.get(branch, None) - if graph is None: - return None - formal_outs = graph.outputs - actual_outs = node.outputs - renamings = { - formal.name: actual.name - for formal, actual in zip(formal_outs, actual_outs) - if actual is not None - } - # TODO: Extend renaming to intermediate values. - - def rename(name): - return renamings.get(name, name) - - for sub_node in graph: - # TODO: handle renaming inside subgraphs in nodes - for v in sub_node.outputs: - v.name = rename(v.name) - # Avoid name collision. - sub_node.name = f"{node.name}_{sub_node.name}" - - # TODO: we should handle initializers as well! - return formal_outs - return None - - -@register("Identity") -def identity(op, node: ir.Node) -> ReturnValue: - del op - input = node.inputs[0] - output = node.outputs[0] - if input is not None and output is not None: - output.symbolic_value = input - return None - - -@register("SequenceConstruct") -def sequence_construct(op, node: ir.Node) -> ReturnValue: - del op - output = node.outputs[0] - if output is not None: - output.symbolic_value = list(node.inputs) - return None - - -@register("ConcatFromSequence") -def concat_from_sequence(op, node: ir.Node) -> ReturnValue: - input = node.inputs[0] - inputs = input.symbolic_value - if any(x is None for x in inputs): - return None - new_axis = node.attributes.get("new_axis", 0) - axis = node.attributes["axis"] - if input is not None and isinstance(input.symbolic_value, list): - if new_axis == 0: - logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) - return op.Concat(*inputs, axis=axis) - if new_axis == 1: - # Unsqueeze the inputs with concat axis if new_axis is 1 - axis_value = op.Constant(value_int=axis) - unsqueezed_inputs = [] - for node_input in inputs: - unsqueezed_input = op.Unsqueeze( - node_input, axis_value, output=[f"{node_input.name}_unsqueeze"] - ) - unsqueezed_inputs.append(unsqueezed_input) - # Send unsqueezed outputs to Concat - logger.debug( - "ConcatFromSequence => Concat %s", [x.name for x in unsqueezed_inputs] - ) - return op.Concat(*unsqueezed_inputs, axis=axis) - return None - - -@register("SplitToSequence") -def split_to_sequence(op, node: ir.Node) -> ReturnValue: - """Rewriting pattern. - - From - - splits = onnx::SplitToSequence(input, split, axis=axis) - - to - - split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) - splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) - - or - - split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) - splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) - - where number of output tensors in `splits` is statically known. - onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. - This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. - """ - input = node.inputs[0] - split = node.inputs[1] - output = node.outputs[0] - - if input is None or split is None or output is None: - return None - - axis = node.attributes.get("axis", 0) - shape = input.shape - if shape is None: - return None - rank = len(shape) - if axis < 0: - axis = axis + rank - if axis < 0 or axis >= rank: - return None - split_dimension_size = shape[axis] - if not isinstance(split_dimension_size, int): - return None - - split_value = get_numpy_value(split) - if split_value is None: - return None - assert isinstance(split_value, np.ndarray) - - if split_value.ndim == 0: - # split into chunks all of size 'split' if possible. - num_outputs = math.ceil(split_dimension_size / split_value.item()) - split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split( - input, axis=axis, num_outputs=num_outputs, output=split_outputs - ) - elif split_value.ndim == 1: - # split into 'size(split)' chunks - num_outputs = split_value.size - split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split(input, split, axis=axis, output=split_outputs) - else: - return None - - keepdims = node.attributes.get("keepdims", 1) - if keepdims == 0: - # squeeze the split dimension if keepdims is 0 - axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) - squeezed_values = [] - for i in range(num_outputs): - squeezed = op.Squeeze( - split_values[i], axis_val, output=[f"{split_outputs[i]}_squeeze"] - ) - squeezed_values.append(squeezed) - split_values = squeezed_values - - logger.debug("SplitToSequence => Split + SequenceConstruct") - - return op.SequenceConstruct(*split_values) - - -@register("SequenceAt") -def sequence_at(op, node: ir.Node) -> ReturnValue: - input = node.inputs[0] - position = node.inputs[1] - output = node.outputs[0] - if input is not None and position is not None: - input_vals = input.symbolic_value - position_val = get_numpy_value(position) - if isinstance(input_vals, list) and position_val is not None: - if position_val.size != 1: - return None - position_val = position_val.item() - try: - result = input_vals[position_val] - except IndexError: - return None - output.symbolic_value = result - logger.debug("SequenceAt %s => %s", input.name, result.name) - return op.Identity(result) - return None - - -@dataclasses.dataclass -class Replacement: - """A replacement for a node in the graph.""" - - new_outputs: Sequence[ir.Value] - new_nodes: Sequence[ir.Node] - - -class ConstantFolder: - opset_imports: dict[str, int] - - def __init__( - self, - external_data_folder: str, - do_shape_inference: bool, - ) -> None: - self._external_data_folder = external_data_folder - self._do_shape_inference = do_shape_inference - self._init() - - def _init(self) -> None: - self.counts = {} - self.sizes = {} - self.modified = False - - def _do_inference(self, node: ir.Node) -> None: - output_types = {} - - # TODO: handle optional inputs - def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: - value = get_numpy_value(x) - if isinstance(value, np.ndarray) and value.size < 20: - return onnx.numpy_helper.from_array(value, node.inputs[i].name) - return None - - def get_type(value: ir.Value) -> onnx.TypeProto | None: - if value.type is not None: - type_proto = ir.serde.serialize_type(value.type) - if value.shape is not None: - ir.serde.serialize_shape_into(type_proto, value.shape) - return type_proto - return None - - input_types = {x.name: get_type(x) for x in node.inputs if x is not None} - input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None} - input_data = {k: v for k, v in input_data.items() if v is not None} - if any(t is None for t in input_types.values()): - logger.debug( - "Skipping shape inference for node %s due to missing input type.", - node.name, - ) - else: - # TODO: pass in constant values, ir_version - try: - schema = onnx.defs.get_schema( - node.op_type, self.opset_imports[node.domain], node.domain - ) - output_types = onnx.shape_inference.infer_node_outputs( - schema, node, input_types, input_data - ) - for output in node.outputs: - if output.name in output_types: - inferred_type = output_types[output.name] - # TODO: merge types, check for conflicts - output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type) - output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) - except Exception as e: - logger.debug( - "Skipping shape inference for node %s due to exception: %s", - node.name, - e, - ) - - def new_constant(self, irvalue: ir.Value, value): - # TODO(rama): Why do we need the conversion below? - if isinstance(value, (int, float, np.ScalarType)): - value = np.array(value) - - if not isinstance(value, np.ndarray): - # ONNX does not have a way to represent non-tensor constants, eg. a sequence. - # So, a constant-value of type sequence is not folded, but it can be used - # to optimize subsequent operations when possible. - logger.info( - "Skip storing constant folded value %s due to unsupported type %s.", - irvalue.name, - type(value), - ) - return None - - irvalue.const_value = _convenience.tensor(value) - - if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: - logger.info( - "Skip storing constant folded nvalue %s due to large size %s.", - irvalue.name, - value.nbytes, - ) - return None - - tensor = onnx.numpy_helper.from_array(value, irvalue.name) - - logger.debug( - "New constant for value %s dtype: %s shape: %s", - irvalue.name, - value.dtype, - value.shape, - ) - - # TODO(rama) - # irvalue.type = onnx.helper.make_tensor_type_proto( - # onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape - # ) - attributes = _convenience.convert_attributes({"value": tensor}) - node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) - return node - - def process_node(self, node: ir.Node): - for i, value in enumerate(node.inputs): - sym_value = get_sym_value(value) - if isinstance(sym_value, ir.Value): - node.replace_input_with(i, sym_value) - # TODO(rama): consider merging type/other info from both values - - # Do incremental shape inference - if self._do_shape_inference and not is_control_flow_op(node): - self._do_inference(node) - - if node.domain not in self.opset_imports: - return None - version = self.opset_imports[node.domain] - op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) - for optimizer in op_optimizers: - assert optimizer - context = orp.RewriterContext() - output = optimizer(context, node) - if output is not None: - # TODO(rama): return nodes, values - if isinstance(output, ir.Value): - output = [output] - return Replacement(output, context.nodes) - - if is_control_flow_op(node) or is_non_deterministic_op(node): - return None - - if any((x is not None and x.const_value is None) for x in node.inputs): - return None - - input_values = [x.const_value.numpy() if x is not None else None for x in node.inputs] - - # Filter out bfloat16 cases? - def convert(av): - if isinstance(av, ir.AttrTensor): - return ir.serde.serialize_tensor(av.value) - return av.value - - attr_values = {name: convert(attr) for name, attr in node.attributes.items()} - outputs = _reference_evaluator.evaluate( - node.domain, node.op_type, version, *input_values, **attr_values - ) - - if outputs is None: - return None - if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): - replacement = self.new_constant(node.outputs[0], outputs) - if is_constant_op(node) or replacement is None: - return None - # self.add_count(op, outputs.size) - return Replacement(replacement.outputs, [replacement]) - else: - logger.warning( - "Skipping constant folding for op %s with multiple outputs.", node.op_type - ) - return None - - def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): - # TODO: apply delta! what about opset_imports? - old_values = node.outputs - new_values = replacement.new_outputs - for old_value, new_value in zip(old_values, new_values): - # Propagate relevant info from old value to new value - # TODO(Rama): Perhaps we should merge old and new types. As of now, new - # values don't have type information. Note that this could be a problem - # for semantics-altering rewrite-rules: we should allow users to override - # this for such rules. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name - - # Reconnect the users of the deleted node to use the new outputs - _convenience.replace_all_uses_with(old_values, new_values) - # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(root.outputs): - if graph_or_function_output in replacement_mapping: - root.outputs[idx] = replacement_mapping[graph_or_function_output] - - # insert new nodes after the index node - root.insert_after(node, replacement.new_nodes) - root.remove(node, safe=True) - - # if isinstance(output, list): - # return output - # else: - # # Currently handles single output only - # self.add_count(node.op_type, output.size) - # return self.new_constant(node.output[0], output) - - def visit_attribute(self, attr: ir.Attr) -> None: - if isinstance(attr, ir.AttrGraph): - self.visit_graph(attr.value) - elif isinstance(attr, ir.AttrGraphs): - for graph in attr.value: - self.visit_graph(graph) - - def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): - replacement = self.process_node(node) - # logger.debug( - # "visit_node: %s::%s %s replacement %s", - # node.domain, - # node.op_type, - # node.name, - # "found" if replacement is not None else "missed", - # ) - if replacement is None: - # No change. Process attributes. - for attr in node.attributes.values(): - self.visit_attribute(attr) - return None - - else: - self.replace_node(node, replacement, root) - - def visit_graph(self, graph: ir.Graph) -> None: - for node in graph: - self.visit_node(node, graph) - - def visit_model(self, model: ir.Model) -> None: - self._init() - self.opset_imports = model.opset_imports - self.visit_graph(model.graph) - # TODO(rama): handle functions - - -def fold_constants( - model: ir.Model, - external_data_folder: str = "", - *, - onnx_shape_inference: bool = False, -) -> bool: - """ - Applies constant folding optimization to the model. - Returns true iff the model was modified. - """ - folder = ConstantFolder( - external_data_folder, - onnx_shape_inference, - ) - folder.visit_model(model) - for op in folder.counts: - logger.info( - "Constant-folded '%s' %s times, with %s size.", - op, - folder.counts[op], - folder.sizes[op], - ) - return folder.modified From 2d821fea2406853322674d450348e8cea865190c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 11 Jul 2024 13:00:53 -0700 Subject: [PATCH 16/22] Use explicit state to track symbolic value --- onnxscript/optimizer/constant_folding_ir.py | 122 +++++++++----------- 1 file changed, 56 insertions(+), 66 deletions(-) diff --git a/onnxscript/optimizer/constant_folding_ir.py b/onnxscript/optimizer/constant_folding_ir.py index 08497085f..fcbdab147 100644 --- a/onnxscript/optimizer/constant_folding_ir.py +++ b/onnxscript/optimizer/constant_folding_ir.py @@ -16,8 +16,6 @@ import onnxscript.ir as ir import onnxscript.ir._convenience as _convenience -import onnxscript.ir._enums as _enums -import onnxscript.ir.serde as serde import onnxscript.optimizer.constant_folding as constant_folding import onnxscript.rewriter.pattern as orp @@ -52,7 +50,7 @@ def is_constant_op(node: ir.Node) -> bool: class ReferenceEvaluator: - def get_evaluator(self, domain: str, op: str, version: int) -> callable | None: + def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None: try: op_impl_class = onnx.reference.ops.load_op(domain, op, version) return op_impl_class.eval # noqa: TRY300 @@ -67,7 +65,7 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: return evaluator(*args, **kwargs) -reference_evaluator = ReferenceEvaluator() +_reference_evaluator = ReferenceEvaluator() @dataclasses.dataclass @@ -78,16 +76,31 @@ class Replacement: new_nodes: Sequence[ir.Node] +class OptimizerState: + def __init__(self): + self._sym_value_map: dict[ir.Value, Any] = {} + + def get_sym_value(self, value: ir.Value | None) -> Any: + if value is None: + return None + return self._sym_value_map.get(value, None) + + def set_sym_value(self, value: ir.Value, sym_value: Any) -> None: + self._sym_value_map[value] = sym_value + + # The "partial evaluators" below are non-standard evaluators. They are used to perform # partial evaluation and/or static program analysis (abstract interpretation). -# A partial-evaluator function takes an RewriterContext and a node, and returns a Replacement -# for the node or None (if no replacement is needed). It may also return just the ir.Value -# or ir.Values to replace the output values of the node, when the new nodes can be inferred -# from the RewriterContext used to build the new nodes. +# A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns +# a Replacement for the node or None (if no replacement is needed). It may also return just +# the ir.Value or ir.Values to replace the output values of the node, when the new nodes +# can be inferred from the RewriterContext used to build the new nodes. ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] -PartialEvaluatorFunction = Callable[[orp.RewriterContext, ir.Node], ReturnValue] +PartialEvaluatorFunction = Callable[ + [ir.Node, orp.RewriterContext, OptimizerState], ReturnValue +] @dataclasses.dataclass @@ -149,14 +162,6 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: register = registry.register -def _get_sym_value(val: ir.Value | None) -> ir.Value | None: - if val is None: - return None - if hasattr(val, "symbolic_value"): - return val.symbolic_value - return None - - def _get_numpy_value(val: ir.Value) -> np.ndarray | None: const_value = val.const_value if hasattr(const_value, "numpy"): @@ -200,7 +205,7 @@ def _get_input_element_type(node: ir.Node, index: int) -> int: input = _get_input(node, index) if input is not None and input.type is not None: return input.type.dtype.value - return _enums.DataType.UNDEFINED.value + return ir.DataType.UNDEFINED.value def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: @@ -215,7 +220,7 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> # TODO(rama): The following should not be necessary. Generic incremental shape-inference # should handle this. This essentially implements type/shape-inference for Cast op. @register("Cast") -def cast(op, node: ir.Node) -> ReturnValue: +def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = _get_input(node, 0) output = _get_output(node, 0) if input is not None and output is not None: @@ -224,12 +229,12 @@ def cast(op, node: ir.Node) -> ReturnValue: @register("CastLike") -def cast_like(op, node: ir.Node) -> ReturnValue: +def cast_like(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input0 = node.inputs[0] source_element_type = _get_input_element_type(node, 0) target_element_type = _get_input_element_type(node, 1) - if target_element_type is _enums.DataType.UNDEFINED.value: + if target_element_type == ir.DataType.UNDEFINED.value: return None if source_element_type == target_element_type: return op.Identity(input0) @@ -237,7 +242,7 @@ def cast_like(op, node: ir.Node) -> ReturnValue: @register("Shape") -def shape(op, node: ir.Node) -> ReturnValue: +def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] shape = input.shape if shape is None: @@ -246,12 +251,12 @@ def shape(op, node: ir.Node) -> ReturnValue: end = _get_int_attribute(node, "end", None) shape_slice = shape[start:end] if all(isinstance(d, int) for d in shape_slice): - return op.Constant(value_ints=[d for d in shape_slice]) + return op.Constant(value_ints=list(shape_slice)) return None @register("Size") -def size(op, node: ir.Node) -> ReturnValue: +def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue: shape = node.inputs[0].shape if shape is None: return None @@ -264,7 +269,7 @@ def size(op, node: ir.Node) -> ReturnValue: @register("If") -def if_op(op, node: ir.Node) -> ReturnValue: +def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: cond = _get_input(node, 0) cond = _get_bool_value(cond) if cond is not None: @@ -301,35 +306,35 @@ def rename(name): @register("Identity") -def identity(op, node: ir.Node) -> ReturnValue: +def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: del op input = node.inputs[0] output = node.outputs[0] if input is not None and output is not None: - output.symbolic_value = input + state.set_sym_value(output, input) return None @register("SequenceConstruct") -def sequence_construct(op, node: ir.Node) -> ReturnValue: +def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: del op output = node.outputs[0] if output is not None: - output.symbolic_value = list(node.inputs) + state.set_sym_value(output, list(node.inputs)) return None @register("ConcatFromSequence") -def concat_from_sequence(op, node: ir.Node) -> ReturnValue: +def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] - inputs = input.symbolic_value + inputs = state.get_sym_value(input) if any(x is None for x in inputs): return None new_axis = _get_int_attribute(node, "new_axis", 0) if "axis" not in node.attributes: return None axis = node.attributes["axis"].value - if input is not None and isinstance(input.symbolic_value, list): + if input is not None and isinstance(inputs, list): if new_axis == 0: logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) return op.Concat(*inputs, axis=axis) @@ -351,7 +356,7 @@ def concat_from_sequence(op, node: ir.Node) -> ReturnValue: @register("SplitToSequence") -def split_to_sequence(op, node: ir.Node) -> ReturnValue: +def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Rewriting pattern. From @@ -434,12 +439,12 @@ def split_to_sequence(op, node: ir.Node) -> ReturnValue: @register("SequenceAt") -def sequence_at(op, node: ir.Node) -> ReturnValue: +def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] position = node.inputs[1] output = node.outputs[0] if input is not None and position is not None: - input_vals = input.symbolic_value + input_vals = state.get_sym_value(input) position_val = _get_numpy_value(position) if isinstance(input_vals, list) and position_val is not None: if position_val.size != 1: @@ -449,7 +454,7 @@ def sequence_at(op, node: ir.Node) -> ReturnValue: result = input_vals[position_val] except IndexError: return None - output.symbolic_value = result + state.set_sym_value(output, result) logger.debug("SequenceAt %s => %s", input.name, result.name) return op.Identity(result) return None @@ -471,6 +476,7 @@ def _init(self) -> None: self.counts = {} self.sizes = {} self.modified = False + self._state = OptimizerState() def _do_inference(self, node: ir.Node) -> None: output_types = {} @@ -484,10 +490,9 @@ def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: def get_type(value: ir.Value) -> onnx.TypeProto | None: if value.type is not None: - type_proto = onnx.TypeProto() - serde.serialize_type_into(type_proto, value.type) + type_proto = ir.serde.serialize_type(value.type) if value.shape is not None: - serde.serialize_shape_into(type_proto, value.shape) + ir.serde.serialize_shape_into(type_proto, value.shape) return type_proto return None @@ -506,14 +511,14 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: node.op_type, self.opset_imports[node.domain], node.domain ) output_types = onnx.shape_inference.infer_node_outputs( - schema, serde.serialize_node(node), input_types, input_data + schema, ir.serde.serialize_node(node), input_types, input_data ) for output in node.outputs: if output.name in output_types: inferred_type = output_types[output.name] # TODO: merge types, check for conflicts - output.shape = serde.deserialize_type_proto_for_shape(inferred_type) - output.type = serde.deserialize_type_proto_for_type(inferred_type) + output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type) + output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: logger.debug( "Skipping shape inference for node %s due to exception: %s", @@ -556,17 +561,13 @@ def new_constant(self, irvalue: ir.Value, value): value.shape, ) - # TODO(rama) - # irvalue.type = onnx.helper.make_tensor_type_proto( - # onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape - # ) attributes = _convenience.convert_attributes({"value": tensor}) node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) return node def process_node(self, node: ir.Node): for i, value in enumerate(node.inputs): - sym_value = _get_sym_value(value) + sym_value = self._state.get_sym_value(value) if isinstance(sym_value, ir.Value): node.replace_input_with(i, sym_value) # TODO(rama): consider merging type/other info from both values @@ -582,7 +583,7 @@ def process_node(self, node: ir.Node): for optimizer in op_optimizers: assert optimizer context = orp.RewriterContext() - output = optimizer(context, node) + output = optimizer(node, context, self._state) if output is not None: if isinstance(output, Replacement): return output @@ -601,11 +602,11 @@ def process_node(self, node: ir.Node): # Filter out bfloat16 cases? def convert(av): if isinstance(av, ir.AttrTensor): - return serde.serialize_tensor(av.value) + return ir.serde.serialize_tensor(av.value) return av.value attr_values = {name: convert(attr) for name, attr in node.attributes.items()} - outputs = reference_evaluator.evaluate( + outputs = _reference_evaluator.evaluate( node.domain, node.op_type, version, *input_values, **attr_values ) @@ -615,7 +616,6 @@ def convert(av): replacement = self.new_constant(node.outputs[0], outputs) if is_constant_op(node) or replacement is None: return None - # self.add_count(op, outputs.size) return Replacement(replacement.outputs, [replacement]) else: logger.warning( @@ -624,7 +624,9 @@ def convert(av): return None def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): - # TODO: apply delta! what about opset_imports? + logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) + + # TODO: what about new opset_imports? old_values = node.outputs new_values = replacement.new_outputs for old_value, new_value in zip(old_values, new_values): @@ -650,12 +652,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) root.insert_after(node, replacement.new_nodes) root.remove(node, safe=True) - # if isinstance(output, list): - # return output - # else: - # # Currently handles single output only - # self.add_count(node.op_type, output.size) - # return self.new_constant(node.output[0], output) + # TODO: track statistics about replaced nodes and sizes of new constants def visit_attribute(self, attr: ir.Attr) -> None: if isinstance(attr, ir.AttrGraph): @@ -666,19 +663,11 @@ def visit_attribute(self, attr: ir.Attr) -> None: def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): replacement = self.process_node(node) - # logger.debug( - # "visit_node: %s::%s %s replacement %s", - # node.domain, - # node.op_type, - # node.name, - # "found" if replacement is not None else "missed", - # ) if replacement is None: # No change. Process attributes. for attr in node.attributes.values(): self.visit_attribute(attr) return None - else: self.replace_node(node, replacement, root) @@ -691,6 +680,7 @@ def visit_model(self, model: ir.Model) -> None: self.opset_imports = model.opset_imports self.visit_graph(model.graph) # TODO(rama): handle functions + # Pending decision on whether we want to specialize functions or not. def fold_constants( From ae5bd0281a20da4a1b4424ab9b5baa89957d49ba Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 11 Jul 2024 16:25:36 -0700 Subject: [PATCH 17/22] Address PR feedback --- .../{constant_folding_ir.py => _constant_folding.py} | 12 ++++++++---- onnxscript/optimizer/constant_folding_test.py | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) rename onnxscript/optimizer/{constant_folding_ir.py => _constant_folding.py} (98%) diff --git a/onnxscript/optimizer/constant_folding_ir.py b/onnxscript/optimizer/_constant_folding.py similarity index 98% rename from onnxscript/optimizer/constant_folding_ir.py rename to onnxscript/optimizer/_constant_folding.py index fcbdab147..463d9f4e0 100644 --- a/onnxscript/optimizer/constant_folding_ir.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # ------------------------------------------------------------------------- +# NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files. + from __future__ import annotations import dataclasses @@ -210,9 +212,11 @@ def _get_input_element_type(node: ir.Node, index: int) -> int: def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: if name in node.attributes: - attr = node.attributes[name] - if isinstance(attr, ir.AttrInt64): - return attr.value + attr = node.attributes[name].value + if isinstance(attr, int): + return attr + # This is an invalid model. For now, we just return None. + # We could raise an error too. return None return default @@ -234,7 +238,7 @@ def cast_like(node: ir.Node, op, state: OptimizerState) -> ReturnValue: source_element_type = _get_input_element_type(node, 0) target_element_type = _get_input_element_type(node, 1) - if target_element_type == ir.DataType.UNDEFINED.value: + if target_element_type == ir.DataType.UNDEFINED: return None if source_element_type == target_element_type: return op.Identity(input0) diff --git a/onnxscript/optimizer/constant_folding_test.py b/onnxscript/optimizer/constant_folding_test.py index bd4d5b65d..7629653d4 100644 --- a/onnxscript/optimizer/constant_folding_test.py +++ b/onnxscript/optimizer/constant_folding_test.py @@ -8,7 +8,7 @@ import onnxscript.optimizer as optimizer from onnxscript.ir import serde -from onnxscript.optimizer import constant_folding, constant_folding_ir +from onnxscript.optimizer import _constant_folding, constant_folding @parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) @@ -16,7 +16,7 @@ class FoldConstantsTest(unittest.TestCase): def _fold(self, model: onnx.ModelProto, onnx_shape_inference=False): if self.using_ir: ir_model = serde.deserialize_model(model) - constant_folding_ir.fold_constants( + _constant_folding.fold_constants( ir_model, onnx_shape_inference=onnx_shape_inference ) optimizer.remove_unused_nodes(ir_model) From 60e82b9220db5538e2125801526c741b4fac0f5a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 17 Jul 2024 09:10:03 -0700 Subject: [PATCH 18/22] Address PR comments --- onnxscript/optimizer/_constant_folding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 463d9f4e0..4c91f7799 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# ------------------------------------------------------------------------- # NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files. @@ -85,7 +83,7 @@ def __init__(self): def get_sym_value(self, value: ir.Value | None) -> Any: if value is None: return None - return self._sym_value_map.get(value, None) + return self._sym_value_map.get(value) def set_sym_value(self, value: ir.Value, sym_value: Any) -> None: self._sym_value_map[value] = sym_value @@ -137,7 +135,9 @@ def lookup_evaluators(self, domain: str, opname: str, version: int): evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) ] - def register(self, opname: str, domain: str = "", version=None): + def register( + self, opname: str, domain: str = "", version=None + ) -> Callable[[PartialEvaluatorFunction], PartialEvaluatorFunction]: if (domain, opname) not in self.op_evaluators: evaluator_list = [] self.op_evaluators[(domain, opname)] = evaluator_list @@ -166,7 +166,7 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: def _get_numpy_value(val: ir.Value) -> np.ndarray | None: const_value = val.const_value - if hasattr(const_value, "numpy"): + if const_value is not None: return const_value.numpy() return None From 3a74d9a7e332d2300ba1931b63f0e2e41f314c41 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 17 Jul 2024 09:32:45 -0700 Subject: [PATCH 19/22] Address mypy warnings --- onnxscript/optimizer/_constant_folding.py | 38 +++++++++++++---------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 4c91f7799..d9da94609 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -138,11 +138,11 @@ def lookup_evaluators(self, domain: str, opname: str, version: int): def register( self, opname: str, domain: str = "", version=None ) -> Callable[[PartialEvaluatorFunction], PartialEvaluatorFunction]: - if (domain, opname) not in self.op_evaluators: + if (domain, opname) in self.op_evaluators: + evaluator_list = self.op_evaluators[(domain, opname)] + else: evaluator_list = [] self.op_evaluators[(domain, opname)] = evaluator_list - else: - evaluator_list = self.op_evaluators[(domain, opname)] if version is None: min_version = None max_version = None @@ -174,15 +174,18 @@ def _get_numpy_value(val: ir.Value) -> np.ndarray | None: def _get_bool_value(val: ir.Value | None) -> bool | None: if val is None: return None - val = _get_numpy_value(val) - if val is None: + value = _get_numpy_value(val) + if value is None: return None - if isinstance(val, bool): - return val - if isinstance(val, np.bool_): - return bool(val) - if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: - return val.item(0) + # TODO: cleanup following checks, which seem redundant. But need to also ensure + # the invariant when setting the value (and also use clearly defined representation + # types in evaluators, such a reference-evaluator). + if isinstance(value, bool): + return value + if isinstance(value, np.bool_): + return bool(value) + if isinstance(value, np.ndarray) and value.size == 1 and value.dtype == bool: + return value.item(0) return None @@ -212,11 +215,14 @@ def _get_input_element_type(node: ir.Node, index: int) -> int: def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: if name in node.attributes: - attr = node.attributes[name].value - if isinstance(attr, int): - return attr - # This is an invalid model. For now, we just return None. - # We could raise an error too. + attr = node.attributes[name] + if not isinstance(attr, ir.Attr): + return None + attr_val = attr.value + if isinstance(attr_val, int): + return attr_val + # This is an invalid model: attribute has invalid/unexpected type. + # For now, we just return None. We could raise an error too. return None return default From aeadeae4966d456ba27995b774e1b5c2edc44abe Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 17 Jul 2024 10:54:59 -0700 Subject: [PATCH 20/22] Fix mypy issues --- onnxscript/optimizer/_constant_folding.py | 26 +++++++++++++---------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index d9da94609..292ade3af 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -164,7 +164,9 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: register = registry.register -def _get_numpy_value(val: ir.Value) -> np.ndarray | None: +def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None: + if val is None: + return None const_value = val.const_value if const_value is not None: return const_value.numpy() @@ -201,9 +203,10 @@ def _get_output(node: ir.Node, index: int) -> ir.Value | None: return None -def _update_type(value: ir.Value, type: ir.TypeProtocol) -> None: - # TODO: merge types - value.type = type +def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None: + if type is not None: + # TODO: merge types + value.type = type def _get_input_element_type(node: ir.Node, index: int) -> int: @@ -254,6 +257,8 @@ def cast_like(node: ir.Node, op, state: OptimizerState) -> ReturnValue: @register("Shape") def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] + if input is None: + return None shape = input.shape if shape is None: return None @@ -483,8 +488,8 @@ def __init__( self._init() def _init(self) -> None: - self.counts = {} - self.sizes = {} + self.counts: dict[str, int] = {} + self.sizes: dict[str, int] = {} self.modified = False self._state = OptimizerState() @@ -522,7 +527,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: ) output_types = onnx.shape_inference.infer_node_outputs( schema, ir.serde.serialize_node(node), input_types, input_data - ) + ) # type: ignore[arg-type] for output in node.outputs: if output.name in output_types: inferred_type = output_types[output.name] @@ -604,11 +609,10 @@ def process_node(self, node: ir.Node): if is_control_flow_op(node) or is_non_deterministic_op(node): return None - if any((x is not None and x.const_value is None) for x in node.inputs): + input_values = [_get_numpy_value(x) for x in node.inputs] + if any(x is None for x in input_values): return None - input_values = [x.const_value.numpy() if x is not None else None for x in node.inputs] - # Filter out bfloat16 cases? def convert(av): if isinstance(av, ir.AttrTensor): @@ -664,7 +668,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) # TODO: track statistics about replaced nodes and sizes of new constants - def visit_attribute(self, attr: ir.Attr) -> None: + def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None: if isinstance(attr, ir.AttrGraph): self.visit_graph(attr.value) elif isinstance(attr, ir.AttrGraphs): From 1cb8689f6107f42d6fc0ff9f9355673b0061fcfc Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 17 Jul 2024 11:24:18 -0700 Subject: [PATCH 21/22] More mypy fixes --- onnxscript/optimizer/_constant_folding.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 292ade3af..7cf553275 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -526,8 +526,11 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: node.op_type, self.opset_imports[node.domain], node.domain ) output_types = onnx.shape_inference.infer_node_outputs( - schema, ir.serde.serialize_node(node), input_types, input_data - ) # type: ignore[arg-type] + schema, + ir.serde.serialize_node(node), + input_types, + input_data, # type: ignore[arg-type] + ) for output in node.outputs: if output.name in output_types: inferred_type = output_types[output.name] @@ -669,11 +672,12 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) # TODO: track statistics about replaced nodes and sizes of new constants def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None: - if isinstance(attr, ir.AttrGraph): - self.visit_graph(attr.value) - elif isinstance(attr, ir.AttrGraphs): - for graph in attr.value: - self.visit_graph(graph) + if isinstance(attr, ir.Attr): + if attr.type == ir.AttributeType.GRAPH: + self.visit_graph(attr.value) # type: ignore[arg-type] + elif attr.type == ir.AttributeType.GRAPHS: + for graph in attr.value: + self.visit_graph(graph) # type: ignore[arg-type] def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): replacement = self.process_node(node) From 61511ef123207834ea830ac3f70b7761476859a8 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 17 Jul 2024 12:31:02 -0700 Subject: [PATCH 22/22] More mypy warnings fixed --- onnxscript/optimizer/_constant_folding.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 7cf553275..6140b06f7 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -272,7 +272,10 @@ def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: @register("Size") def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue: - shape = node.inputs[0].shape + input = _get_input(node, 0) + if input is None: + return None + shape = input.shape if shape is None: return None size = 1 @@ -285,8 +288,8 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue: @register("If") def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: - cond = _get_input(node, 0) - cond = _get_bool_value(cond) + cond_input = _get_input(node, 0) + cond = _get_bool_value(cond_input) if cond is not None: # cond is a constant-value: inline the branch branch = "then_branch" if cond else "else_branch" @@ -346,9 +349,9 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu if any(x is None for x in inputs): return None new_axis = _get_int_attribute(node, "new_axis", 0) - if "axis" not in node.attributes: + axis = _get_int_attribute(node, "axis", None) + if axis is None: return None - axis = node.attributes["axis"].value if input is not None and isinstance(inputs, list): if new_axis == 0: logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) @@ -400,6 +403,8 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None axis = _get_int_attribute(node, "axis", 0) + if axis is None: + return None shape = input.shape if shape is None: return None @@ -466,7 +471,7 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None position_val = position_val.item() try: - result = input_vals[position_val] + result = input_vals[position_val] # type: ignore[index] except IndexError: return None state.set_sym_value(output, result) @@ -528,7 +533,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: output_types = onnx.shape_inference.infer_node_outputs( schema, ir.serde.serialize_node(node), - input_types, + input_types, # type: ignore[arg-type] input_data, # type: ignore[arg-type] ) for output in node.outputs: