Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Test with onnx backend test files #1452

Merged
merged 16 commits into from
May 6, 2024
6 changes: 6 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
"AttrStrings",
"AttrTensor",
"AttrTensors",
"TypeAndShape",
"AttrTypeProto",
"AttrTypeProtos",
"SymbolicDim",
"ExternalTensor",
"StringTensor",
"Function",
"Graph",
"GraphView",
Expand Down Expand Up @@ -80,6 +83,7 @@
AttrTensor,
AttrTensors,
AttrTypeProto,
AttrTypeProtos,
ExternalTensor,
Function,
Graph,
Expand All @@ -92,9 +96,11 @@
SequenceType,
Shape,
SparseTensorType,
StringTensor,
SymbolicDim,
Tensor,
TensorType,
TypeAndShape,
Value,
)
from onnxscript.ir._enums import (
Expand Down
141 changes: 139 additions & 2 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import abc
import contextlib
import dataclasses
import math
import mmap
import os
Expand Down Expand Up @@ -47,6 +48,7 @@
)

if typing.TYPE_CHECKING:
import numpy.typing as npt

Check warning on line 51 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L51

Added line #L51 was not covered by tests
from typing_extensions import TypeGuard

TArrayCompatible = typing.TypeVar(
Expand Down Expand Up @@ -576,6 +578,116 @@
return self._metadata


class StringTensor(TensorBase, _protocols.TensorProtocol):
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""

__slots__ = (
"_raw",
"_shape",
"name",
"doc_string",
"_metadata_props",
"_metadata",
)

def __init__(
self,
value: Sequence[bytes] | npt.NDArray[np.bytes_],
*,
shape: Shape | None = None,
name: str = "",
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
"""Initialize a tensor.

Args:
value: The backing data of the tensor. It can be a numpy array or a Sequence of bytes.
shape: The shape of the tensor. If None, the shape is obtained from the value.
name: The name of the tensor.
doc_string: The documentation string.
metadata_props: The metadata properties.
"""
if shape is None:
if not hasattr(value, "shape"):
raise ValueError(

Check warning on line 613 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L613

Added line #L613 was not covered by tests
f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
"Please specify the shape explicitly."
)
self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009

Check warning on line 617 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L617

Added line #L617 was not covered by tests
else:
self._shape = shape
self._shape._frozen = True
self._raw = value
self.name = name
self.doc_string = doc_string
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props = metadata_props

def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray):
return self._raw
assert isinstance(

Check warning on line 630 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L629-L630

Added lines #L629 - L630 were not covered by tests
self._raw, Sequence
), f"Bug: Expected a sequence, got {type(self._raw)}"
return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())

Check warning on line 633 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L633

Added line #L633 was not covered by tests

def __dlpack__(self, *, stream: Any = None) -> Any:
del stream # unused
raise TypeError("StringTensor does not support DLPack")

Check warning on line 637 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L636-L637

Added lines #L636 - L637 were not covered by tests

def __dlpack_device__(self) -> tuple[int, int]:
raise TypeError("StringTensor does not support DLPack")

Check warning on line 640 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L640

Added line #L640 was not covered by tests

def __repr__(self) -> str:
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"

Check warning on line 643 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L643

Added line #L643 was not covered by tests

@property
def dtype(self) -> _enums.DataType:
"""The data type of the tensor. Immutable."""
return _enums.DataType.STRING

@property
def shape(self) -> Shape:
"""The shape of the tensor. Immutable."""
return self._shape

@property
def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
"""Backing data of the tensor. Immutable."""
return self._raw # type: ignore[return-value]

Check warning on line 658 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L658

Added line #L658 was not covered by tests

def numpy(self) -> npt.NDArray[np.bytes_]:
"""Return the tensor as a numpy array."""
return self.__array__()

Check warning on line 662 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L662

Added line #L662 was not covered by tests

def tobytes(self) -> bytes:
raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.")

Check warning on line 665 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L665

Added line #L665 was not covered by tests

def string_data(self) -> Sequence[bytes]:
"""Return the string data of the tensor."""
if isinstance(self._raw, np.ndarray):
return self._raw.flatten().tolist()

Check warning on line 670 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L670

Added line #L670 was not covered by tests
return self._raw

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
self._metadata_props = {}
return self._metadata_props

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.

Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata

Check warning on line 688 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L687-L688

Added lines #L687 - L688 were not covered by tests


class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
__slots__ = ("_value",)

Expand Down Expand Up @@ -2408,17 +2520,42 @@
)


@dataclasses.dataclass
class TypeAndShape:
"""Type and shape.

Useful for constructing a type proto.
"""

type: _protocols.TypeProtocol | None
shape: Shape | None


