Skip to content

Commit

Permalink
[IR] INT4 support in external tensors (#1510)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #1508
* __->__ #1510

- INT4 support in external tensors
- dlpack support in all tensors, except for external tensors because
they are memory mapped
- Update tensor documentation
- Test

Signed-off-by: Justin Chu <[email protected]>

#1499
  • Loading branch information
justinchuby authored May 8, 2024
1 parent c3e393d commit af6afd1
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 78 deletions.
8 changes: 3 additions & 5 deletions docs/intermediate_representation/tensors.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Tensor Representation in the IR

The ONNX IR offers the {py:class}`ir.TensorProtocol <onnxscript.ir.TensorProtocol>` interface for usings different data structures as backing data for tensors. Besides the traditional {py:class}`onnx.TensorProto`, you can also use {py:class}`np.ndarray`, {py:class}`torch.Tensor`, {py:class}`jax.Array`, and virtually anything else to represent tensors in the graph. This allows for them to be accessed and serialized via the same `TensorProtocol` interface, without incurring additional copies at initialization.
The ONNX IR offers the {py:class}`ir.TensorProtocol <onnxscript.ir.TensorProtocol>` interface for using different data structures as backing data for tensors. Besides the traditional {py:class}`onnx.TensorProto`, you can use {py:class}`np.ndarray`, {py:class}`torch.Tensor`, {py:class}`jax.Array`, and virtually anything else to represent tensors in the graph. This allows them to be accessed and serialized via the same `TensorProtocol` interface, without incurring additional copies during initialization.

## The `TensorProtocol`

Expand All @@ -14,8 +14,6 @@ When interacting with initializers, constant values and tensor attributes, it is

### ir.TensorProtoTensor

The ONNX spec defines [different ways](https://github.com/onnx/onnx/blob/d6f87121ba256ac6cc4d1da0463c300c278339d2/onnx/onnx.proto#L567-L654) for storing tensor data as an {py:class}`onnx.TensorProto <onnx.ir.TensorProtocol>` protocol buffer message. The IR has corresponding classes for each of these data storage methods.

We use the {py:class}`ir.TensorProtoTensor <onnxscript.ir.TensorProtoTensor>` as a wrapper around the proto to implement the `ir.TensorProtocol` interface. You can access `shape`, `dtype` etc. as usual. A copy is incurred only when `numpy()` is called.

:::{note}
Expand Down Expand Up @@ -196,7 +194,7 @@ The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its

## Advanced Usage

### Subclass ir.Tensor for More Efficient Access and Broader dtype Support
### Subclass `ir.Tensor` for More Efficient Access and Broader `dtype` Support

{py:class}`ir.Tensor` internally converts any array compatible objects into NumPy arrays to produce the byte representation in `tobytes()`. This can be inefficient due to the additional conversion. It also limits support for dtypes not supported by NumPy like bfloat16, because the `__array__` method would fail.

Expand Down Expand Up @@ -256,7 +254,7 @@ To fully support arrays from other frameworks, it is usually a good idea to crea
def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16
# Reading from memory directly is also more efficient because
# it avoids the copy to NumPy array
# it avoids copying to a NumPy array
tensor = self.raw.detach().cpu().contiguous()
return bytes(
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
Expand Down
5 changes: 3 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
"beartype==0.17.2",
"expecttest==0.1.6",
"hypothesis",
'numpy==1.24.4; python_version<"3.12"',
'numpy>1.26.0; python_version>="3.12"',
'numpy==1.24.4; python_version<"3.9"',
'numpy==1.26.0; python_version>="3.9"',
"packaging",
"parameterized",
"pyinstrument",
Expand All @@ -26,6 +26,7 @@
"pyyaml",
"types-PyYAML",
"typing_extensions",
"ml_dtypes",
)
ONNX = "onnx==1.16"
ONNX_RUNTIME = "onnxruntime==1.17.1"
Expand Down
10 changes: 10 additions & 0 deletions onnxscript/_internal/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,13 @@ def onnxruntime_older_than(version: str) -> bool:
packaging.version.parse(onnxruntime.__version__).release
< packaging.version.parse(version).release
)


def numpy_older_than(version: str) -> bool:
"""Returns True if the numpy version is older than the given version."""
import numpy # pylint: disable=import-outside-toplevel

return (
packaging.version.parse(numpy.__version__).release
< packaging.version.parse(version).release
)
67 changes: 46 additions & 21 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
)


