Skip to content

Commit

Permalink
[IR] Create ir.tensor() as a convenience Tensor initializer; use ml…
Browse files Browse the repository at this point in the history
…_dtypes to support int4/bfloat16 (#1549)

Now possible to do

```python
tensor1 = ir.tensor(tensor_proto)
tensor2 = ir.tensor(np_array)
tensor3 = ir.tensor([1,2], dtype=ir.DataType.FLOAT)
```

supporting all ONNX dtypes.


- Added ml_dtypes as a new dependency and use it to support
int4/bfloat16.
- Removed the unused float32->float16 helper function

Tested: unit tests and doctests

Fixes #1439
  • Loading branch information
justinchuby authored May 24, 2024
1 parent d316706 commit a6843da
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 123 deletions.
38 changes: 13 additions & 25 deletions docs/intermediate_representation/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,17 @@ In the following scenario, we show how to go from a `TensorProto` to an `ir.Tens

## Working with non-native NumPy dtypes: bfloat16, float8, int4

`ir.Tensor.numpy()` produces a NumPy array representation of the tensor's value. When the tensor has dtype `BFLOAT16`, `FLOAT8[...]` or `[U]INT4` which are not supported by NumPy, the value is the bit representation for the dtype:
`ir.Tensor.numpy()` produces a NumPy array representation of the tensor's value. When the tensor has dtype `BFLOAT16`, `FLOAT8[...]` or `[U]INT4` which are not supported by NumPy, we use dtypes from the `ml_dtypes` package.

`uint4`/`int4` is always unpacked; **`tobyte()` produces a packed representation** as expected.

Initialization of `ir.Tensor` requires the NumPy array to follow the following typing constraints, or have a `ml_dtypes` dtype.

- `int8` for (unpacked) int4, with the sign bit extended to 8 bits.
- `uint8` for (unpacked) uint4.
- `uint8` for 8-bit data types like float8.
- `uint16` for bfloat16.

uint4/int4 is always unpacked; `tobyte()` produces a packed representation as expected.

Initialization of `ir.Tensor` requires the NumPy array to follow these typing constraints as well.

:::{tip}
You can use the [ml_dtypes package](https://github.com/jax-ml/ml_dtypes) to extend NumPy and work with these values.

```bash
pip install --upgrade ml_dtypes
```

:::

The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its values, and create a new tensor to store the transformed values.

```{eval-rst}
Expand All @@ -170,24 +161,21 @@ The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its
import numpy as np
array = np.array([0b1, 0b11], dtype=np.uint8)
# The array is reinterpreted using the ml_dtypes package
tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
print(tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([1, 3], dtype=uint8), name='')
print("tensor.numpy():", tensor.numpy()) # array([1, 3], dtype=uint8)
# You can use the ml_dtypes package to work with these values in NumPy
import ml_dtypes
float8_array = tensor.numpy().view(ml_dtypes.float8_e4m3fn)
print("float8_array:", float8_array) # array([0.00195312, 0.00585938], dtype='float8_e4m3fn')
print(tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
print("tensor.numpy():", tensor.numpy()) # [0.00195312 0.00585938]
# Compute
times_100 = float8_array * 100
times_100 = tensor.numpy() * 100
print("times_100:", times_100)
# Create a new tensor out of the new value; dtype must be specified
new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
print("new_tensor:", new_tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([36, 49], dtype=uint8), name='')
print("new_tensor == times_100", new_tensor.numpy().view(ml_dtypes.float8_e4m3fn) == times_100) # array([ True, True])
# You can also directly create the tensor from the float8 array without specifying dtype
# new_tensor = ir.Tensor(times_100)
print("new_tensor:", new_tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
print("new_tensor == times_100", new_tensor.numpy() == times_100) # array([ True, True])
```

## Advanced Usage
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@
# Conversion functions
"from_proto",
"to_proto",
# IR Tensor initializer
"tensor",
# Pass infrastructure
"passes",
]

from onnxscript.ir import passes, serde
from onnxscript.ir._convenience import tensor
from onnxscript.ir._core import (
Attr,
AttrFloat32,
Expand Down
85 changes: 85 additions & 0 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
"replace_all_uses_with",
]

import typing
from typing import Mapping, Sequence, Union

import numpy as np
import onnx

from onnxscript.ir import _core, _enums, _protocols, serde

if typing.TYPE_CHECKING:
import numpy.typing as npt

SupportedAttrTypes = Union[
str,
int,
Expand Down Expand Up @@ -285,3 +290,83 @@ def replace_all_uses_with(
for value, replacement in zip(values, replacements):
for user_node, index in tuple(value.uses()):
user_node.replace_input_with(index, replacement)


def tensor(
value: npt.ArrayLike
| onnx.TensorProto
| _protocols.DLPackCompatible
| _protocols.ArrayCompatible,
dtype: _enums.DataType | None = None,
name: str | None = None,
doc_string: str | None = None,
) -> _protocols.TensorProtocol:
"""Create a tensor value from an ArrayLike object or a TensorProto.
The dtype must match the value. Reinterpretation of the value is
not supported, unless if the value is a plain Python object, in which case
it is converted to a numpy array with the given dtype.
:param:`value` can be a numpy array, a plain Python object, or a TensorProto.
Example::
>>> from onnxscript import ir
>>> import numpy as np
>>> import ml_dtypes
>>> import onnx
>>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
>>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
>>> tp_tensor.numpy()
array(0.5, dtype=float32)
Args:
value: The numpy array to create the tensor from.
dtype: The data type of the tensor.
name: The name of the tensor.
doc_string: The documentation string of the tensor.
Returns:
A tensor value.
Raises:
ValueError: If the dtype does not match the value when value is not a plain Python
object like ``list[int]``.
"""
if isinstance(value, _protocols.TensorProtocol):
if dtype is not None and dtype != value.dtype:
raise ValueError(
f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
"You do not have to specify the dtype when value is a Tensor."
)
return value
if isinstance(value, onnx.TensorProto):
tensor_ = serde.deserialize_tensor(value)
if name is not None:
tensor_.name = name
if doc_string is not None:
tensor_.doc_string = doc_string
if dtype is not None and dtype != tensor_.dtype:
raise ValueError(
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
"You do not have to specify the dtype when value is a TensorProto."
)
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
else:
if dtype is not None:
numpy_dtype = dtype.numpy()
else:
numpy_dtype = None
array = np.array(value, dtype=numpy_dtype)
tensor_ = _core.Tensor(
array,
dtype=dtype,
shape=_core.Shape(array.shape),
name=name,
doc_string=name,
)
return tensor_
72 changes: 56 additions & 16 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Union,
)

import ml_dtypes
import numpy as np

import onnxscript
Expand Down Expand Up @@ -184,26 +185,33 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
- ``uint8`` for uint4.
- ``uint8`` for 8-bit data types.
- ``uint16`` for bfloat16
or corresponding dtypes from the ``ml_dtype`` package.
"""
if dtype in _NON_NUMPY_NATIVE_TYPES:
if dtype.itemsize == 2 and array.dtype != np.uint16:
# TODO(justinchuby): Support the storage dtypes like uint16 for bfloat16.
if dtype.itemsize == 2 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
raise TypeError(
f"The numpy array dtype must be uint16 (not {array.dtype}) for IR data type {dtype}."
f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
)
if dtype.itemsize == 1 and array.dtype != np.uint8:
if dtype.itemsize == 1 and array.dtype not in (
np.uint8,
ml_dtypes.float8_e4m3b11fnuz,
ml_dtypes.float8_e4m3fn,
ml_dtypes.float8_e5m2fnuz,
ml_dtypes.float8_e5m2,
):
raise TypeError(
f"The numpy array dtype must be uint8 (not {array.dtype}) for IR data type {dtype}."
f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
)
if dtype == _enums.DataType.INT4:
if array.dtype not in (np.int8, np.uint8):
if array.dtype not in (np.int8, np.uint8, ml_dtypes.int4):
raise TypeError(
f"The numpy array dtype must be int8 or uint8 (not {array.dtype}) for IR data type {dtype}."
f"The numpy array dtype must be int8 or uint8 or ml_dtypes.int4 (not {array.dtype}) for IR data type {dtype}."
)
if dtype == _enums.DataType.UINT4:
if array.dtype != np.uint8:
if array.dtype not in (np.uint8, ml_dtypes.uint4):
raise TypeError(
f"The numpy array dtype must be uint8 (not {array.dtype}) for IR data type {dtype}."
f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
)
return

Expand All @@ -222,6 +230,35 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
)


def _maybe_view_np_array_with_ml_dtypes(
array: np.ndarray, dtype: _enums.DataType
) -> np.ndarray:
"""Reinterpret the array when it is a bit representation of a dtype not supported by numpy.
Args:
array: The numpy array to reinterpret.
dtype: The data type to reinterpret the array as.
Returns:
The array reinterpreted as the dtype.
"""
if dtype == _enums.DataType.BFLOAT16:
return array.view(ml_dtypes.bfloat16)
if dtype == _enums.DataType.FLOAT8E4M3FN:
return array.view(ml_dtypes.float8_e4m3fn)
if dtype == _enums.DataType.FLOAT8E4M3FNUZ:
return array.view(ml_dtypes.float8_e4m3fnuz)
if dtype == _enums.DataType.FLOAT8E5M2:
return array.view(ml_dtypes.float8_e5m2)
if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
return array.view(ml_dtypes.float8_e5m2fnuz)
if dtype == _enums.DataType.INT4:
return array.view(ml_dtypes.int4)
if dtype == _enums.DataType.UINT4:
return array.view(ml_dtypes.uint4)
return array


class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
"""An immutable concrete tensor.
Expand Down Expand Up @@ -327,6 +364,11 @@ def __init__(
# Users are responsible for making sure the dtype matches the value
# when value is not a numpy array
self._dtype = dtype

# View the bfloat16, float8 and int4 types using ml_dtypes
if isinstance(value, np.ndarray):
value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]

self._raw = value
self.name = name
self.doc_string = doc_string
Expand Down Expand Up @@ -372,13 +414,9 @@ def raw(self) -> TArrayCompatible:
def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array.
When the data type is not supported by numpy, the value is the bit representation
of the dtype:
- ``int8`` for int4, with the sign bit extended to 8 bits.
- ``uint8`` for uint4.
- ``uint8`` for 8-bit data types like float8.
- ``uint16`` for bfloat16.
When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
package are used. The values can be reinterpreted as bit representations
using the ``.view()`` method.
"""
if isinstance(self._raw, np.ndarray):
return self._raw
Expand Down Expand Up @@ -528,6 +566,8 @@ def _load(self):
# Handle the byte order correctly by always using little endian
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
# Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
dt = np.dtype(np.uint8).newbyteorder("<")
count = self.size // 2 + self.size % 2
else:
count = self.size
Expand Down
25 changes: 22 additions & 3 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: ir.Data
array = np.array([0b1, 0b11], dtype=np_dtype)
tensor = _core.Tensor(array, dtype=dtype)
self.assertEqual(tensor.dtype, dtype)
np.testing.assert_array_equal(tensor, array)
np.testing.assert_array_equal(tensor, array.view(dtype.numpy()))

def test_initialize_with_just_np_array(self):
array = np.random.rand(1, 2)
Expand All @@ -74,6 +74,11 @@ def test_initialize_raises_when_numpy_dtype_doesnt_match(self):
with self.assertRaises(TypeError):
_core.Tensor(array, dtype=ir.DataType.INT64)

def test_initialize_supports_custom_dtype(self):
custom_dtype = np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)}))
array = np.random.rand(1, 2).astype(custom_dtype)
_core.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)

def test_initialize_raises_when_numpy_dtype_doesnt_match_custom_dtype(self):
custom_dtype = np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)}))
array = np.random.rand(1, 2).astype(custom_dtype)
Expand Down Expand Up @@ -134,13 +139,27 @@ def test_tobtyes_returns_packed_data_for_int4(self):
tensor = _core.Tensor(array, dtype=ir.DataType.INT4)
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")

def test_tobtyes_returns_packed_data_for_int4_ml_dtypes(self):
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=ml_dtypes.int4)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.INT4)
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")

def test_tobtyes_returns_packed_data_for_uint4(self):
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.UINT4)
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")

def test_tobtyes_returns_packed_data_for_uint4_ml_dtypes(self):
array = np.array([0, 1, 2, 7, 15], dtype=ml_dtypes.uint4)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.UINT4)
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")

def test_metadata(self):
array = np.random.rand(1, 2).astype(np.float32)
tensor = _core.Tensor(array)
Expand Down Expand Up @@ -339,7 +358,7 @@ def test_external_tensor_float8(self, _: str, dtype: ir.DataType, np_dtype):
]
)
def test_external_tensor_int(self, _: str, dtype: ir.DataType):
expected_array = np.array([[-1, 0, 1, 7]]).astype(dtype.numpy())
expected_array = np.array([[-8, 0, 1, 7]]).astype(dtype.numpy())
tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array, dtype=dtype))
with tempfile.TemporaryDirectory() as temp_dir:
_to_external_tensor(tensor_proto, temp_dir, "tensor.bin")
Expand All @@ -359,7 +378,7 @@ def test_external_tensor_int(self, _: str, dtype: ir.DataType):
]
)
def test_external_tensor_uint(self, _: str, dtype: ir.DataType):
expected_array = np.array([[0, 1, 8]]).astype(dtype.numpy())
expected_array = np.array([[0, 1, 15]]).astype(dtype.numpy())
tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array, dtype=dtype))
with tempfile.TemporaryDirectory() as temp_dir:
_to_external_tensor(tensor_proto, temp_dir, "tensor.bin")
Expand Down
Loading

0 comments on commit a6843da

Please sign in to comment.