class AttrTypeProto(_SpecializedAttr):
def __init__(
self,
name: str,
value: _protocols.TypeProtocol,
value: TypeAndShape,
doc_string: str | None = None,
):
# TODO(justinchuby): Include shape as well
super().__init__(
name,
_enums.AttributeType.TYPE_PROTO,
value,
doc_string=doc_string,
)


class AttrTypeProtos(_SpecializedAttr):
def __init__(
self,
name: str,
value: Sequence[TypeAndShape],
doc_string: str | None = None,
):
super().__init__(

Check warning on line 2556 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L2556

Added line #L2556 was not covered by tests
name,
_enums.AttributeType.TYPE_PROTOS,
value,
doc_string=doc_string,
)
87 changes: 80 additions & 7 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@
inputs,
outputs,
nodes=nodes,
# TODO(justinchuby): Attach the values associated with the initializers
initializers=initializers,
doc_string=_get_field(proto, "doc_string"),
name=_get_field(proto, "name"),
Expand Down Expand Up @@ -633,6 +634,17 @@
doc_string=proto.doc_string,
metadata_props=deserialize_metadata_props(proto.metadata_props),
)
if proto.data_type == _enums.DataType.STRING:
name = _get_field(proto, "name")
doc_string = _get_field(proto, "doc_string")
metadata_props = deserialize_metadata_props(proto.metadata_props)
return _core.StringTensor(
proto.string_data,
shape=_core.Shape(proto.dims),
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
return TensorProtoTensor(proto)


Expand Down Expand Up @@ -691,7 +703,25 @@
[_deserialize_graph(g, scoped_values) for g in proto.graphs],
doc_string=doc_string,
)
# TODO: Handle type protos etc.
if type_ == _enums.AttributeType.SPARSE_TENSOR:
raise NotImplementedError("Sparse tensors are not supported yet")

Check warning on line 707 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L707

Added line #L707 was not covered by tests
if type_ == _enums.AttributeType.SPARSE_TENSORS:
raise NotImplementedError("Sparse tensors are not supported yet")

Check warning on line 709 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L709

Added line #L709 was not covered by tests
if type_ == _enums.AttributeType.TYPE_PROTO:
ir_type = deserialize_type_proto_for_type(proto.tp)
shape = deserialize_type_proto_for_shape(proto.tp)
return _core.AttrTypeProto(
name, _core.TypeAndShape(ir_type, shape), doc_string=doc_string
)
if type_ == _enums.AttributeType.TYPE_PROTOS:
type_and_shapes = []

Check warning on line 717 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L717

Added line #L717 was not covered by tests
for type_proto in proto.type_protos:
ir_type = deserialize_type_proto_for_type(type_proto)
shape = deserialize_type_proto_for_shape(type_proto)
type_and_shapes.append(_core.TypeAndShape(ir_type, shape))
return _core.AttrTypeProtos(name, type_and_shapes, doc_string=doc_string)

Check warning on line 722 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L719-L722

Added lines #L719 - L722 were not covered by tests
if type_ == _enums.AttributeType.UNDEFINED:
return _core.Attr(name, type_, None, doc_string=doc_string)

Check warning on line 724 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L724

Added line #L724 was not covered by tests
raise ValueError(f"Unsupported attribute type: '{type_}'")


Expand Down Expand Up @@ -1078,6 +1108,8 @@
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
elif isinstance(from_, _core.StringTensor):
tensor_proto.string_data.extend(from_.string_data())
else:
tensor_proto.raw_data = from_.tobytes()
_serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
Expand All @@ -1102,37 +1134,69 @@
attribute_proto: onnx.AttributeProto, type_: _enums.AttributeType, value: Any
) -> None:
if type_ == _enums.AttributeType.INT:
# value: int
attribute_proto.i = value
attribute_proto.type = onnx.AttributeProto.INT
elif type_ == _enums.AttributeType.FLOAT:
# value: float
attribute_proto.f = value
attribute_proto.type = onnx.AttributeProto.FLOAT
elif type_ == _enums.AttributeType.STRING:
# value: str
attribute_proto.s = value.encode("utf-8")
attribute_proto.type = onnx.AttributeProto.STRING
elif type_ == _enums.AttributeType.INTS:
# value: Sequence[int]
attribute_proto.ints.extend(value)
attribute_proto.type = onnx.AttributeProto.INTS
elif type_ == _enums.AttributeType.FLOATS:
# value: Sequence[float]
attribute_proto.floats.extend(value)
attribute_proto.type = onnx.AttributeProto.FLOATS
elif type_ == _enums.AttributeType.STRINGS:
# value: Sequence[str]
attribute_proto.strings.extend([s.encode("utf-8") for s in value])
attribute_proto.type = onnx.AttributeProto.STRINGS
elif type_ == _enums.AttributeType.TENSOR:
# value: _protocols.TensorProtocol
serialize_tensor_into(attribute_proto.t, value)
attribute_proto.type = onnx.AttributeProto.TENSOR
elif type_ == _enums.AttributeType.GRAPH:
# value: _protocols.GraphProtocol
serialize_graph_into(attribute_proto.g, value)
attribute_proto.type = onnx.AttributeProto.GRAPH
elif type_ == _enums.AttributeType.TENSORS:
# value: Sequence[_protocols.TensorProtocol]
for tensor in value:
serialize_tensor_into(attribute_proto.tensors.add(), tensor)
attribute_proto.type = onnx.AttributeProto.TENSORS
elif type_ == _enums.AttributeType.GRAPHS:
# value: Sequence[_protocols.GraphProtocol]
for graph in value:
serialize_graph_into(attribute_proto.graphs.add(), graph)
attribute_proto.type = onnx.AttributeProto.GRAPHS
elif type_ == _enums.AttributeType.SPARSE_TENSOR:
raise NotImplementedError("Sparse tensors are not supported yet")

