diff --git a/docs/intermediate_representation/getting_started.ipynb b/docs/intermediate_representation/getting_started.ipynb index 4ababa4ea..68e1faaa7 100644 --- a/docs/intermediate_representation/getting_started.ipynb +++ b/docs/intermediate_representation/getting_started.ipynb @@ -8,7 +8,7 @@ "# Getting started with ONNX IR 🌱\n", "The ONNX IR ships with the ONNX Script package and is available as `onnxscript.ir`.\n", "To create an IR object from ONNX file, load it as `ModelProto` and call\n", - "`ir.from_proto()` or `ir.serde.deserialize_model`:" + "`ir.from_proto()`:" ] }, { @@ -65,7 +65,7 @@ "model_proto = onnx.parser.parse_model(MODEL_TEXT)\n", "\n", "# Create an IR object from the model\n", - "model = ir.serde.deserialize_model(model_proto)" + "model = ir.from_proto(model_proto)" ] }, { @@ -347,7 +347,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_proto_back = ir.serde.serialize_model(model)" + "model_proto_back = ir.to_proto(model)" ] }, { diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 079963df7..432af8cf1 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -14,6 +14,7 @@ from __future__ import annotations import functools +import typing __all__ = [ # Tensors @@ -29,6 +30,7 @@ "deserialize_node", "deserialize_opset_import", "deserialize_tensor", + "deserialize_tensor_shape", "deserialize_type_proto_for_shape", "deserialize_type_proto_for_type", "deserialize_value_info_proto", @@ -59,7 +61,6 @@ import collections import logging import os -import typing from typing import Any, Callable, List, Mapping, Sequence import numpy as np @@ -121,16 +122,35 @@ def _unflatten_complex( return array[::2] + 1j * array[1::2] -def from_proto( - proto: onnx.ModelProto - | onnx.GraphProto - | onnx.NodeProto - | onnx.TensorProto - | onnx.AttributeProto - | onnx.ValueInfoProto - | onnx.TypeProto - | onnx.FunctionProto, -) -> Any: +@typing.overload +def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: onnx.GraphProto) -> _core.Graph: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: onnx.NodeProto) -> _core.Node: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: onnx.TensorProto) -> _protocols.TensorProtocol: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: onnx.AttributeProto) -> _core.Attr: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: onnx.ValueInfoProto) -> _core.Value: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: onnx.TypeProto) -> _core.TypeAndShape: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: onnx.FunctionProto) -> _core.Function: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: onnx.TensorShapeProto) -> _core.Shape: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto( # type: ignore[overload-overlap] + proto: onnx.TensorShapeProto.Dimension, +) -> tuple[int | _core.SymbolicDim, str | None]: ... +@typing.overload +def from_proto(proto: Sequence[onnx.OperatorSetIdProto]) -> dict[str, int]: ... # type: ignore[overload-overlap] +@typing.overload +def from_proto(proto: Sequence[onnx.StringStringEntryProto]) -> dict[str, str]: ... # type: ignore[overload-overlap] + + +def from_proto(proto: object) -> object: """Deserialize an ONNX proto message to an IR object.""" if isinstance(proto, onnx.ModelProto): return deserialize_model(proto) @@ -151,24 +171,47 @@ def from_proto( ) if isinstance(proto, onnx.FunctionProto): return deserialize_function(proto) + if isinstance(proto, onnx.TensorShapeProto): + return deserialize_tensor_shape(proto) + if isinstance(proto, onnx.TensorShapeProto.Dimension): + return deserialize_dimension(proto) + if isinstance(proto, Sequence) and all( + isinstance(p, onnx.OperatorSetIdProto) for p in proto + ): + return deserialize_opset_import(proto) + if isinstance(proto, Sequence) and all( + isinstance(p, onnx.StringStringEntryProto) for p in proto + ): + return deserialize_metadata_props(proto) raise NotImplementedError( f"Deserialization of {type(proto)} in from_proto is not implemented. " "Use a specific ir.serde.deserialize* function instead." ) -def to_proto( - ir_object: _protocols.ModelProtocol - | _protocols.GraphProtocol - | _protocols.NodeProtocol - | _protocols.ValueProtocol - | _protocols.AttributeProtocol - | _protocols.ReferenceAttributeProtocol - | _protocols.TensorProtocol - | _protocols.TypeProtocol - | _protocols.GraphViewProtocol - | _protocols.FunctionProtocol, -) -> Any: +@typing.overload +def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.GraphProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.NodeProtocol) -> onnx.NodeProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.TensorProtocol) -> onnx.TensorProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.AttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.ReferenceAttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.ValueProtocol) -> onnx.ValueInfoProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.TypeProtocol) -> onnx.TypeProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.FunctionProtocol) -> onnx.FunctionProto: ... # type: ignore[overload-overlap] +@typing.overload +def to_proto(ir_object: _protocols.GraphViewProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap] + + +def to_proto(ir_object: object) -> object: """Serialize an IR object to a proto.""" if isinstance(ir_object, _protocols.ModelProtocol): return serialize_model(ir_object) @@ -665,29 +708,28 @@ def deserialize_value_info_proto( return value +@_capture_errors(str) +def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape: + # This logic handles when the shape is [] as well + dim_protos = proto.dim + deserialized_dim_denotations = [ + deserialize_dimension(dim_proto) for dim_proto in dim_protos + ] + dims = [dim for dim, _ in deserialized_dim_denotations] + denotations = [denotation for _, denotation in deserialized_dim_denotations] + return _core.Shape(dims, denotations=denotations, frozen=True) + + @_capture_errors(str) def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None: if proto.HasField("tensor_type"): if (shape_proto := _get_field(proto.tensor_type, "shape")) is None: return None - # This logic handles when the shape is [] as well - dim_protos = shape_proto.dim - deserialized_dim_denotations = [ - deserialize_dimension(dim_proto) for dim_proto in dim_protos - ] - dims = [dim for dim, _ in deserialized_dim_denotations] - denotations = [denotation for _, denotation in deserialized_dim_denotations] - return _core.Shape(dims, denotations=denotations, frozen=True) + return deserialize_tensor_shape(shape_proto) if proto.HasField("sparse_tensor_type"): if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None: return None - dim_protos = shape_proto.dim - deserialized_dim_denotations = [ - deserialize_dimension(dim_proto) for dim_proto in dim_protos - ] - dims = [dim for dim, _ in deserialized_dim_denotations] - denotations = [denotation for _, denotation in deserialized_dim_denotations] - return _core.Shape(dims, denotations=denotations, frozen=True) + return deserialize_tensor_shape(shape_proto) if proto.HasField("sequence_type"): if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None: return None