diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index d0304160dc68d..f4f10dc4b4b97 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -602,7 +602,7 @@ def bind_input(self, name, device_type, device_id, element_type, shape, buffer_p :param name: input name :param device_type: e.g. cpu, cuda, cann :param device_id: device id, e.g. 0 - :param element_type: input element type + :param element_type: input element type. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16) :param shape: input shape :param buffer_ptr: memory pointer to input data """ @@ -641,7 +641,7 @@ def bind_output( :param name: output name :param device_type: e.g. cpu, cuda, cann, cpu by default :param device_id: device id, e.g. 0 - :param element_type: output element type + :param element_type: output element type. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16) :param shape: output shape :param buffer_ptr: memory pointer to output data """ @@ -758,31 +758,43 @@ def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0): ) @staticmethod - def ortvalue_from_numpy_with_onnxtype(data: Sequence[int], onnx_element_type: int): + def ortvalue_from_numpy_with_onnx_type(data, onnx_element_type: int): """ - This method creates an instance of OrtValue on top of the numpy array + This method creates an instance of OrtValue on top of the numpy array. No data copy is made and the lifespan of the resulting OrtValue should never exceed the lifespan of bytes object. The API attempts to reinterpret the data type which is expected to be the same size. This is useful when we want to use an ONNX data type that is not supported by numpy. - :param data: numpy array. + :param data: numpy.ndarray. :param onnx_elemenet_type: a valid onnx TensorProto::DataType enum value """ - return OrtValue(C.OrtValue.ortvalue_from_numpy_with_onnxtype(data, onnx_element_type), data) + return OrtValue(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data) @staticmethod - def ortvalue_from_shape_and_type(shape=None, element_type=None, device_type="cpu", device_id=0): + def ortvalue_from_shape_and_type(shape, element_type, device_type: str = "cpu", device_id: int = 0): """ Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type :param shape: List of integers indicating the shape of the OrtValue - :param element_type: The data type of the elements in the OrtValue (numpy type) + :param element_type: The data type of the elements. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16). :param device_type: e.g. cpu, cuda, cann, cpu by default :param device_id: device id, e.g. 0 """ - if shape is None or element_type is None: - raise ValueError("`element_type` and `shape` are to be provided if pre-allocated memory is provided") + # Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html). + # This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy. + if isinstance(element_type, int): + return OrtValue( + C.OrtValue.ortvalue_from_shape_and_onnx_type( + shape, + element_type, + C.OrtDevice( + get_ort_device_type(device_type, device_id), + C.OrtDevice.default_memory(), + device_id, + ), + ) + ) return OrtValue( C.OrtValue.ortvalue_from_shape_and_type( diff --git a/onnxruntime/python/onnxruntime_pybind_iobinding.cc b/onnxruntime/python/onnxruntime_pybind_iobinding.cc index 37081cd0ff2b4..f26e188187412 100644 --- a/onnxruntime/python/onnxruntime_pybind_iobinding.cc +++ b/onnxruntime/python/onnxruntime_pybind_iobinding.cc @@ -20,6 +20,39 @@ namespace python { namespace py = pybind11; +namespace { +void BindOutput(SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, + MLDataType element_type, const std::vector& shape, int64_t data_ptr) { + ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid"); + InferenceSession* sess = io_binding->GetInferenceSession(); + auto px = sess->GetModelOutputs(); + if (!px.first.IsOK() || !px.second) { + throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); + } + + // For now, limit binding support to only non-string Tensors + const auto& def_list = *px.second; + onnx::TypeProto type_proto; + if (!CheckIfTensor(def_list, name, type_proto)) { + throw std::runtime_error("Only binding Tensors is currently supported"); + } + + ORT_ENFORCE(utils::HasTensorType(type_proto) && utils::HasElemType(type_proto.tensor_type())); + if (type_proto.tensor_type().elem_type() == onnx::TensorProto::STRING) { + throw std::runtime_error("Only binding non-string Tensors is currently supported"); + } + + OrtValue ml_value; + OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id()); + Tensor::InitOrtValue(element_type, gsl::make_span(shape), reinterpret_cast(data_ptr), info, ml_value); + + auto status = io_binding->Get()->BindOutput(name, ml_value); + if (!status.IsOK()) { + throw std::runtime_error("Error when binding output: " + status.ErrorMessage()); + } +} +} // namespace + void addIoBindingMethods(pybind11::module& m) { py::class_ session_io_binding(m, "SessionIOBinding"); session_io_binding @@ -58,6 +91,18 @@ void addIoBindingMethods(pybind11::module& m) { } }) // This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo + .def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, int32_t element_type, const std::vector& shape, int64_t data_ptr) -> void { + auto ml_type = OnnxTypeToOnnxRuntimeTensorType(element_type); + OrtValue ml_value; + OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id()); + Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast(data_ptr), info, ml_value); + + auto status = io_binding->Get()->BindInput(name, ml_value); + if (!status.IsOK()) { + throw std::runtime_error("Error when binding input: " + status.ErrorMessage()); + } + }) + // This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo .def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector& shape, int64_t data_ptr) -> void { PyArray_Descr* dtype; if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) { @@ -90,28 +135,14 @@ void addIoBindingMethods(pybind11::module& m) { throw std::runtime_error("Error when synchronizing bound inputs: " + status.ErrorMessage()); } }) + // This binds output to a pre-allocated memory as a Tensor. + // The element type is onnx type , or key in onnx.mapping.TENSOR_TYPE_MAP (https://onnx.ai/onnx/api/mapping.html) + .def("bind_output", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, int32_t element_type, const std::vector& shape, int64_t data_ptr) -> void { + MLDataType ml_type = OnnxTypeToOnnxRuntimeTensorType(element_type); + BindOutput(io_binding, name, device, ml_type, shape, data_ptr); + }) // This binds output to a pre-allocated memory as a Tensor .def("bind_output", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector& shape, int64_t data_ptr) -> void { - ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid"); - - InferenceSession* sess = io_binding->GetInferenceSession(); - auto px = sess->GetModelOutputs(); - if (!px.first.IsOK() || !px.second) { - throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); - } - - // For now, limit binding support to only non-string Tensors - const auto& def_list = *px.second; - onnx::TypeProto type_proto; - if (!CheckIfTensor(def_list, name, type_proto)) { - throw std::runtime_error("Only binding Tensors is currently supported"); - } - - ORT_ENFORCE(utils::HasTensorType(type_proto) && utils::HasElemType(type_proto.tensor_type())); - if (type_proto.tensor_type().elem_type() == onnx::TensorProto::STRING) { - throw std::runtime_error("Only binding non-string Tensors is currently supported"); - } - PyArray_Descr* dtype; if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) { throw std::runtime_error("Not a valid numpy type"); @@ -119,15 +150,8 @@ void addIoBindingMethods(pybind11::module& m) { int type_num = dtype->type_num; Py_DECREF(dtype); - OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id()); auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num); - OrtValue ml_value; - Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast(data_ptr), info, ml_value); - - auto status = io_binding->Get()->BindOutput(name, ml_value); - if (!status.IsOK()) { - throw std::runtime_error("Error when binding output: " + status.ErrorMessage()); - } + BindOutput(io_binding, name, device, ml_type, shape, data_ptr); }) // This binds output to a device. Meaning that the output OrtValue must be allocated on a specific device. .def("bind_output", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device) -> void { diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 6ed4c42bd4304..084ee6bc50698 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -467,6 +467,10 @@ MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type) { } } +MLDataType OnnxTypeToOnnxRuntimeTensorType(int onnx_element_type) { + return DataTypeImpl::TensorTypeFromONNXEnum(onnx_element_type)->GetElementType(); +} + // This is a one time use, ad-hoc allocator that allows Tensors to take ownership of // python array objects and use the underlying memory directly and // properly deallocated them when they are done. diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index c76292040b61b..78a5ea4368ae9 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -40,6 +40,8 @@ int OnnxRuntimeTensorToNumpyType(const DataTypeImpl* tensor_type); MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type); +MLDataType OnnxTypeToOnnxRuntimeTensorType(int onnx_element_type); + using MemCpyFunc = void (*)(void*, const void*, size_t); using DataTransferAlternative = std::variant; diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index e338634d73bd3..18785cd607eaa 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -21,6 +21,42 @@ namespace python { namespace py = pybind11; +namespace { +std::unique_ptr OrtValueFromShapeAndType(const std::vector& shape, + MLDataType element_type, + const OrtDevice& device) { + AllocatorPtr allocator; + if (strcmp(GetDeviceName(device), CPU) == 0) { + allocator = GetAllocator(); + } else if (strcmp(GetDeviceName(device), CUDA) == 0) { +#ifdef USE_CUDA + if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { + throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); + } + allocator = GetCudaAllocator(device.Id()); +#else + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); +#endif + } else if (strcmp(GetDeviceName(device), DML) == 0) { +#if USE_DML + allocator = GetDmlAllocator(device.Id()); +#else + throw std::runtime_error( + "Can't allocate memory on the DirectML device using this package of OnnxRuntime. " + "Please use the DirectML package of OnnxRuntime to use this feature."); +#endif + } else { + throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); + } + + auto ml_value = std::make_unique(); + Tensor::InitOrtValue(element_type, gsl::make_span(shape), std::move(allocator), *ml_value); + return ml_value; +} +} // namespace + void addOrtValueMethods(pybind11::module& m) { py::class_ ortvalue_binding(m, "OrtValue"); ortvalue_binding @@ -144,13 +180,12 @@ void addOrtValueMethods(pybind11::module& m) { }) // Create an ortvalue value on top of the numpy array, but interpret the data // as a different type with the same element size. - .def_static("ortvalue_from_numpy_with_onnxtype", [](py::array& data, int32_t onnx_element_type) -> std::unique_ptr { + .def_static("ortvalue_from_numpy_with_onnx_type", [](py::array& data, int32_t onnx_element_type) -> std::unique_ptr { if (!ONNX_NAMESPACE::TensorProto_DataType_IsValid(onnx_element_type)) { ORT_THROW("Not a valid ONNX Tensor data type: ", onnx_element_type); } - const auto element_type = DataTypeImpl::TensorTypeFromONNXEnum(onnx_element_type) - ->GetElementType(); + const auto element_type = OnnxTypeToOnnxRuntimeTensorType(onnx_element_type); const auto element_size = element_type->Size(); if (narrow(data.itemsize()) != element_size) { @@ -164,11 +199,11 @@ void addOrtValueMethods(pybind11::module& m) { const_cast(data.data()), cpu_allocator->Info(), *ort_value); return ort_value; }) - // Factory method to create an OrtValue (Tensor) from the given shape and element type with memory on the specified device + // Factory method to create an OrtValue from the given shape and numpy element type on the specified device. // The memory is left uninitialized - .def_static("ortvalue_from_shape_and_type", [](const std::vector& shape, py::object& element_type, const OrtDevice& device) { + .def_static("ortvalue_from_shape_and_type", [](const std::vector& shape, py::object& numpy_element_type, const OrtDevice& device) -> std::unique_ptr { PyArray_Descr* dtype; - if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) { + if (!PyArray_DescrConverter(numpy_element_type.ptr(), &dtype)) { throw std::runtime_error("Not a valid numpy type"); } @@ -179,36 +214,18 @@ void addOrtValueMethods(pybind11::module& m) { throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays"); } - AllocatorPtr allocator; - if (strcmp(GetDeviceName(device), CPU) == 0) { - allocator = GetAllocator(); - } else if (strcmp(GetDeviceName(device), CUDA) == 0) { -#ifdef USE_CUDA - if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } - allocator = GetCudaAllocator(device.Id()); -#else - throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); -#endif - } else if (strcmp(GetDeviceName(device), DML) == 0) { -#if USE_DML - allocator = GetDmlAllocator(device.Id()); -#else - throw std::runtime_error( - "Can't allocate memory on the DirectML device using this package of OnnxRuntime. " - "Please use the DirectML package of OnnxRuntime to use this feature."); -#endif - } else { - throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); + auto element_type = NumpyTypeToOnnxRuntimeTensorType(type_num); + return OrtValueFromShapeAndType(shape, element_type, device); + }) + // Factory method to create an OrtValue from the given shape and onnx element type on the specified device. + // The memory is left uninitialized + .def_static("ortvalue_from_shape_and_onnx_type", [](const std::vector& shape, int32_t onnx_element_type, const OrtDevice& device) -> std::unique_ptr { + if (onnx_element_type == onnx::TensorProto_DataType::TensorProto_DataType_STRING) { + throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays"); } - auto ml_value = std::make_unique(); - auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num); - Tensor::InitOrtValue(ml_type, gsl::make_span(shape), std::move(allocator), *ml_value); - return ml_value; + auto element_type = OnnxTypeToOnnxRuntimeTensorType(onnx_element_type); + return OrtValueFromShapeAndType(shape, element_type, device); }) #if !defined(DISABLE_SPARSE_TENSORS) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 9419761340517..9b26944629aa6 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1391,7 +1391,9 @@ def test_session_with_ortvalue_input(ortvalue): # test ort_value creation on top of the bytes float_tensor_data_type = 1 # TensorProto_DataType_FLOAT - ort_value_with_type = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(numpy_arr_input, float_tensor_data_type) + ort_value_with_type = onnxrt.OrtValue.ortvalue_from_numpy_with_onnx_type( + numpy_arr_input, float_tensor_data_type + ) self.assertTrue(ort_value_with_type.is_tensor()) self.assertEqual(float_tensor_data_type, ort_value_with_type.element_type()) self.assertEqual([3, 2], ort_value_with_type.shape()) @@ -1843,8 +1845,8 @@ def test_adater_export_read(self): param_1 = np.array(val).astype(np.float32).reshape(5, 2) param_2 = np.array(val).astype(np.int64).reshape(2, 5) - ort_val_1 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(param_1, float_data_type) - ort_val_2 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(param_2, int64_data_type) + ort_val_1 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnx_type(param_1, float_data_type) + ort_val_2 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnx_type(param_2, int64_data_type) params = {"param_1": ort_val_1, "param_2": ort_val_2} diff --git a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py index 56417f13fbea4..01269bc02d77c 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py +++ b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py @@ -7,9 +7,9 @@ import numpy as np from helper import get_name from numpy.testing import assert_almost_equal -from onnx import helper +from onnx import TensorProto, helper from onnx.defs import onnx_opset_version -from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE +from onnx.mapping import TENSOR_TYPE_MAP import onnxruntime as onnxrt from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice # pylint: disable=E0611 @@ -104,7 +104,7 @@ def test_bind_input_types(self): ]: with self.subTest(dtype=dtype, inner_device=str(inner_device)): x = np.arange(8).reshape((-1, 2)).astype(dtype) - proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype] + proto_dtype = helper.np_dtype_to_tensor_dtype(x.dtype) X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 @@ -143,6 +143,116 @@ def test_bind_input_types(self): y = ortvalue.numpy() assert_almost_equal(x, y) + def test_bind_onnx_types_supported_by_numpy(self): + opset = onnx_opset_version() + devices = [ + ( + C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), + ["CPUExecutionProvider"], + ), + ] + + for inner_device, provider in devices: + for onnx_dtype in [ + TensorProto.FLOAT, + TensorProto.UINT8, + TensorProto.INT8, + TensorProto.UINT16, + TensorProto.INT16, + TensorProto.INT32, + TensorProto.INT64, + TensorProto.BOOL, + TensorProto.FLOAT16, + TensorProto.DOUBLE, + TensorProto.UINT32, + TensorProto.UINT64, + ]: + with self.subTest(onnx_dtype=onnx_dtype, inner_device=str(inner_device)): + assert onnx_dtype in TENSOR_TYPE_MAP + np_dtype = TENSOR_TYPE_MAP[onnx_dtype].np_dtype + x = np.arange(8).reshape((-1, 2)).astype(np_dtype) + + # create onnx graph + X = helper.make_tensor_value_info("X", onnx_dtype, [None, x.shape[1]]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", onnx_dtype, [None, x.shape[1]]) # noqa: N806 + node_add = helper.make_node("Identity", ["X"], ["Y"]) + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=7, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + + ort_value_x = C_OrtValue.ortvalue_from_numpy(x, inner_device) + ort_value_y = onnxrt.OrtValue.ortvalue_from_shape_and_type(x.shape, onnx_dtype) + + sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider) + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", inner_device, onnx_dtype, x.shape, ort_value_x.data_ptr()) + bind.bind_output("Y", inner_device, onnx_dtype, x.shape, ort_value_y.data_ptr()) + sess._sess.run_with_iobinding(bind, None) + assert_almost_equal(x, ort_value_y.numpy()) + + # Test I/O binding with onnx types like bfloat16 and float8, which are not supported in numpy. + def test_bind_onnx_types_not_supported_by_numpy(self): + try: + import torch + except ImportError: + self.skipTest("Skipping since PyTorch is not installed.") + + opset = onnx_opset_version() + devices = [ + ( + C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), + ["CPUExecutionProvider"], + ), + ] + + onnx_to_torch_type_map = {TensorProto.BFLOAT16: torch.bfloat16} + # Float8 support requires torch >= 2.1.0 + if hasattr(torch, "float8_e4m3fn") and hasattr(torch, "float8_e5m2"): + onnx_to_torch_type_map.update( + { + TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, + TensorProto.FLOAT8E5M2: torch.float8_e5m2, + } + ) + + for inner_device, provider in devices: + for onnx_dtype in onnx_to_torch_type_map: + with self.subTest(onnx_dtype=onnx_dtype, inner_device=str(inner_device)): + + # Create onnx graph with dynamic axes + X = helper.make_tensor_value_info("X", onnx_dtype, [None]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", onnx_dtype, [None]) # noqa: N806 + node_add = helper.make_node("Identity", ["X"], ["Y"]) + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=10, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + + sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider) + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + x = torch.arange(8).to(torch_dtype) + y = torch.empty(8, dtype=torch_dtype) + + bind = sess.io_binding() + bind.bind_input("X", x.device.type, 0, onnx_dtype, x.shape, x.data_ptr()) + bind.bind_output("Y", y.device.type, 0, onnx_dtype, y.shape, y.data_ptr()) + sess.run_with_iobinding(bind) + if onnx_dtype != TensorProto.BFLOAT16: + # torch has no cpu equal implementation of float8, so we compare them after casting to float. + self.assertTrue(torch.equal(x.to(torch.float), y.to(torch.float))) + else: + self.assertTrue(torch.equal(x, y)) + def test_bind_input_only(self): for device, execution_provider, _ in test_params: with self.subTest(execution_provider):