Skip to content

Commit

Permalink
[IR] Allow tensor created with numpy unsupported dtypes (#1441)
Browse files Browse the repository at this point in the history
Support all ONNX types in `ir.Tensor` and fix re-serialization of tensor
protos that does not use the `raw_data` field to store data. Previously
the *_data to raw_data conversion was incorrect when the data type size
doesn't match the *_data type size.

For numpy unsupported dtypes, the array is represented as the unsigned
int type of the same size as the dtype. For example, uint8 for float8,
uint16 for bfloat16; with an exception for 4bit types.

This PR creates a more performant and robust method to convert tensor
proto to numpy array than `onnx.numpy_helper.to_array`. We will upstream
the implementation in the future.

Additionally
- Consolidate all tensorproto classes to the TensorProtoTensor class


Tested: unit tests
  • Loading branch information
justinchuby authored May 4, 2024
1 parent 84170f7 commit bca6a64
Show file tree
Hide file tree
Showing 10 changed files with 699 additions and 159 deletions.
137 changes: 130 additions & 7 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_metadata,
_name_authority,
_protocols,
_type_casting,
)

if typing.TYPE_CHECKING:
Expand All @@ -54,7 +55,19 @@
)

# System is little endian
IS_LITTLE_ENDIAN = sys.byteorder == "little"
_IS_LITTLE_ENDIAN = sys.byteorder == "little"
# Data types that are not supported by numpy
_NON_NUMPY_NATIVE_TYPES = frozenset(
(
_enums.DataType.BFLOAT16,
_enums.DataType.FLOAT8E4M3FN,
_enums.DataType.FLOAT8E4M3FNUZ,
_enums.DataType.FLOAT8E5M2,
_enums.DataType.FLOAT8E5M2FNUZ,
_enums.DataType.INT4,
_enums.DataType.UINT4,
)
)


def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]:
Expand Down Expand Up @@ -159,8 +172,89 @@ def display(self, *, page: bool | None = None) -> None:
rich.print(text)


def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) -> None:
"""Check if the numpy array dtype matches the IR data type.
When the dtype is not one of the numpy native dtypes, the value needs need to be:
- ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
- ``uint8`` for uint4.
- ``uint8`` for 8-bit data types.
- ``uint16`` for bfloat16
"""
if dtype in _NON_NUMPY_NATIVE_TYPES:
if dtype.itemsize == 2 and array.dtype != np.uint16:
# TODO(justinchuby): Support the storage dtypes like uint16 for bfloat16.
raise TypeError(
f"The numpy array dtype must be uint16 (not {array.dtype}) for IR data type {dtype}."
)
if dtype.itemsize == 1 and array.dtype != np.uint8:
raise TypeError(
f"The numpy array dtype must be uint8 (not {array.dtype}) for IR data type {dtype}."
)
if dtype == _enums.DataType.INT4:
if array.dtype not in (np.int8, np.uint8):
raise TypeError(
f"The numpy array dtype must be int8 or uint8 (not {array.dtype}) for IR data type {dtype}."
)
if dtype == _enums.DataType.UINT4:
if array.dtype != np.uint8:
raise TypeError(
f"The numpy array dtype must be uint8 (not {array.dtype}) for IR data type {dtype}."
)
return

try:
dtype_numpy = _enums.DataType.from_numpy(array.dtype)
except TypeError as e:
raise TypeError(
"Failed to convert the numpy dtype to an IR data type. "
"If you are using a non-native dtype, be sure to specify the corresponding IR dtype when "
"creating a Tensor."
) from e

if dtype_numpy != dtype:
raise TypeError(
f"The numpy array dtype {array.dtype} does not match the IR data type {dtype}."
)


class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
"""An immutable concrete value."""
"""An immutable concrete tensor.
This class is a wrapper around the raw tensor data. The raw tensor data can be a numpy array
compatible object (e.g. ``np.ndarray``, ``torch.Tensor``) or a ``DLPack`` compatible object.
The tensor is immutable and the data is not copied at initialization.
To create a tensor from a numpy array::
>>> import numpy as np
>>> array = np.array([1, 2, 3])
>>> tensor = Tensor(array)
>>> # The tensor itself can be treated as a numpy array because it implements the __array__ method
>>> np.allclose(tensor, array)
True
To get a numpy array from the tensor, call :meth:`numpy`. To convert the tensor
to a byte string for serialization, call :meth:`tobytes`.
It is recommended to check the size of the tensor first before accessing the
underlying data, because accessing the data may be expensive and incur IO
overhead.
Subclass this class to efficiently handle different types of tensors from different frameworks.
Attributes:
name: The name of the tensor.
shape: The shape of the tensor.
dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum.
doc_string: Documentation string.
raw: The raw data behind this tensor. It can be anything.
size: The number of elements in the tensor.
nbytes: The number of bytes in the tensor.
metadata_props: Metadata that will be serialized to the ONNX file.
meta: Metadata store for graph transform passes.
"""