Check warning on line 1179 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L1179

Added line #L1179 was not covered by tests
elif type_ == _enums.AttributeType.SPARSE_TENSORS:
raise NotImplementedError("Sparse tensors are not supported yet")

Check warning on line 1181 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L1181

Added line #L1181 was not covered by tests
elif type_ == _enums.AttributeType.TYPE_PROTO:
# value: _core.TypeAndShape
if value.type is not None:
serialize_type_into(attribute_proto.tp, value.type)
# Need to create the type _before_ writing the shape
if value.shape is not None:
serialize_shape_into(attribute_proto.tp, value.shape)
attribute_proto.type = onnx.AttributeProto.TYPE_PROTO
elif type_ == _enums.AttributeType.TYPE_PROTOS:
for ir_type in value:
# ir_type: _core.TypeAndShape
type_proto = attribute_proto.type_protos.add()

Check warning on line 1193 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L1193

Added line #L1193 was not covered by tests
if ir_type.type is not None:
serialize_type_into(type_proto, ir_type.type)

Check warning on line 1195 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L1195

Added line #L1195 was not covered by tests
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
if ir_type.shape is not None:
serialize_shape_into(type_proto, ir_type.shape)
attribute_proto.type = onnx.AttributeProto.TYPE_PROTOS

Check warning on line 1199 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L1198-L1199

Added lines #L1198 - L1199 were not covered by tests
else:
raise TypeError(f"Unsupported attribute type: {type_}")

Expand Down Expand Up @@ -1179,10 +1243,11 @@
value_info_proto.name = from_.name
if from_.metadata_props:
_serialize_metadata_props_into(value_info_proto.metadata_props, from_.metadata_props)
if from_.shape is not None:
serialize_shape_into(value_info_proto.type, from_.shape)
if from_.type is not None:
serialize_type_into(value_info_proto.type, from_.type)
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we show this dependency in code? Like putting from_.shape under from_.type is not None branch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean line 1238-1239?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you are saying. Currently we just assume the type to be a type that is an unknown tensor dtype with a known shape, if type is None. I think that's ok?

if from_.shape is not None:
serialize_shape_into(value_info_proto.type, from_.shape)


def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None:
Expand All @@ -1205,12 +1270,18 @@


def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None:
tensor_type_proto = type_proto.tensor_type
value_field = type_proto.WhichOneof("value")
tensor_type = getattr(type_proto, value_field)
while not isinstance(tensor_type.elem_type, int):
# Find the leaf type that has the shape field
type_proto = tensor_type.elem_type
value_field = type_proto.WhichOneof("value")
tensor_type = getattr(type_proto, value_field)
# When from is empty, we still need to set the shape field to an empty list by touching it
tensor_type_proto.shape.ClearField("dim")
tensor_type.shape.ClearField("dim")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check naming

for i, dim in enumerate(from_):
denotation = from_.get_denotation(i)
serialize_dimension_into(tensor_type_proto.shape.dim.add(), dim, denotation)
serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation)


def serialize_dimension_into(
Expand All @@ -1223,4 +1294,6 @@
if isinstance(dim, int):
dim_proto.dim_value = dim
elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)):
dim_proto.dim_param = str(dim.value)
if dim.value is not None:
# TODO(justinchuby): None is probably not a valid value for dim_param
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do something when it's None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we are ok doing nothing. The None was from a test case. We should add a pass to check that no dims are None in rewriter produced values. etc.

dim_proto.dim_param = str(dim.value)
Loading
Loading