From 99409df2b79ecceb9aa6a2f94426b5c67cdc4c4f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 1 Mar 2024 16:25:27 +0000 Subject: [PATCH] TRT 10 supports int64 so no need to cast --- .../tensorrt/tensorrt_execution_provider.cc | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 291728ba48545..a0a3962308103 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -861,8 +861,12 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) - // Cast int64 input to int32 input because TensorRT doesn't support int64 +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Cast int64 input to int32 input because TensorRT < 10 doesn't support int64 CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif // Cast double input to float because TensorRT doesn't support double CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { @@ -949,8 +953,12 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) - // Allocate int32 CUDA memory for int64 output type because TensorRT doesn't support int64 +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Allocate int32 CUDA memory for int64 output type because TensorRT < 10 doesn't support int64 CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif // Allocate float CUDA memory for double output type because TensorRT doesn't support double CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { @@ -1014,8 +1022,12 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) - // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. +#if NV_TENSORRT_MAJOR >= 10 + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // The allocation buffer holds the int32 output data since TRT < 10 doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) +#endif // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { @@ -3431,7 +3443,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Assign TRT output back to ORT output - // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (1) Bind TRT DDS output to ORT kernel context output. // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3454,12 +3466,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } } else { auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { auto output_tensor_ptr = output_tensor.GetTensorMutableData(); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); } - } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { auto output_tensor_ptr = output_tensor.GetTensorMutableData(); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); @@ -3723,7 +3738,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con } // Assign TRT output back to ORT output - // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (1) Bind TRT DDS output to ORT kernel context output. // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3746,12 +3761,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con } } else { auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { auto output_tensor_ptr = output_tensor.GetTensorMutableData(); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); } - } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { auto output_tensor_ptr = output_tensor.GetTensorMutableData(); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);