__slots__ = (
"_raw",
Expand All @@ -185,17 +279,28 @@ def __init__(
"""Initialize a tensor.
Args:
value: The backing data of the tensor. It can be a numpy array or a DLPack compatible object.
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
When the dtype is not one of the numpy native dtypes, the value needs
to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
when the value is a numpy array; :param:`dtype` must be specified in this case.
dtype: The data type of the tensor. It can be None only when value is a numpy array.
Users are responsible for making sure the dtype matches the value when value is not a numpy array.
shape: The shape of the tensor. If None, the shape is obtained from the value.
name: The name of the tensor.
doc_string: The documentation string.
metadata_props: The metadata properties.
Raises:
TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
TypeError: If the value is a numpy array and the dtype is specified but does not match the dtype of the array.
ValueError: If the shape is not specified and the value does not have a shape attribute.
ValueError: If the dtype is not specified and the value is not a numpy array.
"""
# NOTE: We should not do any copying here for performance reasons
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
raise TypeError(f"Expected an array compatible object, got {type(value)}")
if shape is None:
# Obtain the shape from the value
if not hasattr(value, "shape"):
raise ValueError(
f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
Expand All @@ -213,6 +318,11 @@ def __init__(
"The dtype must be specified when the value is not a numpy array."
)
else:
if isinstance(value, np.ndarray):
# Make sure the dtype matches the value
_check_numpy_representation_type(value, dtype)
# Users are responsible for making sure the dtype matches the value
# when value is not a numpy array
self._dtype = dtype
self._raw = value
self.name = name
Expand All @@ -221,7 +331,6 @@ def __init__(
self._metadata_props = metadata_props

def __array__(self, dtype: Any = None) -> np.ndarray:
# TODO(justinchuby): Support numpy unsupported types
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
return self._raw.__array__(dtype)
assert _compatible_with_dlpack(
Expand Down Expand Up @@ -258,7 +367,16 @@ def raw(self) -> TArrayCompatible:
return self._raw # type: ignore[return-value]

def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array."""
"""Return the tensor as a numpy array.
When the data type is not supported by numpy, the value is the bit representation
of the dtype:
- ``int8`` for int4, with the sign bit extended to 8 bits.
- ``uint8`` for uint4.
- ``uint8`` for 8-bit data types like float8.
- ``uint16`` for bfloat16.
"""
if isinstance(self._raw, np.ndarray):
return self._raw
# We do not cache the value to save memory
Expand All @@ -272,8 +390,13 @@ def tobytes(self) -> bytes:
"""
# TODO(justinchuby): Support DLPack
array = self.numpy()
if not IS_LITTLE_ENDIAN:
return array.view(array.dtype.newbyteorder("<")).tobytes()
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
# Pack the array into int4
array = _type_casting.pack_int4(array)
else:
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
if not _IS_LITTLE_ENDIAN:
array = array.view(array.dtype.newbyteorder("<"))
return array.tobytes()

@property
Expand Down
52 changes: 47 additions & 5 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,44 @@ def test_init_requires_type_when_value_is_not_np_array(self):
with self.assertRaises(ValueError):
_core.Tensor(torch_tensor)

def test_init_respects_dtype_when_it_is_provided(self):
array = np.random.rand(1, 2).astype(np.int8)
tensor = _core.Tensor(array, dtype=_enums.DataType.UINT4)
self.assertEqual(tensor.dtype, _enums.DataType.UINT4)
@parameterized.parameterized.expand(
[
("bfloat16", np.uint16, _enums.DataType.BFLOAT16),
(
"float8e4m3fn",
np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)})),
_enums.DataType.FLOAT8E4M3FN,
),
("float8e4m3fnuz", np.uint8, _enums.DataType.FLOAT8E4M3FNUZ),
("float8e5m2", np.uint8, _enums.DataType.FLOAT8E5M2),
("float8e5m2fnuz", np.uint8, _enums.DataType.FLOAT8E5M2FNUZ),
("int4", np.int8, _enums.DataType.INT4),
("int4_uint8", np.uint8, _enums.DataType.INT4),
("uint4", np.uint8, _enums.DataType.UINT4),
]
)
def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: _enums.DataType):
array = np.array([0b1, 0b11], dtype=np_dtype)
tensor = _core.Tensor(array, dtype=dtype)
self.assertEqual(tensor.dtype, dtype)
np.testing.assert_array_equal(tensor, array)

