diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0159c35d1941b..5de3e87cbb71c 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -483,6 +483,10 @@ if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) endif() +if (onnxruntime_TEST_TENSORRT_EP_PLUGIN) + list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_trt_ep_plugin.cc) +endif() + if(onnxruntime_RUN_ONNX_TESTS) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_io_types.cc) endif() diff --git a/onnxruntime/test/shared_lib/test_trt_ep_plugin.cc b/onnxruntime/test/shared_lib/test_trt_ep_plugin.cc new file mode 100644 index 0000000000000..15361234678b8 --- /dev/null +++ b/onnxruntime/test/shared_lib/test_trt_ep_plugin.cc @@ -0,0 +1,194 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include + +namespace onnxruntime { + +const ORTCHAR_T* ep_plugin_lib = "/home/lochi/repos/ort_for_docker_ep_plugin/samples/tensorRTEp/build/libTensorRTEp.so"; // hardcode path for now +const ORTCHAR_T* ep_plugin_name = "tensorrtEp"; +const ORTCHAR_T* model_path = "testdata/trt_ep_test_model_static_input_shape.onnx"; +const ORTCHAR_T* model_path_2 = "testdata/trt_ep_test_model_dynamic_input_shape.onnx"; + +inline void THROW_ON_ERROR(OrtStatus* status, const OrtApi* api) { + if (status != nullptr && api != nullptr) { + std::cout<<"ErrorMessage:"<GetErrorMessage(status)<<"\n"; + abort(); + } +} + +void RegisterTrtEpPlugin(const OrtApi* api, OrtEnv* env, OrtSessionOptions* so) { + THROW_ON_ERROR(api->RegisterPluginExecutionProviderLibrary(ep_plugin_lib, env, ep_plugin_name), api); + std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; // hardcode device id for now + THROW_ON_ERROR(api->SessionOptionsAppendPluginExecutionProvider(so, ep_plugin_name, env, keys.data(), values.data(), keys.size()), api); +} + +bool HasCacheFileWithPrefix(const std::string& prefix, std::string file_dir = "") { + std::filesystem::path target_dir; + if (file_dir.empty()) { + target_dir = std::filesystem::current_path(); + } else { + target_dir = std::filesystem::path(file_dir); + } + + for (const auto& entry : std::filesystem::directory_iterator(target_dir)) { + if (entry.is_regular_file()) { + std::string filename = entry.path().filename().string(); + if (filename.rfind(prefix, 0) == 0) { + return true; + } + } + } + return false; +} + +void ValidateOutputs(std::vector& ort_outputs, + std::vector& expected_dims, + std::vector& expected_values) { + + auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), expected_dims); + size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(expected_values.size(), total_len); + + float* f = ort_outputs[0].GetTensorMutableData(); + for (size_t i = 0; i != total_len; ++i) { + ASSERT_EQ(expected_values[i], f[i]); + } +} + +void RunWithOneSessionSingleThreadInference() { + // Use C API at first since EP plugin only supports C API for now + OrtEnv* env = nullptr; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; + THROW_ON_ERROR(api->CreateEnv(log_level, "", &env), api); + OrtSessionOptions* so = nullptr; + THROW_ON_ERROR(api->CreateSessionOptions(&so), api); + + RegisterTrtEpPlugin(api, env, so); + + // Use C++ Wrapper + Ort::SessionOptions ort_so{so}; + Ort::Env ort_env{env}; + + OrtTensorRTProviderOptionsV2* trt_options; + ASSERT_TRUE(api->CreateTensorRTProviderOptions(&trt_options) == nullptr); + std::unique_ptrReleaseTensorRTProviderOptions)> + rel_trt_options(trt_options, api->ReleaseTensorRTProviderOptions); + std::vector keys{"trt_engine_cache_enable", "trt_engine_cache_prefix", "trt_dump_ep_context_model", "trt_ep_context_file_path"}; + std::vector values{"1", "TRTEP_Cache_Test", "1", "EP_Context_model.onnx"}; + ASSERT_TRUE(api->UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + + Ort::Session session(ort_env, model_path, ort_so); + + std::vector ort_inputs; + std::vector input_names; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + // input 0, 1, 2 + std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + std::vector input_dims = {1, 3, 2}; + input_names.emplace_back("X"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Y"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Z"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + + // output 0 + const char* output_names[] = {"M"}; + + // Run inference + // TRT engine will be created and cached + // TRT profile will be created and cached only for dynamic input shape + // Data in profile, + // X: 1, 3, 3, 2, 2, 2 + // Y: 1, 3, 3, 2, 2, 2 + // Z: 1, 3, 3, 2, 2, 2 + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + // Verify on cache with customized prefix + ASSERT_TRUE(HasCacheFileWithPrefix("TRTEP_Cache_Test")); + + // Verify EP context model with user provided name + ASSERT_TRUE(HasCacheFileWithPrefix("EP_Context_model.onnx")); +} + +TEST(TensorrtExecutionProviderPluginTest, SmallModel) { + // Use C API at first since EP plugin only supports C API for now + OrtEnv* env = nullptr; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; + THROW_ON_ERROR(api->CreateEnv(log_level, "", &env), api); + OrtSessionOptions* so = nullptr; + THROW_ON_ERROR(api->CreateSessionOptions(&so), api); + + RegisterTrtEpPlugin(api, env, so); + + // Use C++ Wrapper + Ort::SessionOptions ort_so{so}; + Ort::Env ort_env{env}; + Ort::Session session(ort_env, model_path, ort_so); + + std::vector ort_inputs; + std::vector input_names; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + // input 0, 1, 2 + std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + std::vector input_dims = {1, 3, 2}; + input_names.emplace_back("X"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Y"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Z"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + + // output 0 + const char* output_names[] = {"M"}; + + // Run inference + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + // Validate results + std::vector y_dims = {1, 3, 2}; + std::vector values_y = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; + ValidateOutputs(ort_outputs, y_dims, values_y); +} + +TEST(TensorrtExecutionProviderPluginTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) { + std::vector threads; + std::vector dims = {1, 3, 2}; + int num_thread = 5; + + for (int i = 0; i < num_thread; ++i) + threads.push_back(std::thread(RunWithOneSessionSingleThreadInference)); + + for (auto& th : threads) + th.join(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/trt_ep_test_model_static_input_shape.onnx b/onnxruntime/test/testdata/trt_ep_test_model_static_input_shape.onnx new file mode 100644 index 0000000000000..4286222dd05bc Binary files /dev/null and b/onnxruntime/test/testdata/trt_ep_test_model_static_input_shape.onnx differ diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 75fbf5d0851ae..2047324f4c15c 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -579,6 +579,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument( "--use_tvm_hash", action="store_true", help="Build ipp-crypto for hash generation. It is used by TVM EP only" ) + parser.add_argument("--test_tensorrt_ep_plugin", action="store_true", help="Build with TensorRT EP Plugin Test App") parser.add_argument("--use_tensorrt", action="store_true", help="Build with TensorRT") parser.add_argument( "--use_tensorrt_builtin_parser", action="store_true", default=True, help="Use TensorRT builtin parser" @@ -1027,6 +1028,7 @@ def generate_build_tree( "-Donnxruntime_USE_LLVM=" + ("ON" if args.use_tvm else "OFF"), "-Donnxruntime_ENABLE_MICROSOFT_INTERNAL=" + ("ON" if args.enable_msinternal else "OFF"), "-Donnxruntime_USE_VITISAI=" + ("ON" if args.use_vitisai else "OFF"), + "-Donnxruntime_TEST_TENSORRT_EP_PLUGIN =" + ("ON" if args.test_tensorrt_ep_plugin else "OFF"), "-Donnxruntime_USE_TENSORRT=" + ("ON" if args.use_tensorrt else "OFF"), "-Donnxruntime_USE_TENSORRT_BUILTIN_PARSER=" + ("ON" if args.use_tensorrt_builtin_parser and not args.use_tensorrt_oss_parser else "OFF"),