From 84bdf04b255661e8446ded49edda4dd8d730101f Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Fri, 3 Nov 2023 23:07:50 +0000 Subject: [PATCH] [TensorRT EP] Fix bug for shape tensor input (#18253) When the model has "shape tensor" as one of the inputs and user provides explicit profile shapes for it, TRT EP doesn't correctly set the "shape tensor" input. Also, there is a bug for applying explicit profile shapes for the shape tensor input. Note: It seems the model has shape tensor input is a rare case. Most of the cases, the inputs are all execution tensor. --- .../tensorrt/tensorrt_execution_provider.cc | 54 ++++++++++++++++++- onnxruntime/test/shared_lib/test_inference.cc | 52 ++++++++++++++++++ onnxruntime/test/testdata/trt_reshape.onnx | 16 ++++++ onnxruntime/test/testdata/trt_reshape_test.py | 42 +++++++++++++++ 4 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/testdata/trt_reshape.onnx create mode 100644 onnxruntime/test/testdata/trt_reshape_test.py diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index a1fc67ff60b6f..7f5ab3a772305 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -365,6 +365,46 @@ std::unique_lock TensorrtExecutionProvider::GetApiLock() const { return std::unique_lock(singleton); } +Status GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, + std::vector& shape_values, + nvinfer1::ICudaEngine* trt_engine, + int binding_index, + cudaStream_t stream) { + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); + nvinfer1::Dims dims = trt_engine->getBindingDimensions(static_cast(binding_index)); + int nb_dims = dims.nbDims; + int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + shape_values.resize(shape_size, 1); + + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto input = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + for (int j = 0; j < shape_size; ++j) { + shape_values[j] = input[j]; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + auto input = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + for (int j = 0; j < shape_size; ++j) { + shape_values[j] = static_cast(input[j]); + } + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported."); + } + } + return Status::OK(); +} + /* * Apply TensorRT optimization profile shapes from provider options. * @@ -404,7 +444,7 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorisShapeTensor()) { - auto shape_size = nb_dims; + int shape_size = nb_dims == 0 ? 1 : static_cast(profile_min_shapes[input_name][i].size()); std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shape size of this shape tensor is " << shape_size; @@ -2758,7 +2798,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorisShapeBinding(binding_index)) { - trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + // Get shape of the shape tensor + std::vector shape_values; + if (!tensor_shape_values[input_name].empty()) { + shape_values = tensor_shape_values[input_name]; + } else { + auto status = GetShapeOfShapeTensor(input_tensor, shape_values, trt_engine, binding_index, stream); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + trt_context->setInputShapeBinding(binding_index, &shape_values[0]); } else { for (int j = 0, end = nb_dims; j < end; ++j) { dimensions.d[j] = static_cast(tensor_shapes[j]); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 33d50f90333cf..7dee0bc41a6f3 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2832,6 +2832,58 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { #endif #ifdef USE_TENSORRT +TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { + const auto& api = Ort::GetApi(); + + // Test input tensor which is shape tensor with explicit trt profile shapes + Ort::SessionOptions session_options; + OrtTensorRTProviderOptionsV2* trt_options; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); + std::unique_ptr + rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + + const char* trt_profile_min_shapes = "data:2x2,shape:4x1"; + const char* trt_profile_max_shapes = "data:2x2,shape:4x1"; + const char* trt_profile_opt_shapes = "data:2x2,shape:4x1"; + std::vector keys{"trt_profile_min_shapes", "trt_profile_max_shapes", "trt_profile_opt_shapes"}; + std::vector values{trt_profile_min_shapes, trt_profile_max_shapes, trt_profile_opt_shapes}; + ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options), + rel_trt_options.get()) == nullptr); + + auto model_path = ORT_TSTR("testdata/trt_reshape.onnx"); + + std::vector input_value_0{1.1f, 1.2f, 1.3f, 1.4f}; + std::vector input_shape_0{2, 2}; + std::vector input_value_1{4, 1}; + std::vector input_shape_1{2}; + + std::vector input_names{"data", "shape"}; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + std::vector ort_inputs; + ort_inputs.emplace_back(Ort::Value::CreateTensor(info, input_value_0.data(), input_value_0.size(), input_shape_0.data(), input_shape_0.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor(info, input_value_1.data(), input_value_1.size(), input_shape_1.data(), input_shape_1.size())); + + const char* output_names[] = {"reshaped"}; + + Ort::Session session(*ort_env, model_path, session_options); + session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); + + // Test input tensor which is shape tensor with implicit trt profile shapes + Ort::SessionOptions session_options_2; + OrtTensorRTProviderOptionsV2* trt_options_2; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options_2) == nullptr); + std::unique_ptr + rel_trt_options_2(trt_options_2, api.ReleaseTensorRTProviderOptions); + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options_2), + rel_trt_options_2.get()) == nullptr); + Ort::Session session_2(*ort_env, model_path, session_options_2); + session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); +} + TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) { const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; diff --git a/onnxruntime/test/testdata/trt_reshape.onnx b/onnxruntime/test/testdata/trt_reshape.onnx new file mode 100644 index 0000000000000..7d195af2ae204 --- /dev/null +++ b/onnxruntime/test/testdata/trt_reshape.onnx @@ -0,0 +1,16 @@ + :‰ +) +data +shapereshapedReshape"Reshapetrt_engine_wrapperZ +data +  +N +Z +shape + + +b +reshaped +  + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/trt_reshape_test.py b/onnxruntime/test/testdata/trt_reshape_test.py new file mode 100644 index 0000000000000..42777bd3d50c7 --- /dev/null +++ b/onnxruntime/test/testdata/trt_reshape_test.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import TensorProto, helper + + +def generate_model(model_name): + nodes = [ + helper.make_node( + "Reshape", + ["data", "shape"], + ["reshaped"], + "Reshape", + ), + ] + + graph = helper.make_graph( + nodes, + "trt_engine_wrapper", + [ # input + helper.make_tensor_value_info("data", TensorProto.FLOAT, ["N", 2]), + helper.make_tensor_value_info( + "shape", + TensorProto.INT64, + [ + 2, + ], + ), + ], + [ # output + helper.make_tensor_value_info("reshaped", TensorProto.FLOAT, [4, 1]), + ], + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +if __name__ == "__main__": + generate_model("trt_reshape.onnx")