Skip to content

Commit

Permalink
Support string tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Apr 26, 2024
1 parent 4efc5a5 commit db08661
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
111 changes: 111 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)

if typing.TYPE_CHECKING:
import numpy.typing as npt
from typing_extensions import TypeGuard

TArrayCompatible = typing.TypeVar(
Expand Down Expand Up @@ -454,6 +455,116 @@ def meta(self) -> _metadata.MetadataStore:
return self._metadata


class StringTensor(TensorBase, _protocols.TensorProtocol):
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""

__slots__ = (
"_raw",
"_shape",
"name",
"doc_string",
"_metadata_props",
"_metadata",
)

def __init__(
self,
value: Sequence[bytes] | npt.NDArray[np.bytes_],
*,
shape: Shape | None = None,
name: str = "",
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
"""Initialize a tensor.
Args:
value: The backing data of the tensor. It can be a numpy array or a Sequence of strings.
shape: The shape of the tensor. If None, the shape is obtained from the value.
name: The name of the tensor.
doc_string: The documentation string.
metadata_props: The metadata properties.
"""
if shape is None:
if not hasattr(value, "shape"):
raise ValueError(
f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
"Please specify the shape explicitly."
)
self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
else:
self._shape = shape
self._shape._frozen = True
self._raw = value
self.name = name
self.doc_string = doc_string
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props = metadata_props

def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray):
return self._raw
assert isinstance(
self._raw, Sequence
), f"Bug: Expected a sequence, got {type(self._raw)}"
return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())

def __dlpack__(self, *, stream: Any = None) -> Any:
del stream # unused
raise TypeError("StringTensor does not support DLPack")

def __dlpack_device__(self) -> tuple[int, int]:
raise TypeError("StringTensor does not support DLPack")

def __repr__(self) -> str:
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"

@property
def dtype(self) -> _enums.DataType:
"""The data type of the tensor. Immutable."""
return _enums.DataType.STRING

@property
def shape(self) -> Shape:
"""The shape of the tensor. Immutable."""
return self._shape

@property
def raw(self) -> str:
"""Backing data of the tensor. Immutable."""
return self._raw # type: ignore[return-value]

def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array."""
return self.__array__()

def tobytes(self) -> bytes:
raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.")

def string_data(self) -> Sequence[bytes]:
"""Return the string data of the tensor."""
if isinstance(self._raw, np.ndarray):
return self._raw.flatten().tolist()
return self._raw

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
self._metadata_props = {}
return self._metadata_props

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata


class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
__slots__ = ("_value",)

Expand Down
13 changes: 13 additions & 0 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,17 @@ def deserialize_tensor(
return DoubleDataTensor(proto)
if proto.data_type in UInt64DataTensor.compatible_types:
return UInt64DataTensor(proto)
if proto.data_type == _enums.DataType.STRING:
name = _get_field(proto, "name")
doc_string = _get_field(proto, "doc_string")
metadata_props = deserialize_metadata_props(proto.metadata_props)
return _core.StringTensor(
proto.string_data,
shape=_core.Shape(proto.dims),
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
raise ValueError(
f"TensorProto(name={proto.name}) does not have any data fields set and is not an external tensor."
)
Expand Down Expand Up @@ -1086,6 +1097,8 @@ def serialize_tensor_into(
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
elif isinstance(from_, _core.StringTensor):
tensor_proto.string_data.extend(from_.string_data())
else:
tensor_proto.raw_data = from_.tobytes()
_serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
Expand Down

0 comments on commit db08661

Please sign in to comment.