def test_initialize_with_just_np_array(self):
array = np.random.rand(1, 2)
tensor = _core.Tensor(array)
np.testing.assert_array_equal(tensor, array)

def test_initialize_raises_when_numpy_dtype_doesnt_match(self):
array = np.random.rand(1, 2).astype(np.float32)
with self.assertRaises(TypeError):
_core.Tensor(array, dtype=_enums.DataType.INT64)

def test_initialize_raises_when_numpy_dtype_doesnt_match_custom_dtype(self):
custom_dtype = np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)}))
array = np.random.rand(1, 2).astype(custom_dtype)
with self.assertRaises(TypeError):
_core.Tensor(array, dtype=_enums.DataType.BFLOAT16)

def test_initialize_with_torch_tensor(self):
array = np.random.rand(1, 2).astype(np.int64)
np_tensor = _core.Tensor(array)
Expand Down Expand Up @@ -87,7 +115,7 @@ def test_numpy_returns_np_array(self):
np.testing.assert_equal(tensor.numpy(), array)

def test_numpy_returns_data_when_dtype_is_not_supported(self):
array = np.array([1], dtype=np.int8)
array = np.array([1], dtype=np.uint8)
tensor = _core.Tensor(array, dtype=_enums.DataType.INT4)
np.testing.assert_equal(tensor.numpy(), array)

Expand All @@ -97,6 +125,20 @@ def test_tobytes(self):
tensor = _core.Tensor(torch_tensor, dtype=_enums.DataType.FLOAT)
self.assertEqual(tensor.tobytes(), array.tobytes())

def test_tobtyes_returns_packed_data_for_int4(self):
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=_enums.DataType.INT4)
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")

def test_tobtyes_returns_packed_data_for_uint4(self):
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=_enums.DataType.UINT4)
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")

def test_metadata(self):
array = np.random.rand(1, 2).astype(np.float32)
tensor = _core.Tensor(array)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/ir/_display_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class DisplayTest(unittest.TestCase):
def test_tensor_display_does_not_raise_on_nan_values(self):
array_with_nan = np.array([np.inf, -np.inf, np.nan, 5, -10])
array_with_nan = np.array([np.inf, -np.inf, np.nan, 5, -10], dtype=np.float32)
tensor = ir.Tensor(array_with_nan, dtype=ir.DataType.FLOAT)
with contextlib.redirect_stdout(None):
tensor.display()
Expand Down
11 changes: 11 additions & 0 deletions onnxscript/ir/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,14 @@ def __str__(self) -> str:
# ONNX DataType to Numpy dtype. This mapping does not capture ONNX data
# types that are not supported by numpy.
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
_DATA_TYPE_TO_NP_TYPE.update(
{
DataType.FLOAT8E4M3FN: np.dtype("uint8"),
DataType.FLOAT8E4M3FNUZ: np.dtype("uint8"),
DataType.FLOAT8E5M2: np.dtype("uint8"),
DataType.FLOAT8E5M2FNUZ: np.dtype("uint8"),
DataType.UINT4: np.dtype("uint8"),
DataType.INT4: np.dtype("int8"),
DataType.BFLOAT16: np.dtype("uint16"),
}
)
22 changes: 21 additions & 1 deletion onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
tools.
"""

# 👀
# NOTE: Why are we using protocols, instead of abstract base classes?
#
# Protocols are more flexible than abstract base classes. Users can define their
# own classes that implement the protocols without having to inherit from a
# specific base class. For example, a user can define a custom tensor class that
# implements the TensorProtocol without explicitly inheriting, and the IR can
# work with that class without any changes.
#
# `isinstance` checks can be slower with protocols. Avoid using `isinstance`
# checks when you can. Always check for concrete classes first.
#
# NOTE: Why are we using protocols, instead of using concrete classes directly?
#
# Protocols define the interface that is typically more stable. If you find yourself
# updating the protocols, pause 🛑, and carefully make sure it is absolutely needed
# and will improve the design. If you are adding new methods, consider if the method
# should be part of the protocol or if it should be a higher level convenience function
# defined outside the protocol.

from __future__ import annotations

import typing
Expand Down Expand Up @@ -41,7 +61,7 @@
class ArrayCompatible(Protocol):
"""Protocol for array-like objects.
An example of an array-like object is a numpy array or a PyTorch array.
An example of an array-like object is a numpy ndarray or a PyTorch Tensor.
Read more at https://numpy.org/devdocs/user/basics.interoperability.html
"""

Expand Down
Loading

0 comments on commit bca6a64

Please sign in to comment.