Skip to content

Commit

Permalink
[IR] Improve from_proto/to_proto typing with overloads (#1992)
Browse files Browse the repository at this point in the history
- Use typing.overload to annotate the from_proto method for accurate
type hinting. With this change we can recommend users to use
`ir.from/to_proto` over the `ir.serde.(de)serialize*` methods and still
keep mypy happy. This simplifies the serialization apis for users.
- Create deserialize_tensor_shape to deserialize tensor shapes.
  • Loading branch information
justinchuby authored Dec 31, 2024
1 parent 9f79317 commit 854b5d9
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 41 deletions.
6 changes: 3 additions & 3 deletions docs/intermediate_representation/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"# Getting started with ONNX IR 🌱\n",
"The ONNX IR ships with the ONNX Script package and is available as `onnxscript.ir`.\n",
"To create an IR object from ONNX file, load it as `ModelProto` and call\n",
"`ir.from_proto()` or `ir.serde.deserialize_model`:"
"`ir.from_proto()`:"
]
},
{
Expand Down Expand Up @@ -65,7 +65,7 @@
"model_proto = onnx.parser.parse_model(MODEL_TEXT)\n",
"\n",
"# Create an IR object from the model\n",
"model = ir.serde.deserialize_model(model_proto)"
"model = ir.from_proto(model_proto)"
]
},
{
Expand Down Expand Up @@ -347,7 +347,7 @@
"metadata": {},
"outputs": [],
"source": [
"model_proto_back = ir.serde.serialize_model(model)"
"model_proto_back = ir.to_proto(model)"
]
},
{
Expand Down
118 changes: 80 additions & 38 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import functools
import typing

__all__ = [
# Tensors
Expand All @@ -29,6 +30,7 @@
"deserialize_node",
"deserialize_opset_import",
"deserialize_tensor",
"deserialize_tensor_shape",
"deserialize_type_proto_for_shape",
"deserialize_type_proto_for_type",
"deserialize_value_info_proto",
Expand Down Expand Up @@ -59,7 +61,6 @@
import collections
import logging
import os
import typing
from typing import Any, Callable, List, Mapping, Sequence

import numpy as np
Expand Down Expand Up @@ -121,16 +122,35 @@ def _unflatten_complex(
return array[::2] + 1j * array[1::2]


def from_proto(
proto: onnx.ModelProto
| onnx.GraphProto
| onnx.NodeProto
| onnx.TensorProto
| onnx.AttributeProto
| onnx.ValueInfoProto
| onnx.TypeProto
| onnx.FunctionProto,
) -> Any:
@typing.overload
def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.GraphProto) -> _core.Graph: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.NodeProto) -> _core.Node: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.TensorProto) -> _protocols.TensorProtocol: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.AttributeProto) -> _core.Attr: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.ValueInfoProto) -> _core.Value: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.TypeProto) -> _core.TypeAndShape: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.FunctionProto) -> _core.Function: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.TensorShapeProto) -> _core.Shape: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto( # type: ignore[overload-overlap]
proto: onnx.TensorShapeProto.Dimension,
) -> tuple[int | _core.SymbolicDim, str | None]: ...
@typing.overload
def from_proto(proto: Sequence[onnx.OperatorSetIdProto]) -> dict[str, int]: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: Sequence[onnx.StringStringEntryProto]) -> dict[str, str]: ... # type: ignore[overload-overlap]


def from_proto(proto: object) -> object:
"""Deserialize an ONNX proto message to an IR object."""
if isinstance(proto, onnx.ModelProto):
return deserialize_model(proto)
Expand All @@ -151,24 +171,47 @@ def from_proto(
)
if isinstance(proto, onnx.FunctionProto):
return deserialize_function(proto)
if isinstance(proto, onnx.TensorShapeProto):
return deserialize_tensor_shape(proto)
if isinstance(proto, onnx.TensorShapeProto.Dimension):
return deserialize_dimension(proto)
if isinstance(proto, Sequence) and all(
isinstance(p, onnx.OperatorSetIdProto) for p in proto
):
return deserialize_opset_import(proto)
if isinstance(proto, Sequence) and all(
isinstance(p, onnx.StringStringEntryProto) for p in proto
):
return deserialize_metadata_props(proto)
raise NotImplementedError(
f"Deserialization of {type(proto)} in from_proto is not implemented. "
"Use a specific ir.serde.deserialize* function instead."
)


def to_proto(
ir_object: _protocols.ModelProtocol
| _protocols.GraphProtocol
| _protocols.NodeProtocol
| _protocols.ValueProtocol
| _protocols.AttributeProtocol
| _protocols.ReferenceAttributeProtocol
| _protocols.TensorProtocol
| _protocols.TypeProtocol
| _protocols.GraphViewProtocol
| _protocols.FunctionProtocol,
) -> Any:
@typing.overload
def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.GraphProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.NodeProtocol) -> onnx.NodeProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.TensorProtocol) -> onnx.TensorProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.AttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.ReferenceAttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.ValueProtocol) -> onnx.ValueInfoProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.TypeProtocol) -> onnx.TypeProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.FunctionProtocol) -> onnx.FunctionProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.GraphViewProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]


def to_proto(ir_object: object) -> object:
"""Serialize an IR object to a proto."""
if isinstance(ir_object, _protocols.ModelProtocol):
return serialize_model(ir_object)
Expand Down Expand Up @@ -665,29 +708,28 @@ def deserialize_value_info_proto(
return value


@_capture_errors(str)
def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
# This logic handles when the shape is [] as well
dim_protos = proto.dim
deserialized_dim_denotations = [
deserialize_dimension(dim_proto) for dim_proto in dim_protos
]
dims = [dim for dim, _ in deserialized_dim_denotations]
denotations = [denotation for _, denotation in deserialized_dim_denotations]
return _core.Shape(dims, denotations=denotations, frozen=True)


@_capture_errors(str)
def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
if proto.HasField("tensor_type"):
if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
return None
# This logic handles when the shape is [] as well
dim_protos = shape_proto.dim
deserialized_dim_denotations = [
deserialize_dimension(dim_proto) for dim_proto in dim_protos
]
dims = [dim for dim, _ in deserialized_dim_denotations]
denotations = [denotation for _, denotation in deserialized_dim_denotations]
return _core.Shape(dims, denotations=denotations, frozen=True)
return deserialize_tensor_shape(shape_proto)
if proto.HasField("sparse_tensor_type"):
if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None:
return None
dim_protos = shape_proto.dim
deserialized_dim_denotations = [
deserialize_dimension(dim_proto) for dim_proto in dim_protos
]
dims = [dim for dim, _ in deserialized_dim_denotations]
denotations = [denotation for _, denotation in deserialized_dim_denotations]
return _core.Shape(dims, denotations=denotations, frozen=True)
return deserialize_tensor_shape(shape_proto)
if proto.HasField("sequence_type"):
if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
return None
Expand Down

0 comments on commit 854b5d9

Please sign in to comment.