Skip to content

Commit

Permalink
[IR] Test with onnx backend test files (#1452)
Browse files Browse the repository at this point in the history
- Run IR round-trip test on all ONNX backend test models.
- Fix support for shape serialization when the value type is Optional or
Sequence.
- Support TypeProto Attributes.
- Support String Tensors

Not implemented: #1430
  • Loading branch information
justinchuby authored May 6, 2024
1 parent e0e96d8 commit bc818e7
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 12 deletions.
6 changes: 6 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
"AttrStrings",
"AttrTensor",
"AttrTensors",
"TypeAndShape",
"AttrTypeProto",
"AttrTypeProtos",
"SymbolicDim",
"ExternalTensor",
"StringTensor",
"Function",
"Graph",
"GraphView",
Expand Down Expand Up @@ -80,6 +83,7 @@
AttrTensor,
AttrTensors,
AttrTypeProto,
AttrTypeProtos,
ExternalTensor,
Function,
Graph,
Expand All @@ -92,9 +96,11 @@
SequenceType,
Shape,
SparseTensorType,
StringTensor,
SymbolicDim,
Tensor,
TensorType,
TypeAndShape,
Value,
)
from onnxscript.ir._enums import (
Expand Down
141 changes: 139 additions & 2 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import abc
import contextlib
import dataclasses
import math
import mmap
import os
Expand Down Expand Up @@ -47,6 +48,7 @@
)

if typing.TYPE_CHECKING:
import numpy.typing as npt
from typing_extensions import TypeGuard

TArrayCompatible = typing.TypeVar(
Expand Down Expand Up @@ -576,6 +578,116 @@ def meta(self) -> _metadata.MetadataStore:
return self._metadata


class StringTensor(TensorBase, _protocols.TensorProtocol):
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""

__slots__ = (
"_raw",
"_shape",
"name",
"doc_string",
"_metadata_props",
"_metadata",
)

def __init__(
self,
value: Sequence[bytes] | npt.NDArray[np.bytes_],
*,
shape: Shape | None = None,
name: str = "",
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
"""Initialize a tensor.
Args:
value: The backing data of the tensor. It can be a numpy array or a Sequence of bytes.
shape: The shape of the tensor. If None, the shape is obtained from the value.
name: The name of the tensor.
doc_string: The documentation string.
metadata_props: The metadata properties.
"""
if shape is None:
if not hasattr(value, "shape"):
raise ValueError(
f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
"Please specify the shape explicitly."
)
self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
else:
self._shape = shape
self._shape._frozen = True
self._raw = value
self.name = name
self.doc_string = doc_string
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props = metadata_props

def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray):
return self._raw
assert isinstance(
self._raw, Sequence
), f"Bug: Expected a sequence, got {type(self._raw)}"
return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())

def __dlpack__(self, *, stream: Any = None) -> Any:
del stream # unused
raise TypeError("StringTensor does not support DLPack")

def __dlpack_device__(self) -> tuple[int, int]:
raise TypeError("StringTensor does not support DLPack")

def __repr__(self) -> str:
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"

@property
def dtype(self) -> _enums.DataType:
"""The data type of the tensor. Immutable."""
return _enums.DataType.STRING

@property
def shape(self) -> Shape:
"""The shape of the tensor. Immutable."""
return self._shape

@property
def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
"""Backing data of the tensor. Immutable."""
return self._raw # type: ignore[return-value]

def numpy(self) -> npt.NDArray[np.bytes_]:
"""Return the tensor as a numpy array."""
return self.__array__()

def tobytes(self) -> bytes:
raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.")

def string_data(self) -> Sequence[bytes]:
"""Return the string data of the tensor."""
if isinstance(self._raw, np.ndarray):
return self._raw.flatten().tolist()
return self._raw

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
self._metadata_props = {}
return self._metadata_props

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata


class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
__slots__ = ("_value",)

Expand Down Expand Up @@ -2408,17 +2520,42 @@ def __init__(
)


@dataclasses.dataclass
class TypeAndShape:
"""Type and shape.
Useful for constructing a type proto.
"""

type: _protocols.TypeProtocol | None
shape: Shape | None


