Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/fix-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Jan 4, 2025
2 parents 82eecf3 + fa191bb commit 135d5fd
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 60 deletions.
6 changes: 3 additions & 3 deletions docs/intermediate_representation/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"# Getting started with ONNX IR 🌱\n",
"The ONNX IR ships with the ONNX Script package and is available as `onnxscript.ir`.\n",
"To create an IR object from ONNX file, load it as `ModelProto` and call\n",
"`ir.from_proto()` or `ir.serde.deserialize_model`:"
"`ir.from_proto()`:"
]
},
{
Expand Down Expand Up @@ -65,7 +65,7 @@
"model_proto = onnx.parser.parse_model(MODEL_TEXT)\n",
"\n",
"# Create an IR object from the model\n",
"model = ir.serde.deserialize_model(model_proto)"
"model = ir.from_proto(model_proto)"
]
},
{
Expand Down Expand Up @@ -347,7 +347,7 @@
"metadata": {},
"outputs": [],
"source": [
"model_proto_back = ir.serde.serialize_model(model)"
"model_proto_back = ir.to_proto(model)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8085,7 +8085,7 @@ def aten_swapdims(self: TensorType, dim0: int, dim1: int) -> TensorType:
@torch_op("aten::sym_size.int", trace_only=True)
def aten_sym_size(self: TensorType, dim: int = 0) -> INT64:
"""sym_size.int(Tensor self, int dim) -> SymInt"""
return op.Shape(self, end=dim + 1, start=dim)
return op.Squeeze(op.Shape(self, end=dim + 1, start=dim))


def aten_symeig(
Expand Down
38 changes: 23 additions & 15 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import onnx

from onnxscript.ir import _core, _enums, _protocols, serde
from onnxscript.ir import _core, _enums, _protocols, serde, tensor_adapters

if typing.TYPE_CHECKING:
import numpy.typing as npt
Expand Down Expand Up @@ -321,6 +321,9 @@ def tensor(
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
>>> tp_tensor.numpy()
array(0.5, dtype=float32)
>>> import torch
>>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor")
TorchTensor<FLOAT,[2]>(tensor([1., 2.]), name='torch_tensor')
Args:
value: The numpy array to create the tensor from.
Expand Down Expand Up @@ -353,22 +356,27 @@ def tensor(
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
"You do not have to specify the dtype when value is a TensorProto."
)
return tensor_
elif str(type(value)) == "<class 'torch.Tensor'>":
# NOTE: We use str(type(...)) and do not import torch for type checking
# as it creates overhead during import
return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type]
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
return _core.Tensor(value, dtype=dtype, name=name, doc_string=name)

# Plain Python object
if dtype is not None:
numpy_dtype = dtype.numpy()
else:
if dtype is not None:
numpy_dtype = dtype.numpy()
else:
numpy_dtype = None
array = np.array(value, dtype=numpy_dtype)
tensor_ = _core.Tensor(
array,
dtype=dtype,
shape=_core.Shape(array.shape),
name=name,
doc_string=name,
)
return tensor_
numpy_dtype = None
array = np.array(value, dtype=numpy_dtype)
return _core.Tensor(
array,
dtype=dtype,
shape=_core.Shape(array.shape),
name=name,
doc_string=name,
)


def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
Expand Down
22 changes: 22 additions & 0 deletions onnxscript/ir/_convenience_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Unit tests for the _convenience module."""

import unittest

import numpy as np

from onnxscript.ir import _convenience


class ConvenienceTest(unittest.TestCase):
def test_tensor_accepts_torch_tensor(self):
import torch as some_random_name # pylint: disable=import-outside-toplevel

torch_tensor = some_random_name.tensor([1, 2, 3])
tensor = _convenience.tensor(torch_tensor)
np.testing.assert_array_equal(tensor, torch_tensor.numpy())


if __name__ == "__main__":
unittest.main()
118 changes: 80 additions & 38 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import functools
import typing

__all__ = [
# Tensors
Expand All @@ -29,6 +30,7 @@
"deserialize_node",
"deserialize_opset_import",
"deserialize_tensor",
"deserialize_tensor_shape",
"deserialize_type_proto_for_shape",
"deserialize_type_proto_for_type",
"deserialize_value_info_proto",
Expand Down Expand Up @@ -59,7 +61,6 @@
import collections
import logging
import os
import typing
from typing import Any, Callable, List, Mapping, Sequence

import numpy as np
Expand Down Expand Up @@ -121,16 +122,35 @@ def _unflatten_complex(
return array[::2] + 1j * array[1::2]


def from_proto(
proto: onnx.ModelProto
| onnx.GraphProto
| onnx.NodeProto
| onnx.TensorProto
| onnx.AttributeProto
| onnx.ValueInfoProto
| onnx.TypeProto
| onnx.FunctionProto,
) -> Any:
@typing.overload
def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.GraphProto) -> _core.Graph: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.NodeProto) -> _core.Node: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.TensorProto) -> _protocols.TensorProtocol: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.AttributeProto) -> _core.Attr: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.ValueInfoProto) -> _core.Value: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.TypeProto) -> _core.TypeAndShape: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.FunctionProto) -> _core.Function: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: onnx.TensorShapeProto) -> _core.Shape: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto( # type: ignore[overload-overlap]
proto: onnx.TensorShapeProto.Dimension,
) -> tuple[int | _core.SymbolicDim, str | None]: ...
@typing.overload
def from_proto(proto: Sequence[onnx.OperatorSetIdProto]) -> dict[str, int]: ... # type: ignore[overload-overlap]
@typing.overload
def from_proto(proto: Sequence[onnx.StringStringEntryProto]) -> dict[str, str]: ... # type: ignore[overload-overlap]


def from_proto(proto: object) -> object:
"""Deserialize an ONNX proto message to an IR object."""
if isinstance(proto, onnx.ModelProto):
return deserialize_model(proto)
Expand All @@ -151,24 +171,47 @@ def from_proto(
)
if isinstance(proto, onnx.FunctionProto):
return deserialize_function(proto)
if isinstance(proto, onnx.TensorShapeProto):
return deserialize_tensor_shape(proto)
if isinstance(proto, onnx.TensorShapeProto.Dimension):
return deserialize_dimension(proto)
if isinstance(proto, Sequence) and all(
isinstance(p, onnx.OperatorSetIdProto) for p in proto
):
return deserialize_opset_import(proto)
if isinstance(proto, Sequence) and all(
isinstance(p, onnx.StringStringEntryProto) for p in proto
):
return deserialize_metadata_props(proto)
raise NotImplementedError(
f"Deserialization of {type(proto)} in from_proto is not implemented. "
"Use a specific ir.serde.deserialize* function instead."
)


def to_proto(
ir_object: _protocols.ModelProtocol
| _protocols.GraphProtocol
| _protocols.NodeProtocol
| _protocols.ValueProtocol
| _protocols.AttributeProtocol
| _protocols.ReferenceAttributeProtocol
| _protocols.TensorProtocol
| _protocols.TypeProtocol
| _protocols.GraphViewProtocol
| _protocols.FunctionProtocol,
) -> Any:
@typing.overload
def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.GraphProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.NodeProtocol) -> onnx.NodeProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.TensorProtocol) -> onnx.TensorProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.AttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.ReferenceAttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.ValueProtocol) -> onnx.ValueInfoProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.TypeProtocol) -> onnx.TypeProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.FunctionProtocol) -> onnx.FunctionProto: ... # type: ignore[overload-overlap]
@typing.overload
def to_proto(ir_object: _protocols.GraphViewProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]


def to_proto(ir_object: object) -> object:
"""Serialize an IR object to a proto."""
if isinstance(ir_object, _protocols.ModelProtocol):
return serialize_model(ir_object)
Expand Down Expand Up @@ -665,29 +708,28 @@ def deserialize_value_info_proto(
return value


@_capture_errors(str)
def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
# This logic handles when the shape is [] as well
dim_protos = proto.dim
deserialized_dim_denotations = [
deserialize_dimension(dim_proto) for dim_proto in dim_protos
]
dims = [dim for dim, _ in deserialized_dim_denotations]
denotations = [denotation for _, denotation in deserialized_dim_denotations]
return _core.Shape(dims, denotations=denotations, frozen=True)


@_capture_errors(str)
def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
if proto.HasField("tensor_type"):
if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
return None
# This logic handles when the shape is [] as well
dim_protos = shape_proto.dim
deserialized_dim_denotations = [
deserialize_dimension(dim_proto) for dim_proto in dim_protos
]
dims = [dim for dim, _ in deserialized_dim_denotations]
denotations = [denotation for _, denotation in deserialized_dim_denotations]
return _core.Shape(dims, denotations=denotations, frozen=True)
return deserialize_tensor_shape(shape_proto)
if proto.HasField("sparse_tensor_type"):
if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None:
return None
dim_protos = shape_proto.dim
deserialized_dim_denotations = [
deserialize_dimension(dim_proto) for dim_proto in dim_protos
]
dims = [dim for dim, _ in deserialized_dim_denotations]
denotations = [denotation for _, denotation in deserialized_dim_denotations]
return _core.Shape(dims, denotations=denotations, frozen=True)
return deserialize_tensor_shape(shape_proto)
if proto.HasField("sequence_type"):
if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
return None
Expand Down
11 changes: 8 additions & 3 deletions onnxscript/ir/tensor_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
import numpy.typing as npt

from onnxscript import ir
from onnxscript.ir import _core

if TYPE_CHECKING:
import torch


class TorchTensor(ir.Tensor):
def __init__(self, tensor: torch.Tensor, name: str | None = None):
class TorchTensor(_core.Tensor):
def __init__(
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
):
# Pass the tensor as the raw data to ir.Tensor's constructor
import torch

Expand All @@ -69,7 +72,9 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None):
torch.uint32: ir.DataType.UINT32,
torch.uint64: ir.DataType.UINT64,
}
super().__init__(tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name)
super().__init__(
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
)

def numpy(self) -> npt.NDArray:
import torch
Expand Down

0 comments on commit 135d5fd

Please sign in to comment.