Skip to content

Commit

Permalink
TRT 10 supports int64 so no need to cast
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Mar 1, 2024
1 parent e962201 commit 99409df
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,12 @@ Status BindContextInput(Ort::KernelContext& ctx,
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
// Cast int64 input to int32 input because TensorRT doesn't support int64
#if NV_TENSORRT_MAJOR >= 10
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
#else
// Cast int64 input to int32 input because TensorRT < 10 doesn't support int64
CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t)
#endif
// Cast double input to float because TensorRT doesn't support double
CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float)
default: {
Expand Down Expand Up @@ -949,8 +953,12 @@ Status BindContextOutput(Ort::KernelContext& ctx,
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
// Allocate int32 CUDA memory for int64 output type because TensorRT doesn't support int64
#if NV_TENSORRT_MAJOR >= 10
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
#else
// Allocate int32 CUDA memory for int64 output type because TensorRT < 10 doesn't support int64
CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t)
#endif
// Allocate float CUDA memory for double output type because TensorRT doesn't support double
CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float)
default: {
Expand Down Expand Up @@ -1014,8 +1022,12 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
// The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output.
#if NV_TENSORRT_MAJOR >= 10
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
#else
// The allocation buffer holds the int32 output data since TRT < 10 doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output.
CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t)
#endif
// The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output.
CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double)
default: {
Expand Down Expand Up @@ -3431,7 +3443,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
}

// Assign TRT output back to ORT output
// (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished)
// (1) Bind TRT DDS output to ORT kernel context output.
// (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
char const* output_name = output_binding_names[i];
Expand All @@ -3454,12 +3466,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
}
} else {
auto& output_tensor = output_tensors[i];
#if NV_TENSORRT_MAJOR < 10
if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
auto output_tensor_ptr = output_tensor.GetTensorMutableData<int64_t>();
if (output_tensor_ptr != nullptr) {
cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
}
} else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
}
#endif
if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
auto output_tensor_ptr = output_tensor.GetTensorMutableData<double>();
if (output_tensor_ptr != nullptr) {
cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
Expand Down Expand Up @@ -3723,7 +3738,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
}

// Assign TRT output back to ORT output
// (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished)
// (1) Bind TRT DDS output to ORT kernel context output.
// (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
char const* output_name = output_binding_names[i];
Expand All @@ -3746,12 +3761,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
}
} else {
auto& output_tensor = output_tensors[i];
#if NV_TENSORRT_MAJOR < 10
if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
auto output_tensor_ptr = output_tensor.GetTensorMutableData<int64_t>();
if (output_tensor_ptr != nullptr) {
cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
}
} else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
}
#endif
if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
auto output_tensor_ptr = output_tensor.GetTensorMutableData<double>();
if (output_tensor_ptr != nullptr) {
cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
Expand Down

0 comments on commit 99409df

Please sign in to comment.