From 4e1f9dd84f604060f6cbd24d7a38e8fa3c37850d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 15 Nov 2024 18:51:18 +0000 Subject: [PATCH 1/2] [IR] Add torch tensor support for ir.Tensor --- onnxscript/ir/_convenience.py | 38 ++++++++++++++++++------------ onnxscript/ir/_convenience_test.py | 22 +++++++++++++++++ onnxscript/ir/tensor_adapters.py | 11 ++++++--- 3 files changed, 53 insertions(+), 18 deletions(-) create mode 100644 onnxscript/ir/_convenience_test.py diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index 7e60ec74d..52ed5fb88 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -20,7 +20,7 @@ import numpy as np import onnx -from onnxscript.ir import _core, _enums, _protocols, serde +from onnxscript.ir import _core, _enums, _protocols, serde, tensor_adapters if typing.TYPE_CHECKING: import numpy.typing as npt @@ -321,6 +321,9 @@ def tensor( >>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5])) >>> tp_tensor.numpy() array(0.5, dtype=float32) + >>> import torch + >>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor") + TorchTensor(tensor([1., 2.]), name='torch_tensor') Args: value: The numpy array to create the tensor from. @@ -353,22 +356,27 @@ def tensor( f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}" "You do not have to specify the dtype when value is a TensorProto." ) + return tensor_ + elif str(type(value)) == "": + # NOTE: We use str(type(...)) and do not import torch for type checking + # as it creates overhead during import + return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)): - tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name) + return _core.Tensor(value, dtype=dtype, name=name, doc_string=name) + + # Plain Python object + if dtype is not None: + numpy_dtype = dtype.numpy() else: - if dtype is not None: - numpy_dtype = dtype.numpy() - else: - numpy_dtype = None - array = np.array(value, dtype=numpy_dtype) - tensor_ = _core.Tensor( - array, - dtype=dtype, - shape=_core.Shape(array.shape), - name=name, - doc_string=name, - ) - return tensor_ + numpy_dtype = None + array = np.array(value, dtype=numpy_dtype) + return _core.Tensor( + array, + dtype=dtype, + shape=_core.Shape(array.shape), + name=name, + doc_string=name, + ) def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]: diff --git a/onnxscript/ir/_convenience_test.py b/onnxscript/ir/_convenience_test.py new file mode 100644 index 000000000..37b08608f --- /dev/null +++ b/onnxscript/ir/_convenience_test.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Unit tests for the _convenience module.""" + +import unittest + +import numpy as np + +from onnxscript.ir import _convenience + + +class ConvenienceTest(unittest.TestCase): + def test_tensor_accepts_torch_tensor(self): + import torch as some_random_name + + torch_tensor = some_random_name.tensor([1, 2, 3]) + tensor = _convenience.tensor(torch_tensor) + np.testing.assert_array_equal(tensor, torch_tensor.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/tensor_adapters.py b/onnxscript/ir/tensor_adapters.py index 10e181152..e24bce026 100644 --- a/onnxscript/ir/tensor_adapters.py +++ b/onnxscript/ir/tensor_adapters.py @@ -38,13 +38,16 @@ import numpy.typing as npt from onnxscript import ir +from onnxscript.ir import _core if TYPE_CHECKING: import torch -class TorchTensor(ir.Tensor): - def __init__(self, tensor: torch.Tensor, name: str | None = None): +class TorchTensor(_core.Tensor): + def __init__( + self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None + ): # Pass the tensor as the raw data to ir.Tensor's constructor import torch @@ -69,7 +72,9 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None): torch.uint32: ir.DataType.UINT32, torch.uint64: ir.DataType.UINT64, } - super().__init__(tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name) + super().__init__( + tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string + ) def numpy(self) -> npt.NDArray: import torch From b346e2cbc731771906ae51b1082b045fdd8c8ab5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 31 Dec 2024 10:20:27 -0800 Subject: [PATCH 2/2] Add type ignore and pylint disable comments --- onnxscript/ir/_convenience.py | 2 +- onnxscript/ir/_convenience_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index 52ed5fb88..d59bfe479 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -360,7 +360,7 @@ def tensor( elif str(type(value)) == "": # NOTE: We use str(type(...)) and do not import torch for type checking # as it creates overhead during import - return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) + return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type] elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)): return _core.Tensor(value, dtype=dtype, name=name, doc_string=name) diff --git a/onnxscript/ir/_convenience_test.py b/onnxscript/ir/_convenience_test.py index 37b08608f..c293a0097 100644 --- a/onnxscript/ir/_convenience_test.py +++ b/onnxscript/ir/_convenience_test.py @@ -11,7 +11,7 @@ class ConvenienceTest(unittest.TestCase): def test_tensor_accepts_torch_tensor(self): - import torch as some_random_name + import torch as some_random_name # pylint: disable=import-outside-toplevel torch_tensor = some_random_name.tensor([1, 2, 3]) tensor = _convenience.tensor(torch_tensor)