-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Create documentation for tensors #1481
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
75d9d34
[IR] Create documentation for tensors
justinchuby 5e568c2
save
justinchuby 441385b
update
justinchuby f7ed9a6
index
justinchuby 3a9a337
Unified TensorProtoTensor
justinchuby b6d7340
snap
justinchuby 7a5418d
Fix
justinchuby da5e0cc
numpy
justinchuby 888ff50
changes
justinchuby a02a9f9
code
justinchuby e936844
Merge branch 'main' into justinchu/docs-tensors-2
justinchuby e351cfb
sphinx_exec_code
justinchuby b268c40
Improve
justinchuby 2224cb9
text
justinchuby 13877d2
int4
justinchuby 5bcb366
format
justinchuby 7c2da92
typo
justinchuby 638f66f
cap
justinchuby 53cc52c
title
justinchuby 37436ec
note
justinchuby f863764
re
justinchuby 1aed7bd
jax[cpu]
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,6 @@ | |
```{toctree} | ||
:maxdepth: 1 | ||
|
||
tensors | ||
ir_api | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,322 @@ | ||
# Tensor Representation in the IR | ||
|
||
The ONNX IR offers the {py:class}`ir.TensorProtocol <onnxscript.ir.TensorProtocol>` interface for usings different data structures as backing data for tensors. Besides the traditional {py:class}`onnx.TensorProto`, you can also use {py:class}`np.ndarray`, {py:class}`torch.Tensor`, {py:class}`jax.Array`, and virtually anything else to represent tensors in the graph. This allows for them to be accessed and serialized via the same `TensorProtocol` interface, without incurring additional copies at initialization. | ||
|
||
## The `TensorProtocol` | ||
|
||
{py:class}`ir.TensorProtocol <onnxscript.ir.TensorProtocol>` defines a read-only interface for representing tensors. A tensor class implementing the interface has attributes like `name`, `shape`, `dtype`, `size`, `nbytes` and `metadata_props` to describe basic properties of the tensor. Additionally, it should implement two methods {py:meth}`numpy <onnxscript.ir.TensorProtocol.numpy>` and {py:meth}`__array__ <onnxscript.ir.TensorProtocol.__array__>` which will produce equivalent NumPy arrays from the backing data. | ||
|
||
:::{note} | ||
When interacting with initializers, constant values and tensor attributes, it is best to assume `TensorProtocol` and only use `isinstance` to check for concrete classes when there is a need. | ||
::: | ||
|
||
## Tensor Classes | ||
|
||
### ir.TensorProtoTensor | ||
|
||
The ONNX spec defines [different ways](https://github.com/onnx/onnx/blob/d6f87121ba256ac6cc4d1da0463c300c278339d2/onnx/onnx.proto#L567-L654) for storing tensor data as an {py:class}`onnx.TensorProto <onnx.ir.TensorProtocol>` protocol buffer message. The IR has corresponding classes for each of these data storage methods. | ||
|
||
We use the {py:class}`ir.TensorProtoTensor <onnxscript.ir.TensorProtoTensor>` as a wrapper around the proto to implement the `ir.TensorProtocol` interface. You can access `shape`, `dtype` etc. as usual. A copy is incurred only when `numpy()` is called. | ||
|
||
:::{note} | ||
Directly initializing an `ir.TensorProtoTensor`, as below, is possible. However, it is usually recommended to use `ir.serde.deserialize_tensor` because it handles all types of `TensorProto`s (`ir.TensorProtoTensor` doesn't handle external tensors, for example). Please refer to [From `TensorProto`s and back](#from-tensorprotos-and-back) for an example. | ||
::: | ||
|
||
```{eval-rst} | ||
.. exec_code:: | ||
|
||
import onnx | ||
from onnxscript import ir | ||
|
||
tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3]) | ||
tensor = ir.TensorProtoTensor(tensor_proto) | ||
print("tensor: ", tensor) # TensorProtoTensor<INT16,[3]>(name='tensor') | ||
print("shape: ", tensor.shape) # ir.Shape([3]) | ||
print("dtype: ", tensor.dtype) # ir.DataType.INT16 | ||
print(tensor.raw == tensor_proto) # The raw field is the exact tensor_proto provided at initialization | ||
print("tobytes: ", tensor.tobytes()) # b'\x01\x00\x02\x00\x03\x00' | ||
print("numpy: ", tensor.numpy()) # array([1, 2, 3], dtype=int16) | ||
``` | ||
|
||
### ir.ExternalTensor | ||
|
||
Tensor data stored externally in the disk are typically large and will take up memory when loaded. The {py:class}`ir.ExternalTensor <onnxscript.ir.ExternalTensor>` class uses memory mapping to avoid loading the tensor into memory. You are able to use the tensor as a normal NumPy array with minimal memory usage. | ||
|
||
Refer to {py:func}`ir.serde.deserialize_tensor <onnxscript.ir.serde.deserialize_tensor>` to find an example on converting an `onnx.TensorProto` to an {py:class}`ir.ExternalTensor <onnxscript.ir.ExternalTensor>`. | ||
|
||
### ir.Tensor | ||
|
||
{py:class}`ir.Tensor <onnxscript.ir.Tensor>` is a wrapper around NumPy array compatible array objects like {py:class}`np.ndarray` and {py:class}`torch.Tensor`. It is best for creating in-memory tensors without converting it to a `TensorProto` to reduce the conversion overhead. | ||
|
||
:::{tip} | ||
An array object is compatible if it defines the `__array__` method. | ||
::: | ||
|
||
To create a tensor from an array, simply initialize it with an NumPy array | ||
|
||
```python | ||
tensor = ir.Tensor(np.random.rand(1, 2)) | ||
``` | ||
|
||
The initializer will obtain dtype and shape information from the array. | ||
|
||
To create a tensor from objects other than NumPy array, you need to specify the dtype: | ||
|
||
```{eval-rst} | ||
.. exec_code:: | ||
|
||
import torch | ||
from onnxscript import ir | ||
|
||
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16) | ||
tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16) | ||
print(tensor.numpy()) # array([1., 2., 3.], dtype=float16) | ||
``` | ||
|
||
### String Tensor | ||
|
||
Use {py:class}`ir.StringTensor <onnxscript.ir.StringTensor>` to create a string tensor. | ||
|
||
<!-- TODO(justinchuby): Document make tensor helper --> | ||
|
||
### Sparse Tensor | ||
|
||
Sparse tensors are not yet supported, but they are on our roadmap. | ||
|
||
## From `TensorProto`s and back | ||
|
||
In the following scenario, we show how to go from a `TensorProto` to an `ir.Tensor`, run some computation, then turn it back to an `ir.Tensor` and finally `TensorProto` | ||
|
||
```{eval-rst} | ||
.. exec_code:: | ||
|
||
from onnxscript import ir | ||
import onnx | ||
import numpy as np | ||
|
||
# 1. Create the TensorProto | ||
proto = onnx.helper.make_tensor( | ||
"tensor", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6] | ||
) | ||
|
||
# 2. Create an IR Tensor from the Protobuf message | ||
tensor = ir.serde.deserialize_tensor(proto) | ||
# Note that we get a TensorProtoTensor that implements the TensorProtocol | ||
print("tensor:", tensor) # TensorProtoTensor<FLOAT16,[2,3]>(name='tensor') | ||
print("tensor.numpy():", tensor.numpy()) # [[1. 2. 3.] | ||
# [4. 5. 6.]] | ||
print("tensor.tobytes():", tensor.tobytes()) # b'\x00<\x00@\x00B\x00D\x00E\x00F' | ||
|
||
# 3. Do computation using numpy | ||
mean = tensor.numpy().mean(axis=0) | ||
print("mean:", mean) # array([2.5, 3.5, 4.5], dtype=float16) | ||
|
||
# 4. Create a Tensor from the ndarray. Note that we use ir.Tensor | ||
tensor_mean = ir.Tensor(mean) | ||
print("tensor_mean:", tensor_mean) # Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name='') | ||
|
||
# 5. Obtain the TensorProto from ir.Tensor | ||
mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean) | ||
print("mean_tensor_proto:", mean_tensor_proto) | ||
print( | ||
"onnx.numpy_helper.to_array(mean_tensor_proto):", | ||
onnx.numpy_helper.to_array(mean_tensor_proto) | ||
# array([2.5, 3.5, 4.5], dtype=float16) | ||
) | ||
|
||
# You can obtain the bytes data as well | ||
print("tensor_mean.tobytes():", tensor_mean.tobytes()) | ||
print("Bytes same as proto:", mean_tensor_proto.raw_data == tensor_mean.tobytes()) | ||
|
||
# Explore other methods defined by TensorProtocol: | ||
print("\n# Explore other methods defined by TensorProtocol:") | ||
print("tensor_mean.shape:", tensor_mean.shape) | ||
print("tensor_mean.dtype:", tensor_mean.dtype) | ||
print("tensor_mean.name:", tensor_mean.name) | ||
print("tensor_mean.doc_string:", tensor_mean.doc_string) | ||
print("tensor_mean.raw:", tensor_mean.raw) | ||
print("tensor_mean.metadata_props:", tensor_mean.metadata_props) | ||
print("tensor_mean.size:", tensor_mean.size) | ||
print("tensor_mean.nbytes:", tensor_mean.nbytes) | ||
print("tensor_mean.raw:", tensor_mean.raw) | ||
print("\nUse the display() method to view the tensor") | ||
tensor_mean.display() | ||
``` | ||
|
||
## Working with non-native NumPy dtypes: bfloat16, float8, int4 | ||
|
||
`ir.Tensor.numpy()` produces a NumPy array representation of the tensor's value. When the tensor has dtype `BFLOAT16`, `FLOAT8[...]` or `[U]INT4` which are not supported by NumPy, the value is the bit representation for the dtype: | ||
|
||
- `int8` for (unpacked) int4, with the sign bit extended to 8 bits. | ||
- `uint8` for (unpacked) uint4. | ||
- `uint8` for 8-bit data types like float8. | ||
- `uint16` for bfloat16. | ||
|
||
uint4/int4 is always unpacked; `tobyte()` produces a packed representation as expected. | ||
|
||
Initialization of `ir.Tensor` requires the NumPy array to follow these typing constraints as well. | ||
|
||
:::{tip} | ||
You can use the [ml_dtypes package](https://github.com/jax-ml/ml_dtypes) to extend NumPy and work with these values. | ||
|
||
```bash | ||
pip install --upgrade ml_dtypes | ||
``` | ||
|
||
::: | ||
|
||
The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its values, and create a new tensor to store the transformed values. | ||
|
||
```{eval-rst} | ||
.. exec_code:: | ||
|
||
from onnxscript import ir | ||
import numpy as np | ||
|
||
array = np.array([0b1, 0b11], dtype=np.uint8) | ||
tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN) | ||
print(tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([1, 3], dtype=uint8), name='') | ||
print("tensor.numpy():", tensor.numpy()) # array([1, 3], dtype=uint8) | ||
|
||
# You can use the ml_dtypes package to work with these values in NumPy | ||
import ml_dtypes | ||
float8_array = tensor.numpy().view(ml_dtypes.float8_e4m3fn) | ||
print("float8_array:", float8_array) # array([0.00195312, 0.00585938], dtype='float8_e4m3fn') | ||
|
||
# Compute | ||
times_100 = float8_array * 100 | ||
print("times_100:", times_100) | ||
|
||
# Create a new tensor out of the new value; dtype must be specified | ||
new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN) | ||
print("new_tensor:", new_tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([36, 49], dtype=uint8), name='') | ||
print("new_tensor == times_100", new_tensor.numpy().view(ml_dtypes.float8_e4m3fn) == times_100) # array([ True, True]) | ||
|
||
``` | ||
|
||
## Advanced Usage | ||
|
||
### Subclass ir.Tensor for More Efficient Access and Broader dtype Support | ||
|
||
{py:class}`ir.Tensor` internally converts any array compatible objects into NumPy arrays to produce the byte representation in `tobytes()`. This can be inefficient due to the additional conversion. It also limits support for dtypes not supported by NumPy like bfloat16, because the `__array__` method would fail. | ||
|
||
To fully support arrays from other frameworks, it is usually a good idea to create specialized classes to handle them. The `TorchTensor` class below demonstrates how you can subclass `ir.Tensor` to handle PyTorch tensors: | ||
|
||
```{eval-rst} | ||
.. exec_code:: | ||
|
||
import ctypes | ||
from typing import Any | ||
|
||
import torch | ||
from onnxscript import ir | ||
|
||
# Define utilities to convert PyTorch data types so users do not need to specify manually | ||
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { | ||
torch.bfloat16: ir.DataType.BFLOAT16, | ||
torch.bool: ir.DataType.BOOL, | ||
torch.complex128: ir.DataType.COMPLEX128, | ||
torch.complex64: ir.DataType.COMPLEX64, | ||
torch.float16: ir.DataType.FLOAT16, | ||
torch.float32: ir.DataType.FLOAT, | ||
torch.float64: ir.DataType.DOUBLE, | ||
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, | ||
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, | ||
torch.float8_e5m2: ir.DataType.FLOAT8E5M2, | ||
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, | ||
torch.int16: ir.DataType.INT16, | ||
torch.int32: ir.DataType.INT32, | ||
torch.int64: ir.DataType.INT64, | ||
torch.int8: ir.DataType.INT8, | ||
torch.uint8: ir.DataType.UINT8, | ||
} | ||
|
||
|
||
def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType: | ||
return _TORCH_DTYPE_TO_ONNX[dtype] | ||
|
||
class TorchTensor(ir.Tensor): | ||
def __init__(self, tensor: torch.Tensor): | ||
# Pass the tensor as the raw data to ir.Tensor's constructor | ||
super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype)) | ||
|
||
def __array__(self, dtype: Any = None) -> "np.ndarray": | ||
# numpy() calls __array__ in ir.Tensor | ||
if self.dtype == ir.DataType.BFLOAT16: | ||
return self.raw.view(torch.uint16).__array__(dtype) | ||
if self.dtype in { | ||
ir.DataType.FLOAT8E4M3FN, | ||
ir.DataType.FLOAT8E4M3FNUZ, | ||
ir.DataType.FLOAT8E5M2, | ||
ir.DataType.FLOAT8E5M2FNUZ | ||
}: | ||
return self.raw.view(torch.uint8).__array__(dtype) | ||
return self.raw.__array__(dtype) | ||
|
||
def tobytes(self) -> bytes: | ||
# Implement tobytes to support native PyTorch types so we can use types like bloat16 | ||
# Reading from memory directly is also more efficient because | ||
# it avoids the copy to NumPy array | ||
tensor = self.raw.detach().cpu().contiguous() | ||
return bytes( | ||
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( | ||
tensor.data_ptr() | ||
) | ||
) | ||
|
||
# Test the implementation | ||
torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16) | ||
tensor = TorchTensor(torch_tensor) | ||
print("tensor: ", tensor) | ||
print("numpy: ", tensor.numpy()) | ||
print("tobytes: ", tensor.tobytes()) # b'\x80?\x00@@@' | ||
print("nbytes: ", tensor.nbytes) # 6 | ||
``` | ||
|
||
The `TorchTensor` class above implements `tobytes()` to produce the correct bytes representation for the tensor when it is serialized into an ONNX file / TensorProto. The class also implements the `__array__()` method to return the bit representation for types NumPy does not support. This way analysis passes can still perform computation on these values. | ||
|
||
### Computation with different Frameworks | ||
|
||
Since `ir.Tensor` implements the `__array__` method and `__dlpack__` methods, its content can be shared with computation frameworks without copying. For example: | ||
|
||
```{eval-rst} | ||
.. exec_code:: | ||
|
||
from onnxscript import ir | ||
|
||
# We can call numpy methods directly on ir.Tensor | ||
import numpy as np | ||
print(np.multiply(ir.Tensor(np.array([1, 2])), 42)) # array([42., 84.]) | ||
|
||
# We can transfer arrays to different frameworks | ||
import jax.numpy as jnp | ||
import jax | ||
import torch | ||
|
||
# Create ir.Tensor | ||
jax_array = jnp.array([10., 20.]) | ||
ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT) | ||
torch_tensor = torch.tensor([30., 40.]) | ||
ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) | ||
|
||
# Use numpy for computation | ||
print(np.multiply(ir_tensor_jax, ir_tensor_torch)) # array([300., 800.], dtype=float32) | ||
|
||
# Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same | ||
jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch) | ||
print(jax_array_from_ir + jax_array) # [40. 60.] | ||
|
||
# Use PyTorch for computation | ||
torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax) | ||
print(torch_tensor_from_ir - torch_tensor) # tensor([-20., -20.]) | ||
|
||
# They can all be serialized into TensorProto | ||
proto = ir.serde.serialize_tensor(ir_tensor_jax) | ||
print(type(proto)) # <class 'onnx.onnx_ml_pb2.TensorProto'> | ||
print(proto) | ||
|
||
# The value is exactly the same as jax_array | ||
print(ir.serde.deserialize_tensor(proto).numpy()) # [10. 20.] | ||
``` | ||
|
||
This is particularly useful if you are creating passes on the graph that requires doing computation on concrete values. You are free to use your favorite frameworks to create the passes. The transformed graph that contains newly created `ir.Tensor`s will be compatible with downstream passes even if they leverage other computation frameworks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the API allow users to create this tensor without using
onnx.helper.make_tensor
? How would they do that?When users create the IR representation from a pre-existing tensor-proto, ideally they should call a single API method (without having to handle the different cases of a TensorProto themselves).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point - I am going to combine all tensor proto classes and simplify as part of #1441
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done: 0408278
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - all users need is a tensor proto as input
I expect to create a similar (make_tensor) api that will produce ir tensor for us