diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index 7e60ec74d..d59bfe479 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -20,7 +20,7 @@ import numpy as np import onnx -from onnxscript.ir import _core, _enums, _protocols, serde +from onnxscript.ir import _core, _enums, _protocols, serde, tensor_adapters if typing.TYPE_CHECKING: import numpy.typing as npt @@ -321,6 +321,9 @@ def tensor( >>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5])) >>> tp_tensor.numpy() array(0.5, dtype=float32) + >>> import torch + >>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor") + TorchTensor(tensor([1., 2.]), name='torch_tensor') Args: value: The numpy array to create the tensor from. @@ -353,22 +356,27 @@ def tensor( f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}" "You do not have to specify the dtype when value is a TensorProto." ) + return tensor_ + elif str(type(value)) == "": + # NOTE: We use str(type(...)) and do not import torch for type checking + # as it creates overhead during import + return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type] elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)): - tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name) + return _core.Tensor(value, dtype=dtype, name=name, doc_string=name) + + # Plain Python object + if dtype is not None: + numpy_dtype = dtype.numpy() else: - if dtype is not None: - numpy_dtype = dtype.numpy() - else: - numpy_dtype = None - array = np.array(value, dtype=numpy_dtype) - tensor_ = _core.Tensor( - array, - dtype=dtype, - shape=_core.Shape(array.shape), - name=name, - doc_string=name, - ) - return tensor_ + numpy_dtype = None + array = np.array(value, dtype=numpy_dtype) + return _core.Tensor( + array, + dtype=dtype, + shape=_core.Shape(array.shape), + name=name, + doc_string=name, + ) def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]: diff --git a/onnxscript/ir/_convenience_test.py b/onnxscript/ir/_convenience_test.py new file mode 100644 index 000000000..c293a0097 --- /dev/null +++ b/onnxscript/ir/_convenience_test.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Unit tests for the _convenience module.""" + +import unittest + +import numpy as np + +from onnxscript.ir import _convenience + + +class ConvenienceTest(unittest.TestCase): + def test_tensor_accepts_torch_tensor(self): + import torch as some_random_name # pylint: disable=import-outside-toplevel + + torch_tensor = some_random_name.tensor([1, 2, 3]) + tensor = _convenience.tensor(torch_tensor) + np.testing.assert_array_equal(tensor, torch_tensor.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/tensor_adapters.py b/onnxscript/ir/tensor_adapters.py index 10e181152..e24bce026 100644 --- a/onnxscript/ir/tensor_adapters.py +++ b/onnxscript/ir/tensor_adapters.py @@ -38,13 +38,16 @@ import numpy.typing as npt from onnxscript import ir +from onnxscript.ir import _core if TYPE_CHECKING: import torch -class TorchTensor(ir.Tensor): - def __init__(self, tensor: torch.Tensor, name: str | None = None): +class TorchTensor(_core.Tensor): + def __init__( + self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None + ): # Pass the tensor as the raw data to ir.Tensor's constructor import torch @@ -69,7 +72,9 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None): torch.uint32: ir.DataType.UINT32, torch.uint64: ir.DataType.UINT64, } - super().__init__(tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name) + super().__init__( + tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string + ) def numpy(self) -> npt.NDArray: import torch