diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index a1fc67ff60b6f..7f5ab3a772305 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -365,6 +365,46 @@ std::unique_lock TensorrtExecutionProvider::GetApiLock() const { return std::unique_lock(singleton); } +Status GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, + std::vector& shape_values, + nvinfer1::ICudaEngine* trt_engine, + int binding_index, + cudaStream_t stream) { + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); + nvinfer1::Dims dims = trt_engine->getBindingDimensions(static_cast(binding_index)); + int nb_dims = dims.nbDims; + int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + shape_values.resize(shape_size, 1); + + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto input = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + for (int j = 0; j < shape_size; ++j) { + shape_values[j] = input[j]; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + auto input = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + for (int j = 0; j < shape_size; ++j) { + shape_values[j] = static_cast(input[j]); + } + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported."); + } + } + return Status::OK(); +} + /* * Apply TensorRT optimization profile shapes from provider options. * @@ -404,7 +444,7 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorisShapeTensor()) { - auto shape_size = nb_dims; + int shape_size = nb_dims == 0 ? 1 : static_cast(profile_min_shapes[input_name][i].size()); std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shape size of this shape tensor is " << shape_size; @@ -2758,7 +2798,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorisShapeBinding(binding_index)) { - trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + // Get shape of the shape tensor + std::vector shape_values; + if (!tensor_shape_values[input_name].empty()) { + shape_values = tensor_shape_values[input_name]; + } else { + auto status = GetShapeOfShapeTensor(input_tensor, shape_values, trt_engine, binding_index, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + trt_context->setInputShapeBinding(binding_index, &shape_values[0]); } else { for (int j = 0, end = nb_dims; j < end; ++j) { dimensions.d[j] = static_cast(tensor_shapes[j]); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 33d50f90333cf..7dee0bc41a6f3 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2832,6 +2832,58 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { #endif #ifdef USE_TENSORRT +TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { + const auto& api = Ort::GetApi(); + + // Test input tensor which is shape tensor with explicit trt profile shapes + Ort::SessionOptions session_options; + OrtTensorRTProviderOptionsV2* trt_options; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); + std::unique_ptr + rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + + const char* trt_profile_min_shapes = "data:2x2,shape:4x1"; + const char* trt_profile_max_shapes = "data:2x2,shape:4x1"; + const char* trt_profile_opt_shapes = "data:2x2,shape:4x1"; + std::vector keys{"trt_profile_min_shapes", "trt_profile_max_shapes", "trt_profile_opt_shapes"}; + std::vector values{trt_profile_min_shapes, trt_profile_max_shapes, trt_profile_opt_shapes}; + ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options), + rel_trt_options.get()) == nullptr); + + auto model_path = ORT_TSTR("testdata/trt_reshape.onnx"); + + std::vector input_value_0{1.1f, 1.2f, 1.3f, 1.4f}; + std::vector input_shape_0{2, 2}; + std::vector input_value_1{4, 1}; + std::vector input_shape_1{2}; + + std::vector input_names{"data", "shape"}; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + std::vector ort_inputs; + ort_inputs.emplace_back(Ort::Value::CreateTensor(info, input_value_0.data(), input_value_0.size(), input_shape_0.data(), input_shape_0.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor(info, input_value_1.data(), input_value_1.size(), input_shape_1.data(), input_shape_1.size())); + + const char* output_names[] = {"reshaped"}; + + Ort::Session session(*ort_env, model_path, session_options); + session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); + + // Test input tensor which is shape tensor with implicit trt profile shapes + Ort::SessionOptions session_options_2; + OrtTensorRTProviderOptionsV2* trt_options_2; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options_2) == nullptr); + std::unique_ptr + rel_trt_options_2(trt_options_2, api.ReleaseTensorRTProviderOptions); + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options_2), + rel_trt_options_2.get()) == nullptr); + Ort::Session session_2(*ort_env, model_path, session_options_2); + session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); +} + TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) { const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; diff --git a/onnxruntime/test/testdata/trt_reshape.onnx b/onnxruntime/test/testdata/trt_reshape.onnx new file mode 100644 index 0000000000000..7d195af2ae204 --- /dev/null +++ b/onnxruntime/test/testdata/trt_reshape.onnx @@ -0,0 +1,16 @@ + :‰ +) +data +shapereshapedReshape"Reshapetrt_engine_wrapperZ +data +  +N +Z +shape + + +b +reshaped +  + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/trt_reshape_test.py b/onnxruntime/test/testdata/trt_reshape_test.py new file mode 100644 index 0000000000000..42777bd3d50c7 --- /dev/null +++ b/onnxruntime/test/testdata/trt_reshape_test.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import TensorProto, helper + + +def generate_model(model_name): + nodes = [ + helper.make_node( + "Reshape", + ["data", "shape"], + ["reshaped"], + "Reshape", + ), + ] + + graph = helper.make_graph( + nodes, + "trt_engine_wrapper", + [ # input + helper.make_tensor_value_info("data", TensorProto.FLOAT, ["N", 2]), + helper.make_tensor_value_info( + "shape", + TensorProto.INT64, + [ + 2, + ], + ), + ], + [ # output + helper.make_tensor_value_info("reshaped", TensorProto.FLOAT, [4, 1]), + ], + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +if __name__ == "__main__": + generate_model("trt_reshape.onnx")