class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
"""An immutable concrete tensor.
This class is a wrapper around the raw tensor data. The raw tensor data can be a numpy array
Expand Down Expand Up @@ -411,15 +411,15 @@ def metadata_props(self) -> dict[str, str]:
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata


class ExternalTensor(TensorBase, _protocols.TensorProtocol):
class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
"""An immutable concrete tensor with its data store on disk.
This class uses memory mapping to avoid loading the tensor into memory,
Expand All @@ -432,7 +432,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol):
To obtain an array, call :meth:`numpy`. To obtain the bytes,
call :meth:`tobytes`.
The :attribute:`path` can be a relative path or an absolute path.
The :attr:`path` can be a relative path or an absolute path.
Serializers should handle the path correctly to conform with the ONNX spec.
Attributes:
Expand Down Expand Up @@ -512,6 +512,10 @@ def shape(self) -> Shape:

def _load(self):
assert self._array is None, "Bug: The array should be loaded only once."
if self.size == 0:
# When the size is 0, mmap is impossible and meaningless
self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
return
# Map the whole file into the memory
# TODO(justinchuby): Verify if this would exhaust the memory address space
with open(self._path, "rb") as f:
Expand All @@ -522,9 +526,19 @@ def _load(self):
)
# Handle the byte order correctly by always using little endian
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
self._array = np.frombuffer(
self.raw, dtype=dt, offset=self.offset or 0, count=self.size
).reshape(self.shape.numpy())
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
count = self.size // 2 + self.size % 2
else:
count = self.size
self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
shape = self.shape.numpy()
if self.dtype == _enums.DataType.INT4:
# Unpack the int4 arrays
self._array = _type_casting.unpack_int4(self._array, shape)
elif self.dtype == _enums.DataType.UINT4:
self._array = _type_casting.unpack_uint4(self._array, shape)
else:
self._array = self._array.reshape(shape)

def __array__(self, dtype: Any = None) -> np.ndarray:
if self._array is None:
Expand All @@ -533,7 +547,16 @@ def __array__(self, dtype: Any = None) -> np.ndarray:
return self._array.__array__(dtype)

def __dlpack__(self, *, stream: Any = None) -> Any:
return self.numpy().__dlpack__(stream=stream)
raise NotImplementedError(
"ExternalTensor does not support DLPack because it uses memory mapping. "
"Call numpy() to get a numpy array instead."
)

def __dlpack_device__(self) -> tuple[int, int]:
raise NotImplementedError(
"ExternalTensor does not support DLPack because it uses memory mapping. "
"Call numpy() to get a numpy array instead."
)

def __repr__(self) -> str:
return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})"
Expand Down Expand Up @@ -570,15 +593,15 @@ def metadata_props(self) -> dict[str, str]:
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata


class StringTensor(TensorBase, _protocols.TensorProtocol):
class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""

__slots__ = (
Expand Down Expand Up @@ -680,7 +703,7 @@ def metadata_props(self) -> dict[str, str]:
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
Expand Down Expand Up @@ -1168,7 +1191,7 @@ def attributes(self) -> OrderedDict[str, Attr | RefAttr]:
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
Expand Down Expand Up @@ -1423,7 +1446,7 @@ def type(self) -> _protocols.TypeProtocol | None:
Example types can be ``TensorType``, ``SparseTensorType``, ``SequenceType``, ``OptionalType``.
To obtain the data type of the tensor, use ``type.dtype`` or conveniently
:attribute:`dtype`.
:attr:`dtype`.
"""
return self._type

Expand All @@ -1444,7 +1467,7 @@ def dtype(self, value: _enums.DataType) -> None:
If the type is not set, it will be initialized to a new TensorType. To
set the type as other types like ``SequenceType``, initialize the type
then set :attribute:`type` instead.
then set :attr:`type` instead.
"""
if self._type is None:
self._type = TensorType(value)
Expand Down Expand Up @@ -1487,7 +1510,7 @@ def const_value(
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
Expand Down Expand Up @@ -1728,8 +1751,9 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
Args:
nodes: The node to remove.
safe: If True, performs the following actions before removal:
1. It checks to make sure there are no users of the node that are not
to be removed before removing it.
to be removed before removing it.
2. It checks the node does not contribute to any graph outputs.
3. It removes references to all inputs so it is no longer a user of other nodes.
Expand Down Expand Up @@ -1798,7 +1822,7 @@ def sort(self) -> None:
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
Expand Down Expand Up @@ -1963,7 +1987,7 @@ def __reversed__(self) -> Iterator[Node]:
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
Expand Down Expand Up @@ -2048,7 +2072,7 @@ def opset_imports(self) -> dict[str, int]:
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
Expand Down Expand Up @@ -2210,7 +2234,7 @@ def opset_imports(self) -> dict[str, int]:
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
Expand Down Expand Up @@ -2241,8 +2265,9 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
Args:
nodes: The node to remove.
safe: If True, performs the following actions before removal:
1. It checks to make sure there are no users of the node that are not
to be removed before removing it.
to be removed before removing it.
2. It checks the node does not contribute to any graph outputs.
3. It removes references to all inputs so it is no longer a user of other nodes.
Expand Down
Loading

0 comments on commit af6afd1

Please sign in to comment.