class AttrTypeProto(_SpecializedAttr):
def __init__(
self,
name: str,
value: _protocols.TypeProtocol,
value: TypeAndShape,
doc_string: str | None = None,
):
# TODO(justinchuby): Include shape as well
super().__init__(
name,
_enums.AttributeType.TYPE_PROTO,
value,
doc_string=doc_string,
)


class AttrTypeProtos(_SpecializedAttr):
def __init__(
self,
name: str,
value: Sequence[TypeAndShape],
doc_string: str | None = None,
):
super().__init__(
name,
_enums.AttributeType.TYPE_PROTOS,
value,
doc_string=doc_string,
)
87 changes: 80 additions & 7 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def _deserialize_graph(
inputs,
outputs,
nodes=nodes,
# TODO(justinchuby): Attach the values associated with the initializers
initializers=initializers,
doc_string=_get_field(proto, "doc_string"),
name=_get_field(proto, "name"),
Expand Down Expand Up @@ -633,6 +634,17 @@ def deserialize_tensor(
doc_string=proto.doc_string,
metadata_props=deserialize_metadata_props(proto.metadata_props),
)
if proto.data_type == _enums.DataType.STRING:
name = _get_field(proto, "name")
doc_string = _get_field(proto, "doc_string")
metadata_props = deserialize_metadata_props(proto.metadata_props)
return _core.StringTensor(
proto.string_data,
shape=_core.Shape(proto.dims),
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
return TensorProtoTensor(proto)


Expand Down Expand Up @@ -691,7 +703,25 @@ def _deserialize_attribute(
[_deserialize_graph(g, scoped_values) for g in proto.graphs],
doc_string=doc_string,
)
# TODO: Handle type protos etc.
if type_ == _enums.AttributeType.SPARSE_TENSOR:
raise NotImplementedError("Sparse tensors are not supported yet")
if type_ == _enums.AttributeType.SPARSE_TENSORS:
raise NotImplementedError("Sparse tensors are not supported yet")
if type_ == _enums.AttributeType.TYPE_PROTO:
ir_type = deserialize_type_proto_for_type(proto.tp)
shape = deserialize_type_proto_for_shape(proto.tp)
return _core.AttrTypeProto(
name, _core.TypeAndShape(ir_type, shape), doc_string=doc_string
)
if type_ == _enums.AttributeType.TYPE_PROTOS:
type_and_shapes = []
for type_proto in proto.type_protos:
ir_type = deserialize_type_proto_for_type(type_proto)
shape = deserialize_type_proto_for_shape(type_proto)
type_and_shapes.append(_core.TypeAndShape(ir_type, shape))
return _core.AttrTypeProtos(name, type_and_shapes, doc_string=doc_string)
if type_ == _enums.AttributeType.UNDEFINED:
return _core.Attr(name, type_, None, doc_string=doc_string)
raise ValueError(f"Unsupported attribute type: '{type_}'")


Expand Down Expand Up @@ -1078,6 +1108,8 @@ def serialize_tensor_into(
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
elif isinstance(from_, _core.StringTensor):
tensor_proto.string_data.extend(from_.string_data())
else:
tensor_proto.raw_data = from_.tobytes()
_serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
Expand All @@ -1102,37 +1134,69 @@ def _fill_in_value_for_attribute(
attribute_proto: onnx.AttributeProto, type_: _enums.AttributeType, value: Any
) -> None:
if type_ == _enums.AttributeType.INT:
# value: int
attribute_proto.i = value
attribute_proto.type = onnx.AttributeProto.INT
elif type_ == _enums.AttributeType.FLOAT:
# value: float
attribute_proto.f = value
attribute_proto.type = onnx.AttributeProto.FLOAT
elif type_ == _enums.AttributeType.STRING:
# value: str
attribute_proto.s = value.encode("utf-8")
attribute_proto.type = onnx.AttributeProto.STRING
elif type_ == _enums.AttributeType.INTS:
# value: Sequence[int]
attribute_proto.ints.extend(value)
attribute_proto.type = onnx.AttributeProto.INTS
elif type_ == _enums.AttributeType.FLOATS:
# value: Sequence[float]
attribute_proto.floats.extend(value)
attribute_proto.type = onnx.AttributeProto.FLOATS
elif type_ == _enums.AttributeType.STRINGS:
# value: Sequence[str]
attribute_proto.strings.extend([s.encode("utf-8") for s in value])
attribute_proto.type = onnx.AttributeProto.STRINGS
elif type_ == _enums.AttributeType.TENSOR:
# value: _protocols.TensorProtocol
serialize_tensor_into(attribute_proto.t, value)
attribute_proto.type = onnx.AttributeProto.TENSOR
elif type_ == _enums.AttributeType.GRAPH:
# value: _protocols.GraphProtocol
serialize_graph_into(attribute_proto.g, value)
attribute_proto.type = onnx.AttributeProto.GRAPH
elif type_ == _enums.AttributeType.TENSORS:
# value: Sequence[_protocols.TensorProtocol]
for tensor in value:
serialize_tensor_into(attribute_proto.tensors.add(), tensor)
attribute_proto.type = onnx.AttributeProto.TENSORS
elif type_ == _enums.AttributeType.GRAPHS:
# value: Sequence[_protocols.GraphProtocol]
for graph in value:
serialize_graph_into(attribute_proto.graphs.add(), graph)
attribute_proto.type = onnx.AttributeProto.GRAPHS
elif type_ == _enums.AttributeType.SPARSE_TENSOR:
raise NotImplementedError("Sparse tensors are not supported yet")
elif type_ == _enums.AttributeType.SPARSE_TENSORS:
raise NotImplementedError("Sparse tensors are not supported yet")
elif type_ == _enums.AttributeType.TYPE_PROTO:
# value: _core.TypeAndShape
if value.type is not None:
serialize_type_into(attribute_proto.tp, value.type)
# Need to create the type _before_ writing the shape
if value.shape is not None:
serialize_shape_into(attribute_proto.tp, value.shape)
attribute_proto.type = onnx.AttributeProto.TYPE_PROTO
elif type_ == _enums.AttributeType.TYPE_PROTOS:
for ir_type in value:
# ir_type: _core.TypeAndShape
type_proto = attribute_proto.type_protos.add()
if ir_type.type is not None:
serialize_type_into(type_proto, ir_type.type)
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
if ir_type.shape is not None:
serialize_shape_into(type_proto, ir_type.shape)
attribute_proto.type = onnx.AttributeProto.TYPE_PROTOS
else:
raise TypeError(f"Unsupported attribute type: {type_}")

Expand Down Expand Up @@ -1179,10 +1243,11 @@ def serialize_value_into(
value_info_proto.name = from_.name
if from_.metadata_props:
_serialize_metadata_props_into(value_info_proto.metadata_props, from_.metadata_props)
if from_.shape is not None:
serialize_shape_into(value_info_proto.type, from_.shape)
if from_.type is not None:
serialize_type_into(value_info_proto.type, from_.type)
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
if from_.shape is not None:
serialize_shape_into(value_info_proto.type, from_.shape)


def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None:
Expand All @@ -1205,12 +1270,18 @@ def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtoc


def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None:
tensor_type_proto = type_proto.tensor_type
value_field = type_proto.WhichOneof("value")
tensor_type = getattr(type_proto, value_field)
while not isinstance(tensor_type.elem_type, int):
# Find the leaf type that has the shape field
type_proto = tensor_type.elem_type
value_field = type_proto.WhichOneof("value")
tensor_type = getattr(type_proto, value_field)
# When from is empty, we still need to set the shape field to an empty list by touching it
tensor_type_proto.shape.ClearField("dim")
tensor_type.shape.ClearField("dim")
for i, dim in enumerate(from_):
denotation = from_.get_denotation(i)
serialize_dimension_into(tensor_type_proto.shape.dim.add(), dim, denotation)
serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation)


def serialize_dimension_into(
Expand All @@ -1223,4 +1294,6 @@ def serialize_dimension_into(
if isinstance(dim, int):
dim_proto.dim_value = dim
elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)):
dim_proto.dim_param = str(dim.value)
if dim.value is not None:
# TODO(justinchuby): None is probably not a valid value for dim_param
dim_proto.dim_param = str(dim.value)
Loading

0 comments on commit bc818e7

Please sign in to comment.