diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 632d521dc21a8..6a410c6b10cdf 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -395,43 +395,13 @@ std::unique_lock TensorrtExecutionProvider::GetApiLock() const { /* * Get the shape of "shape tensor" input */ +template Status GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, - std::vector& shape_values, - nvinfer1::ICudaEngine* trt_engine, - const char* input_name, + void* shape_values, + int shape_size, cudaStream_t stream) { - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shapes = tensor_info.GetShape(); - const auto tensor_type = tensor_info.GetElementType(); - nvinfer1::Dims dims = trt_engine->getTensorShape(input_name); - int nb_dims = dims.nbDims; - int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension - shape_values.resize(shape_size, 1); - - switch (tensor_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto input = std::make_unique(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); - for (int j = 0; j < shape_size; ++j) { - shape_values[j] = input[j]; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - auto input = std::make_unique(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); - for (int j = 0; j < shape_size; ++j) { - shape_values[j] = static_cast(input[j]); - } - break; - } - default: { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported."); - } - } + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(shape_values, input_tensor.GetTensorData(), shape_size * sizeof(T), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); return Status::OK(); } @@ -558,7 +528,8 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector& input_indexes, - std::unordered_map>& tensor_shape_values, + std::unordered_map>& shape_tensor_values, // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map>& shape_tensor_values_int64, // same as above but for int64 shape tensor input cudaStream_t stream, bool* engine_update) { for (size_t i = 0; i < trt_profiles.size(); i++) { @@ -611,26 +582,33 @@ Status ApplyProfileShapesFromInputTensorValue(std::vectorisShapeTensor()) { // Get shape values for shape tensor input const auto tensor_type = tensor_info.GetElementType(); - int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension - tensor_shape_values[input_name].resize(shape_size); + int shape_size = dims.nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + std::vector values(shape_size); // For setting TRT optimization profile. Note: the min/opt/max profile values are still int32 even though int64 is supported after TRT 10. + switch (tensor_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto input_shape = std::make_unique(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData(), - shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + auto input = std::make_unique(shape_size); + auto status = GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + shape_tensor_values[input_name].resize(shape_size); for (int j = 0; j < shape_size; ++j) { - tensor_shape_values[input_name][j] = input_shape[j]; + shape_tensor_values[input_name][j] = input[j]; + values[j] = input[j]; } break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - auto input_shape = std::make_unique(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData(), - shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + auto input = std::make_unique(shape_size); + auto status = GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + shape_tensor_values_int64[input_name].resize(shape_size); for (int j = 0; j < shape_size; ++j) { - tensor_shape_values[input_name][j] = static_cast(input_shape[j]); + shape_tensor_values_int64[input_name][j] = input[j]; + values[j] = static_cast(input[j]); } break; } @@ -651,7 +629,7 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(shape_range[1]); shapes_opt[j] = static_cast(shape_range[2]); - const auto& tensor_shape_value = tensor_shape_values[input_name][j]; + const auto& tensor_shape_value = values[j]; // Update shape range lower bound if (tensor_shape_value < shape_range[0]) { shape_range[0] = tensor_shape_value; @@ -671,7 +649,7 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector> profile_vector; std::vector shape_vector{tensor_shape_value, tensor_shape_value, tensor_shape_value}; profile_vector.push_back(shape_vector); // only one profile needed @@ -804,7 +782,8 @@ Status BindContextInput(Ort::KernelContext& ctx, nvinfer1::IExecutionContext* trt_context, const char* input_name, size_t input_index, - std::vector& shape_values, // only for "shape tensor" + std::unordered_map>& shape_tensor_values, // only use for int32 shape tensor + std::unordered_map>& shape_tensor_values_int64, // only use for int64 shape tensor std::vector>& scratch_buffers, OrtAllocator* alloc, cudaStream_t stream) { @@ -825,19 +804,56 @@ Status BindContextInput(Ort::KernelContext& ctx, const auto elem_cnt = tensor_info.GetElementCount(); if (trt_engine->isShapeInferenceIO(input_name)) { - // Get the shape value of "shape tensor" - if (shape_values.empty()) { - auto status = GetShapeOfShapeTensor(input_tensor, shape_values, trt_engine, input_name, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + // Bind "shape tensor" input buffer + int shape_size = trt_engine->getTensorShape(input_name).nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + // get shape tensor value if not present + if (shape_tensor_values.find(input_name) == shape_tensor_values.end()) { + auto input = std::make_unique(shape_size); + auto status = GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + shape_tensor_values[input_name].resize(shape_size); + for (size_t i = 0; i < shape_size; ++i) { + shape_tensor_values[input_name][i] = input[i]; + } + } + + if (!trt_context->setTensorAddress(input_name, &shape_tensor_values[input_name][0])) { + std::string error_input_name = input_name; + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'")); + } + break; } - } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // get shape tensor value if not present + if (shape_tensor_values_int64.find(input_name) == shape_tensor_values_int64.end()) { + auto input = std::make_unique(shape_size); + auto status = GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + shape_tensor_values_int64[input_name].resize(shape_size); + for (size_t i = 0; i < shape_size; ++i) { + shape_tensor_values_int64[input_name][i] = input[i]; + } + } - // Bind "shape tensor" input buffer - if (!trt_context->setTensorAddress(input_name, &shape_values[0])) { - std::string error_input_name = input_name; - ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'")); + if (!trt_context->setTensorAddress(input_name, &shape_tensor_values_int64[input_name][0])) { + std::string error_input_name = input_name; + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'")); + } + break; + } + default: { + std::string error_input_name = input_name; + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The data type of shape tensor should be INT32 or INT64. Please check the data type of " + error_input_name); + } } } else { // Set shape for input tensor which is execution tensor @@ -865,8 +881,12 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) - // Cast int64 input to int32 input because TensorRT doesn't support int64 +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Cast int64 input to int32 input because TensorRT < 10 doesn't support int64 CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif // Cast double input to float because TensorRT doesn't support double CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { @@ -953,8 +973,12 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) - // Allocate int32 CUDA memory for int64 output type because TensorRT doesn't support int64 +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Allocate int32 CUDA memory for int64 output type because TensorRT < 10 doesn't support int64 CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif // Allocate float CUDA memory for double output type because TensorRT doesn't support double CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { @@ -3032,7 +3056,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; auto fused_node_name = trt_state->fused_node_name; + // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. + // The info is used for both shape tensor and execution tensor: + // tensor name->(dimension->[min, max, opt]) auto& shape_ranges = trt_state->input_shape_ranges; + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_builder = trt_state->builder; auto trt_engine = trt_state->engine->get(); @@ -3044,7 +3073,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView bool engine_update = false; bool context_update = false; std::unordered_set input_names; - std::unordered_map> tensor_shape_values; OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); @@ -3142,7 +3170,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. // TRT EP will help determine the min/max/opt profile values based on current input tensor value. if (shape_ranges.find(input_name) != shape_ranges.end()) { - auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, shape_tensor_values, shape_tensor_values_int64, stream, &engine_update); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); } @@ -3355,13 +3383,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shapes = tensor_info.GetShape(); - // Only use for "shape tensor" input - std::vector shape_values; - if (tensor_shape_values.find(input_name) != tensor_shape_values.end()) { - shape_values = tensor_shape_values[input_name]; - } - - auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_values, scratch_buffers, alloc, stream); + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -3465,12 +3487,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } } else { auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { auto output_tensor_ptr = output_tensor.GetTensorMutableData(); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); } - } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { auto output_tensor_ptr = output_tensor.GetTensorMutableData(); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); @@ -3611,8 +3636,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; - // int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); @@ -3651,10 +3677,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con input_index = iter->second; } - // Only use for "shape tensor" input - std::vector shape_values; - - Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_values, scratch_buffers, alloc, stream); + Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -3758,12 +3781,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con } } else { auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { auto output_tensor_ptr = output_tensor.GetTensorMutableData(); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); } - } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { auto output_tensor_ptr = output_tensor.GetTensorMutableData(); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index f73031eaefceb..615e61e69a83f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -134,6 +134,10 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { std::vector output_shapes; }; +/* + * This map saves the dimension range of the shape of the shape tensor or execution tensor: + * tensor name -> ( dimension -> [min, max, opt] ) + */ using ShapeRangesMap = std::unordered_map>>>; // Information to construct kernel function state.