From 0e6a80cb53a543804ed172466249cd20b69057e4 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 17 Jul 2024 20:39:43 +0000 Subject: [PATCH 01/81] opaque pointer for graph --- include/onnxruntime/core/session/onnxruntime_c_api.h | 3 +++ onnxruntime/core/session/onnxruntime_c_api.cc | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5c61963a2f39c..16fef811d821c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -304,6 +304,7 @@ ORT_RUNTIME_CLASS(Op); ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); +ORT_RUNTIME_CLASS(Graph); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -4825,6 +4826,8 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOpt */ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); +ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraph* graph, const char* name, bool check_outer_scope); + #ifdef __cplusplus } #endif diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 5cf5ff9b3bd0a..52491e9367afe 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -22,6 +22,7 @@ #include "core/common/safeint.h" #include "core/graph/constants.h" #include "core/graph/graph.h" +#include "core/graph/graph_viewer.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" #include "core/framework/ort_value.h" @@ -2802,3 +2803,8 @@ DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata) + +ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraph* graph, const char* name, bool check_outer_scope) { + ::onnxruntime::GraphViewer graph_viewer(*(reinterpret_cast(graph))); + return graph_viewer.IsConstantInitializer(std::string(name), check_outer_scope); +} From c30a639f37ec14ef8565834e823f8e9885e1a14a Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 23 Jul 2024 00:21:51 +0000 Subject: [PATCH 02/81] ORT C API RegisterOrtExecutionProviderLibrary work --- .../onnxruntime/core/session/environment.h | 3 ++ .../core/session/onnxruntime_c_api.h | 13 +++++++ onnxruntime/core/session/environment.cc | 4 +++ onnxruntime/core/session/onnxruntime_c_api.cc | 16 +++++++++ onnxruntime/core/session/ort_apis.h | 2 ++ onnxruntime/core/session/ort_env.cc | 4 +++ onnxruntime/core/session/ort_env.h | 2 ++ samples/c_test/CMakeLists.txt | 10 ++++++ samples/c_test/test.cpp | 14 ++++++++ samples/outTreeEp/CMakeLists.txt | 9 +++++ samples/outTreeEp/out_tree_ep.cc | 22 ++++++++++++ samples/outTreeEp/out_tree_ep.h | 35 +++++++++++++++++++ 12 files changed, 134 insertions(+) create mode 100644 samples/c_test/CMakeLists.txt create mode 100644 samples/c_test/test.cpp create mode 100644 samples/outTreeEp/CMakeLists.txt create mode 100644 samples/outTreeEp/out_tree_ep.cc create mode 100644 samples/outTreeEp/out_tree_ep.h diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 20c7cace152d5..e7e84d86dbf46 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -88,6 +88,8 @@ class Environment { */ Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg = nullptr); + void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); Status Initialize(std::unique_ptr logging_manager, @@ -99,5 +101,6 @@ class Environment { std::unique_ptr inter_op_thread_pool_; bool create_global_thread_pools_{false}; std::vector shared_allocators_; + std::map custom_ep_factories_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 16fef811d821c..e6bbb0f2429a9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -304,6 +304,8 @@ ORT_RUNTIME_CLASS(Op); ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); +ORT_RUNTIME_CLASS(ExecutionProvider); +ORT_RUNTIME_CLASS(ExecutionProviderFactory); ORT_RUNTIME_CLASS(Graph); #ifdef _WIN32 @@ -682,6 +684,15 @@ struct OrtApiBase { const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; }; +typedef struct OrtExecutionProvider { + //void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraph* graph, _Out_ int* cnt, _Outptr_ OrtComputeCapability** compute_capability); + //void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraph* graph, const OrtNode* node, int size, _Out_ int* cnt, _Outptr_ OrtNodeComputeInfo** node_compute_info); +} OrtExecutionProvider; + +typedef struct OrtExecutionProviderFactory { + void*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_); +} OrtExecutionProviderFactory; + typedef struct OrtApiBase OrtApiBase; /** \brief The Onnxruntime library's entry point to access the C API @@ -4666,6 +4677,8 @@ struct OrtApi { _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, size_t num_external_initializer_files); + + ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); }; /* diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 318c76645bdf5..24cf848541939 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -348,4 +348,8 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ return Status{ONNXRUNTIME, common::INVALID_ARGUMENT, provider_type + " is not implemented in CreateAndRegisterAllocatorV2()"}; } +void Environment::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory) { + custom_ep_factories_.insert({std::string(ep_name), ep_factory}); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 52491e9367afe..cc620369e30b7 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2354,6 +2354,20 @@ ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #endif } +ORT_API_STATUS_IMPL(OrtApis::RegisterOrtExecutionProviderLibrary, _In_ const char* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name) { + API_IMPL_BEGIN + void* handle = nullptr; + ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(ToPathString(lib_path), false, &handle)); + if (handle) { + OrtExecutionProviderFactory* (*symbol)(); + ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, "RegisterCustomEp", (void**)&symbol)); + env->InsertCustomEp(ep_name, symbol()); + return nullptr; + } + return CreateStatus(ORT_RUNTIME_EXCEPTION, "cannot load the shared library for out-tree EP"); + API_IMPL_END +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2731,6 +2745,8 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::KernelInfoGetAllocator, &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::RegisterOrtExecutionProviderLibrary, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fcae173e6c162..0fa5fe2df09ab 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -523,4 +523,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessi ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); + +ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 3c178fd1e91d3..e3212b17dac20 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -110,3 +110,7 @@ onnxruntime::common::Status OrtEnv::UnregisterAllocator(const OrtMemoryInfo& mem onnxruntime::common::Status OrtEnv::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg) { return value_->CreateAndRegisterAllocatorV2(provider_type, mem_info, options, arena_cfg); } + +void OrtEnv::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory) { + value_->InsertCustomEp(ep_name, ep_factory); +} diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 444134d0612e9..42d33de3c8d39 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -65,6 +65,8 @@ struct OrtEnv { ~OrtEnv(); onnxruntime::common::Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg = nullptr); + void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); + private: static std::unique_ptr p_instance_; static onnxruntime::OrtMutex m_; diff --git a/samples/c_test/CMakeLists.txt b/samples/c_test/CMakeLists.txt new file mode 100644 index 0000000000000..068b7cf18be91 --- /dev/null +++ b/samples/c_test/CMakeLists.txt @@ -0,0 +1,10 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ +# cmake --build ./ +cmake_minimum_required(VERSION 3.26) +project(TestOutTreeEp) +add_executable(TestOutTreeEp test.cpp) + +target_include_directories(TestOutTreeEp PUBLIC "../../include/onnxruntime") +target_link_libraries(TestOutTreeEp PUBLIC "/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so") diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp new file mode 100644 index 0000000000000..bdf349b5147a2 --- /dev/null +++ b/samples/c_test/test.cpp @@ -0,0 +1,14 @@ +#include "core/session/onnxruntime_c_api.h" + +inline void THROW_ON_ERROR(OrtStatus* status) { + if (status != nullptr) abort(); +} + +int main() { + const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtEnv* p_env = nullptr; + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; + THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env)); + THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", p_env, "outTreeEp")); + return 0; +} diff --git a/samples/outTreeEp/CMakeLists.txt b/samples/outTreeEp/CMakeLists.txt new file mode 100644 index 0000000000000..72396ac761246 --- /dev/null +++ b/samples/outTreeEp/CMakeLists.txt @@ -0,0 +1,9 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ +# cmake --build ./ +cmake_minimum_required(VERSION 3.26) +project(outTreeEp VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +add_library(outTreeEp SHARED out_tree_ep.cc) +target_include_directories(outTreeEp PUBLIC "../../include/onnxruntime") diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc new file mode 100644 index 0000000000000..eede2ac209145 --- /dev/null +++ b/samples/outTreeEp/out_tree_ep.cc @@ -0,0 +1,22 @@ +#include "out_tree_ep.h" +#include +namespace onnxruntime { + +OutTreeEpFactory::OutTreeEpFactory() { + OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_) -> void* { + std::unique_ptr ret = std::make_unique(); + return ret.release(); }; +} + +} + +#ifdef __cplusplus +extern "C" { +#endif +OrtExecutionProviderFactory* RegisterCustomEp() { + std::unique_ptr ret = std::make_unique(); + return ret.release(); +} +#ifdef __cplusplus +} +#endif diff --git a/samples/outTreeEp/out_tree_ep.h b/samples/outTreeEp/out_tree_ep.h new file mode 100644 index 0000000000000..df47f40eaf1f1 --- /dev/null +++ b/samples/outTreeEp/out_tree_ep.h @@ -0,0 +1,35 @@ +#pragma once +#include "core/session/onnxruntime_c_api.h" +#include + +#ifdef _WIN32 +#define EXPORT_API __declspec(dllexport) +#else +#define EXPORT_API +#endif + +namespace onnxruntime { + +struct OutTreeEpInfo { + int int_property; + std::string str_property; +}; + +struct OutTreeEp : public OrtExecutionProvider { +OutTreeEp() {} +}; + +struct OutTreeEpFactory : public OrtExecutionProviderFactory { + OutTreeEpFactory(); +}; +} + +#ifdef __cplusplus +extern "C" { +#endif + +EXPORT_API OrtExecutionProviderFactory* RegisterCustomEp(); + +#ifdef __cplusplus +} +#endif From 7bfe57e166503a19be3009010d45f499bea5fc1c Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 23 Jul 2024 21:55:30 +0000 Subject: [PATCH 03/81] ORT C-API SessionOptionsAppendOrtExecutionProvider work --- include/onnxruntime/core/session/environment.h | 4 +++- .../core/session/onnxruntime_c_api.h | 6 +++++- onnxruntime/core/framework/provider_adapter.h | 14 ++++++++++++++ onnxruntime/core/framework/session_options.h | 3 +++ onnxruntime/core/session/environment.cc | 3 ++- onnxruntime/core/session/inference_session.cc | 17 +++++++++++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 11 +++++++++++ onnxruntime/core/session/ort_apis.h | 3 +++ samples/c_test/test.cpp | 6 ++++++ samples/outTreeEp/out_tree_ep.cc | 13 ++++++++++--- samples/outTreeEp/out_tree_ep.h | 3 ++- 11 files changed, 76 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/core/framework/provider_adapter.h diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index e7e84d86dbf46..473e0898443e1 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -90,6 +90,8 @@ class Environment { void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); + const std::unordered_map>& GetCustomEpFactories() const { return custom_ep_factories_; } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); Status Initialize(std::unique_ptr logging_manager, @@ -101,6 +103,6 @@ class Environment { std::unique_ptr inter_op_thread_pool_; bool create_global_thread_pools_{false}; std::vector shared_allocators_; - std::map custom_ep_factories_; + std::unordered_map> custom_ep_factories_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e6bbb0f2429a9..d7f715ebb0119 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -687,10 +687,11 @@ struct OrtApiBase { typedef struct OrtExecutionProvider { //void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraph* graph, _Out_ int* cnt, _Outptr_ OrtComputeCapability** compute_capability); //void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraph* graph, const OrtNode* node, int size, _Out_ int* cnt, _Outptr_ OrtNodeComputeInfo** node_compute_info); + const char* type; } OrtExecutionProvider; typedef struct OrtExecutionProviderFactory { - void*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_); + void*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); } OrtExecutionProviderFactory; typedef struct OrtApiBase OrtApiBase; @@ -4679,6 +4680,9 @@ struct OrtApi { size_t num_external_initializer_files); ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); + + ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, + _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); }; /* diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h new file mode 100644 index 0000000000000..38de442e43f67 --- /dev/null +++ b/onnxruntime/core/framework/provider_adapter.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/session/onnxruntime_c_api.h" + +namespace onnxruntime { +class ExecutionProviderAdapter : public IExecutionProvider { + public: + ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type), ep_impl_(ep) {} + private: + OrtExecutionProvider* ep_impl_; +}; +} diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 46bfc3630303c..b82ff038c8d26 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -15,6 +15,7 @@ #include "core/session/onnxruntime_c_api.h" #include "core/optimizer/graph_transformer_level.h" #include "core/util/thread_utils.h" +#include "core/framework/provider_options.h" #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) #include "core/framework/library_handles.h" @@ -184,6 +185,8 @@ struct SessionOptions { // User specified logging func and param OrtLoggingFunction user_logging_function = nullptr; void* user_logging_param = nullptr; + + ProviderOptionsMap custom_ep_options; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 24cf848541939..17d06fbe0dbf9 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -349,7 +349,8 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ } void Environment::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory) { - custom_ep_factories_.insert({std::string(ep_name), ep_factory}); + std::unique_ptr p(ep_factory); + custom_ep_factories_.insert({ep_name, std::move(p)}); // TODO(leca): review } } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f0eed91d70440..7fa25bf779fb5 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -89,6 +89,7 @@ #include "core/framework/stream_execution_context.h" #include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" #endif +#include "core/framework/provider_adapter.h" using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; @@ -1655,6 +1656,22 @@ common::Status InferenceSession::Initialize() { const Env& env = Env::Default(); env.GetTelemetryProvider().LogSessionCreationStart(); + const std::unordered_map>& custom_ep_factories = environment_.GetCustomEpFactories(); + if (custom_ep_factories.size() > 0) { + for (auto const& [ep_name, ep_factory] : custom_ep_factories) { + if (session_options_.custom_ep_options.find(ep_name) != session_options_.custom_ep_options.end()) { + std::vector keys, values; + for (auto const& [op_k, op_v] : session_options_.custom_ep_options[ep_name]) { + keys.push_back(op_k.c_str()); + values.push_back(op_v.c_str()); + } + OrtExecutionProvider* ep = reinterpret_cast(ep_factory->CreateExecutionProvider(ep_factory.get(), keys.data(), values.data(), keys.size())); + std::unique_ptr ep_adapter = std::make_unique(ep); + ORT_RETURN_IF_ERROR(RegisterExecutionProvider(std::move(ep_adapter))); + } + } + } + bool have_cpu_ep = false; { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index cc620369e30b7..6fbd2031f6702 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2368,6 +2368,16 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterOrtExecutionProviderLibrary, _In_ const cha API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, + _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { + std::unordered_map kv; + for (size_t i = 0; i < num_keys; i++) { + kv.insert({provider_options_keys[i], provider_options_values[i]}); + } + options->value.custom_ep_options.insert({ep_name, kv}); + return nullptr; +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2747,6 +2757,7 @@ static constexpr OrtApi ort_api_1_to_19 = { // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::RegisterOrtExecutionProviderLibrary, + &OrtApis::SessionOptionsAppendOrtExecutionProvider, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 0fa5fe2df09ab..efcf89a3375ff 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -525,4 +525,7 @@ ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); + +ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, + _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); } // namespace OrtApis diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index bdf349b5147a2..3ae0ca52befd1 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -1,4 +1,5 @@ #include "core/session/onnxruntime_c_api.h" +#include inline void THROW_ON_ERROR(OrtStatus* status) { if (status != nullptr) abort(); @@ -10,5 +11,10 @@ int main() { OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env)); THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", p_env, "outTreeEp")); + + OrtSessionOptions* so = nullptr; + THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); + std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; + THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "outTreeEp", keys.data(), values.data(), keys.size())); return 0; } diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index eede2ac209145..76f099f0b890f 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -3,9 +3,16 @@ namespace onnxruntime { OutTreeEpFactory::OutTreeEpFactory() { - OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_) -> void* { - std::unique_ptr ret = std::make_unique(); - return ret.release(); }; + OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> void* { + OutTreeEpInfo info; + for (size_t i = 0; i < option_size; i++) { + if (!strcmp(ep_option_keys[i], "int_property")) info.int_property = std::atoi(ep_option_values[i]); + else if (!strcmp(ep_option_keys[i], "str_property")) info.str_property = ep_option_values[i]; + // TODO(leca): else throw + } + std::unique_ptr ret = std::make_unique("outTreeEp", std::move(info)); + return ret.release(); + }; } } diff --git a/samples/outTreeEp/out_tree_ep.h b/samples/outTreeEp/out_tree_ep.h index df47f40eaf1f1..73681383e06e7 100644 --- a/samples/outTreeEp/out_tree_ep.h +++ b/samples/outTreeEp/out_tree_ep.h @@ -16,7 +16,8 @@ struct OutTreeEpInfo { }; struct OutTreeEp : public OrtExecutionProvider { -OutTreeEp() {} + OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(ep_info) { type = ep_type; } + OutTreeEpInfo info; }; struct OutTreeEpFactory : public OrtExecutionProviderFactory { From 8e7d28d25f07690626a2b2993307e17e9c938e2f Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Fri, 26 Jul 2024 18:37:50 +0000 Subject: [PATCH 04/81] Test Relu with compile based EP, build work, runtime error of loading EP as graph API is not exported by ORT. Need to put these graph API into ortapi structure --- .../core/session/onnxruntime_c_api.h | 73 ++++++++++++++--- onnxruntime/core/framework/provider_adapter.h | 75 +++++++++++++++++- onnxruntime/core/session/onnxruntime_c_api.cc | 45 ++++++++++- samples/c_test/Relu.onnx | Bin 0 -> 109 bytes samples/c_test/test.cpp | 27 ++++++- samples/outTreeEp/CMakeLists.txt | 3 + samples/outTreeEp/out_tree_ep.cc | 60 ++++++++++++++ samples/outTreeEp/out_tree_ep.h | 2 +- 8 files changed, 266 insertions(+), 19 deletions(-) create mode 100644 samples/c_test/Relu.onnx diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d7f715ebb0119..1d6bbf00108cd 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -306,7 +306,8 @@ ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); ORT_RUNTIME_CLASS(ExecutionProvider); ORT_RUNTIME_CLASS(ExecutionProviderFactory); -ORT_RUNTIME_CLASS(Graph); +ORT_RUNTIME_CLASS(Node); +ORT_RUNTIME_CLASS(GraphViewer); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -684,16 +685,6 @@ struct OrtApiBase { const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; }; -typedef struct OrtExecutionProvider { - //void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraph* graph, _Out_ int* cnt, _Outptr_ OrtComputeCapability** compute_capability); - //void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraph* graph, const OrtNode* node, int size, _Out_ int* cnt, _Outptr_ OrtNodeComputeInfo** node_compute_info); - const char* type; -} OrtExecutionProvider; - -typedef struct OrtExecutionProviderFactory { - void*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); -} OrtExecutionProviderFactory; - typedef struct OrtApiBase OrtApiBase; /** \brief The Onnxruntime library's entry point to access the C API @@ -702,6 +693,50 @@ typedef struct OrtApiBase OrtApiBase; */ ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; +typedef struct OrtMetaDef { + const char* name; + const char* domain; + int since_version; + + const char** inputs; + int input_len; + const char** outputs; + int output_len; + const char** constant_initializers; + int initializer_len; + + const char* doc_string; +} OrtMetaDef; + +typedef struct OrtIndexedSubGraph { + OrtMetaDef* meta_def; // TODO(leca): how to define a nested structure pointer? + size_t* node_index; + size_t node_index_len; +} OrtIndexedSubGraph; + +typedef struct OrtComputeContext { + void*(ORT_API_CALL* AllocateFunc)(void*, size_t, size_t); + void(ORT_API_CALL* DestroyFunc)(void*, void*); + void* allocator_handle; + const char* node_name; +} OrtComputeContext; + +typedef struct OrtNodeComputeInfo { + int(ORT_API_CALL* CreateFunctionStateFunc)(OrtComputeContext*, void**); + OrtStatusPtr(ORT_API_CALL* ComputeFunc)(void*, const OrtApi*, OrtKernelContext*); + void(ORT_API_CALL* DestroyFunctionStateFunc)(void*); +} OrtNodeComputeInfo; + +typedef struct OrtExecutionProvider { + void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); + void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo*** node_compute_info); + const char* type; +} OrtExecutionProvider; + +typedef struct OrtExecutionProviderFactory { + void*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); +} OrtExecutionProviderFactory; + /** \brief Thread work loop function * * Onnxruntime will provide the working loop on custom thread creation @@ -4843,7 +4878,21 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOpt */ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); -ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraph* graph, const char* name, bool check_outer_scope); +ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope); + +ORT_API(const size_t*, OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, size_t* len); + +ORT_API(const OrtNode*, OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index); + +ORT_API(const char*, OrtNode_GetOpType, const OrtNode* node); + +ORT_API(size_t, OrtNode_GetInputSize, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetIthInputName, const OrtNode* node, size_t i); + +ORT_API(size_t, OrtNode_GetOutputSize, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetIthOutputName, const OrtNode* node, size_t i); #ifdef __cplusplus } diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 38de442e43f67..5dc7fd3cb8d11 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -3,12 +3,83 @@ #pragma once #include "core/session/onnxruntime_c_api.h" +#include "core/framework/compute_capability.h" namespace onnxruntime { class ExecutionProviderAdapter : public IExecutionProvider { - public: +public: ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type), ep_impl_(ep) {} - private: + virtual std::vector> GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup) const override { + size_t cnt = 0; + OrtIndexedSubGraph** indexed_subgraph = nullptr; + ep_impl_->GetCapability(ep_impl_, reinterpret_cast(&graph_viewer), &cnt, &indexed_subgraph); + + if (cnt == 0) return IExecutionProvider::GetCapability(graph_viewer, kernel_lookup); + + std::vector> ret; + for (size_t i = 0; i < cnt; i++) { + std::unique_ptr sb = std::make_unique(); + sb->nodes.reserve(indexed_subgraph[i]->node_index_len); + for (size_t j = 0; j < indexed_subgraph[i]->node_index_len; j++) sb->nodes.push_back((indexed_subgraph[i]->node_index)[j]); + if (indexed_subgraph[i]->meta_def != nullptr) { + std::unique_ptr meta_def = std::make_unique(); + meta_def->name = indexed_subgraph[i]->meta_def->name; + meta_def->doc_string = indexed_subgraph[i]->meta_def->doc_string; + meta_def->domain = indexed_subgraph[i]->meta_def->domain; + meta_def->since_version = indexed_subgraph[i]->meta_def->since_version; + + meta_def->inputs.reserve(indexed_subgraph[i]->meta_def->input_len); + for (int j = 0; j < indexed_subgraph[i]->meta_def->input_len; j++) meta_def->inputs.push_back(indexed_subgraph[i]->meta_def->inputs[j]); + + meta_def->outputs.reserve(indexed_subgraph[i]->meta_def->output_len); + for (int j = 0; j < indexed_subgraph[i]->meta_def->output_len; j++) meta_def->outputs.push_back(indexed_subgraph[i]->meta_def->outputs[j]); + + meta_def->constant_initializers.reserve(indexed_subgraph[i]->meta_def->initializer_len); + for (int j = 0; j < indexed_subgraph[i]->meta_def->initializer_len; j++) meta_def->constant_initializers.push_back(indexed_subgraph[i]->meta_def->constant_initializers[j]); + + sb->SetMetaDef(std::move(meta_def)); + } + + ret.push_back(std::make_unique(std::move(sb))); + } + return ret; + } + + virtual common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { + std::vector ortGraphs; + std::vector ortNodes; + for (auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; + ortGraphs.push_back(reinterpret_cast(&graph_viewer)); + ortNodes.push_back(reinterpret_cast(&fused_node)); + } + size_t count = fused_nodes_and_graphs.size(); + OrtNodeComputeInfo** node_compute_info = new OrtNodeComputeInfo* [count]; + ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &node_compute_info); + + node_compute_funcs.reserve(count); + for (size_t i = 0; i < count; i++) { + NodeComputeInfo compute_info; + compute_info.create_state_func = [&](ComputeContext* context, void** state) { + OrtComputeContext occ; + occ.AllocateFunc = context->allocate_func; + occ.DestroyFunc = context->release_func; + occ.allocator_handle = context->allocator_handle; + occ.node_name = context->node_name; + return node_compute_info[i]->CreateFunctionStateFunc(&occ, state); // TODO(leca): reinterpret_cast(context)? + }; + compute_info.compute_func = [&](void* state, const OrtApi* api, OrtKernelContext* context) { + return ToStatus(node_compute_info[i]->ComputeFunc(state, api, context)); + }; + compute_info.release_state_func = [&](void* state) { + node_compute_info[i]->DestroyFunctionStateFunc(state); + }; + node_compute_funcs.push_back(compute_info); + } + return Status::OK(); + } +private: OrtExecutionProvider* ep_impl_; }; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 6fbd2031f6702..58253509838f7 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2831,7 +2831,46 @@ DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata) -ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraph* graph, const char* name, bool check_outer_scope) { - ::onnxruntime::GraphViewer graph_viewer(*(reinterpret_cast(graph))); - return graph_viewer.IsConstantInitializer(std::string(name), check_outer_scope); +ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->IsConstantInitializer(name, check_outer_scope); +} + +ORT_API(const size_t*, OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, size_t* len) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const std::vector& nodes = graph_viewer->GetNodesInTopologicalOrder(); + *len = nodes.size(); + return nodes.data(); +} + +ORT_API(const OrtNode*, OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return reinterpret_cast(graph_viewer->GetNode(node_index)); +} + +ORT_API(const char*, OrtNode_GetOpType, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->OpType().c_str(); +} + +ORT_API(size_t, OrtNode_GetInputSize, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->InputDefs().size(); +} + +ORT_API(const char*, OrtNode_GetIthInputName, const OrtNode* node, size_t i) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + assert(i < n->InputDefs().size()); + return n->InputDefs()[i]->Name().c_str(); +} + +ORT_API(size_t, OrtNode_GetOutputSize, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->OutputDefs().size(); +} + +ORT_API(const char*, OrtNode_GetIthOutputName, const OrtNode* node, size_t i) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + assert(i < n->OutputDefs().size()); + return n->OutputDefs()[i]->Name().c_str(); } diff --git a/samples/c_test/Relu.onnx b/samples/c_test/Relu.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1b9af5e66dc7cc2ec22b099a60fd4c893fe13a94 GIT binary patch literal 109 zcmd +#include inline void THROW_ON_ERROR(OrtStatus* status) { if (status != nullptr) abort(); @@ -8,7 +9,7 @@ inline void THROW_ON_ERROR(OrtStatus* status) { int main() { const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); OrtEnv* p_env = nullptr; - OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;//OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env)); THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", p_env, "outTreeEp")); @@ -16,5 +17,29 @@ int main() { THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "outTreeEp", keys.data(), values.data(), keys.size())); + + OrtSession* session = nullptr; + THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", so, &session)); + + OrtMemoryInfo* memory_info = nullptr; + THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); + float input_data[] = {-3.0f, 5.0f, -2.0f, 4.0f}; + const size_t input_len = 4 * sizeof(float); + const int64_t input_shape[] = {4}; + const size_t shape_len = sizeof(input_shape)/sizeof(input_shape[0]); + + OrtValue* input_tensor = nullptr; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); + + const char* input_names[] = {"x"}; + const char* output_names[] = {"graphOut"}; + OrtValue* output_tensor = nullptr; + THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)&input_tensor, 1, output_names, 1, &output_tensor)); + + float* output_tensor_data = nullptr; + THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data)); + std::cout<<"Result:\n"; + for (size_t i = 0; i < 4; i++) std::cout< +#include namespace onnxruntime { +OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(ep_info) { + type = ep_type; + OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { + std::vector cache; + size_t nodes_count = 0; + const size_t* nodes_index = OrtGraph_GetNodesIndexInTopologicalOrder(graph, &nodes_count); + for (size_t i = 0; i < nodes_count; i++) { + const OrtNode* node = OrtGraph_GetOrtNode(graph, nodes_index[i]); + if (OrtNode_GetOpType(node) == "Relu") { + OrtIndexedSubGraph* subgraph = new OrtIndexedSubGraph(); + subgraph->node_index_len = 1; + subgraph->node_index = new size_t [subgraph->node_index_len]; + subgraph->node_index[0] = nodes_index[0]; + + subgraph->meta_def = new OrtMetaDef(); + subgraph->meta_def->name = "Relu_subgraph"; + subgraph->meta_def->input_len = OrtNode_GetInputSize(node); + subgraph->meta_def->inputs = new const char* [subgraph->meta_def->input_len]; + for (int j = 0; j < subgraph->meta_def->input_len; j++) subgraph->meta_def->inputs[j] = OrtNode_GetIthInputName(node, j); + + subgraph->meta_def->output_len = OrtNode_GetOutputSize(node); + subgraph->meta_def->outputs = new const char* [subgraph->meta_def->output_len]; + for (int j = 0; j < subgraph->meta_def->output_len; j++) subgraph->meta_def->outputs[j] = OrtNode_GetIthOutputName(node, j); + + cache.push_back(subgraph); + } + } + + *cnt = cache.size(); + *indexed_sub_graph = new OrtIndexedSubGraph* [*cnt]; + for (size_t i = 0; i < *cnt; i++) { + (*indexed_sub_graph)[i] = cache[i]; + } + }; + + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo*** node_compute_info) { + for (size_t i = 0; i < cnt; i++) { + (*node_compute_info)[i]->ComputeFunc = [](void* state, const OrtApi* api, OrtKernelContext* context) ->OrtStatusPtr { + const OrtValue* input = nullptr; + api->KernelContext_GetInput(context, 0, &input); + std::vector dim(1,4); + OrtValue* output = nullptr; + api->KernelContext_GetOutput(context, 0, dim.data(), dim.size(), &output); + + float* input_raw = nullptr, *output_raw = nullptr; + api->GetTensorMutableData(const_cast(input), reinterpret_cast(&input_raw)); + api->GetTensorMutableData(output, reinterpret_cast(&output_raw)); + + for (int i = 0; i < 4; i++) { + output_raw[i] = input_raw[i]; + if (input_raw[i] < 0) output_raw[i] = 0; + } + + return nullptr; + }; + } + }; +} + OutTreeEpFactory::OutTreeEpFactory() { OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> void* { OutTreeEpInfo info; diff --git a/samples/outTreeEp/out_tree_ep.h b/samples/outTreeEp/out_tree_ep.h index 73681383e06e7..cd4b49cabd2c6 100644 --- a/samples/outTreeEp/out_tree_ep.h +++ b/samples/outTreeEp/out_tree_ep.h @@ -16,7 +16,7 @@ struct OutTreeEpInfo { }; struct OutTreeEp : public OrtExecutionProvider { - OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(ep_info) { type = ep_type; } + OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info); OutTreeEpInfo info; }; From 808bfc3bb11cccd731070d1e165e259e5d46e29d Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 29 Jul 2024 17:49:42 +0000 Subject: [PATCH 05/81] prototype works with hardcode node_compute_info's index in ExecutionProviderAdapter::Compile() --- .../core/session/onnxruntime_c_api.h | 40 +++---- onnxruntime/core/framework/provider_adapter.h | 38 ++++--- onnxruntime/core/session/onnxruntime_c_api.cc | 105 ++++++++++-------- onnxruntime/core/session/ort_apis.h | 16 +++ samples/c_test/CMakeLists.txt | 2 +- samples/outTreeEp/CMakeLists.txt | 2 +- samples/outTreeEp/out_tree_ep.cc | 30 +++-- 7 files changed, 144 insertions(+), 89 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 1d6bbf00108cd..5cd0e92cf6126 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -699,11 +699,11 @@ typedef struct OrtMetaDef { int since_version; const char** inputs; - int input_len; + size_t input_len; const char** outputs; - int output_len; + size_t output_len; const char** constant_initializers; - int initializer_len; + size_t initializer_len; const char* doc_string; } OrtMetaDef; @@ -4718,7 +4718,23 @@ struct OrtApi { ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); -}; + + ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); + + ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); + + ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); + + ORT_API2_STATUS(OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type); + + ORT_API2_STATUS(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size); + + ORT_API2_STATUS(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); + + ORT_API2_STATUS(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size); + + ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); +}; // struct OrtApi /* * Steps to use a custom op: @@ -4878,22 +4894,6 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOpt */ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); -ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope); - -ORT_API(const size_t*, OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, size_t* len); - -ORT_API(const OrtNode*, OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index); - -ORT_API(const char*, OrtNode_GetOpType, const OrtNode* node); - -ORT_API(size_t, OrtNode_GetInputSize, const OrtNode* node); - -ORT_API(const char*, OrtNode_GetIthInputName, const OrtNode* node, size_t i); - -ORT_API(size_t, OrtNode_GetOutputSize, const OrtNode* node); - -ORT_API(const char*, OrtNode_GetIthOutputName, const OrtNode* node, size_t i); - #ifdef __cplusplus } #endif diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 5dc7fd3cb8d11..e1ee6ed704c46 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -23,19 +23,19 @@ class ExecutionProviderAdapter : public IExecutionProvider { for (size_t j = 0; j < indexed_subgraph[i]->node_index_len; j++) sb->nodes.push_back((indexed_subgraph[i]->node_index)[j]); if (indexed_subgraph[i]->meta_def != nullptr) { std::unique_ptr meta_def = std::make_unique(); - meta_def->name = indexed_subgraph[i]->meta_def->name; - meta_def->doc_string = indexed_subgraph[i]->meta_def->doc_string; - meta_def->domain = indexed_subgraph[i]->meta_def->domain; + meta_def->name = indexed_subgraph[i]->meta_def->name ? indexed_subgraph[i]->meta_def->name : ""; + meta_def->doc_string = indexed_subgraph[i]->meta_def->doc_string ? indexed_subgraph[i]->meta_def->doc_string : ""; + meta_def->domain = indexed_subgraph[i]->meta_def->domain ? indexed_subgraph[i]->meta_def->domain : ""; meta_def->since_version = indexed_subgraph[i]->meta_def->since_version; meta_def->inputs.reserve(indexed_subgraph[i]->meta_def->input_len); - for (int j = 0; j < indexed_subgraph[i]->meta_def->input_len; j++) meta_def->inputs.push_back(indexed_subgraph[i]->meta_def->inputs[j]); + for (size_t j = 0; j < indexed_subgraph[i]->meta_def->input_len; j++) meta_def->inputs.push_back(indexed_subgraph[i]->meta_def->inputs[j]); meta_def->outputs.reserve(indexed_subgraph[i]->meta_def->output_len); - for (int j = 0; j < indexed_subgraph[i]->meta_def->output_len; j++) meta_def->outputs.push_back(indexed_subgraph[i]->meta_def->outputs[j]); + for (size_t j = 0; j < indexed_subgraph[i]->meta_def->output_len; j++) meta_def->outputs.push_back(indexed_subgraph[i]->meta_def->outputs[j]); meta_def->constant_initializers.reserve(indexed_subgraph[i]->meta_def->initializer_len); - for (int j = 0; j < indexed_subgraph[i]->meta_def->initializer_len; j++) meta_def->constant_initializers.push_back(indexed_subgraph[i]->meta_def->constant_initializers[j]); + for (size_t j = 0; j < indexed_subgraph[i]->meta_def->initializer_len; j++) meta_def->constant_initializers.push_back(indexed_subgraph[i]->meta_def->constant_initializers[j]); sb->SetMetaDef(std::move(meta_def)); } @@ -55,25 +55,30 @@ class ExecutionProviderAdapter : public IExecutionProvider { ortNodes.push_back(reinterpret_cast(&fused_node)); } size_t count = fused_nodes_and_graphs.size(); - OrtNodeComputeInfo** node_compute_info = new OrtNodeComputeInfo* [count]; - ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &node_compute_info); + node_compute_info_ = new OrtNodeComputeInfo* [count]; + ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &node_compute_info_); node_compute_funcs.reserve(count); for (size_t i = 0; i < count; i++) { NodeComputeInfo compute_info; compute_info.create_state_func = [&](ComputeContext* context, void** state) { - OrtComputeContext occ; - occ.AllocateFunc = context->allocate_func; - occ.DestroyFunc = context->release_func; - occ.allocator_handle = context->allocator_handle; - occ.node_name = context->node_name; - return node_compute_info[i]->CreateFunctionStateFunc(&occ, state); // TODO(leca): reinterpret_cast(context)? + if (node_compute_info_[0]->CreateFunctionStateFunc) { + OrtComputeContext occ; + occ.AllocateFunc = context->allocate_func; + occ.DestroyFunc = context->release_func; + occ.allocator_handle = context->allocator_handle; + occ.node_name = context->node_name; + return node_compute_info_[0]->CreateFunctionStateFunc(&occ, state); // TODO(leca): reinterpret_cast(context)? + } + return 0; }; compute_info.compute_func = [&](void* state, const OrtApi* api, OrtKernelContext* context) { - return ToStatus(node_compute_info[i]->ComputeFunc(state, api, context)); + return ToStatus(node_compute_info_[0]->ComputeFunc(state, api, context)); }; compute_info.release_state_func = [&](void* state) { - node_compute_info[i]->DestroyFunctionStateFunc(state); + if (node_compute_info_[0]->DestroyFunctionStateFunc) { + node_compute_info_[0]->DestroyFunctionStateFunc(state); + } }; node_compute_funcs.push_back(compute_info); } @@ -81,5 +86,6 @@ class ExecutionProviderAdapter : public IExecutionProvider { } private: OrtExecutionProvider* ep_impl_; + OrtNodeComputeInfo** node_compute_info_; }; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 58253509838f7..f41214905ea30 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2378,6 +2378,58 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendOrtExecutionProvider, _In_ OrtS return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *ret = graph_viewer->IsConstantInitializer(name, check_outer_scope); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const std::vector& nodes = graph_viewer->GetNodesInTopologicalOrder(); + *len = nodes.size(); + *nodes_index_in_topological_order = nodes.data(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *node = reinterpret_cast(graph_viewer->GetNode(node_index)); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *op_type = n->OpType().c_str(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *input_size = n->InputDefs().size(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + assert(i < n->InputDefs().size()); + *ith_input_name = n->InputDefs()[i]->Name().c_str(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *output_size = n->OutputDefs().size(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + assert(i < n->OutputDefs().size()); + *ith_output_name = n->OutputDefs()[i]->Name().c_str(); + return nullptr; +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2758,6 +2810,15 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::RegisterOrtExecutionProviderLibrary, &OrtApis::SessionOptionsAppendOrtExecutionProvider, + + &OrtApis::OrtGraph_IsConstantInitializer, + &OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, + &OrtApis::OrtGraph_GetOrtNode, + &OrtApis::OrtNode_GetOpType, + &OrtApis::OrtNode_GetInputSize, + &OrtApis::OrtNode_GetIthInputName, + &OrtApis::OrtNode_GetOutputSize, + &OrtApis::OrtNode_GetIthOutputName, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2830,47 +2891,3 @@ DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata) - -ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->IsConstantInitializer(name, check_outer_scope); -} - -ORT_API(const size_t*, OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, size_t* len) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - const std::vector& nodes = graph_viewer->GetNodesInTopologicalOrder(); - *len = nodes.size(); - return nodes.data(); -} - -ORT_API(const OrtNode*, OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return reinterpret_cast(graph_viewer->GetNode(node_index)); -} - -ORT_API(const char*, OrtNode_GetOpType, const OrtNode* node) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->OpType().c_str(); -} - -ORT_API(size_t, OrtNode_GetInputSize, const OrtNode* node) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->InputDefs().size(); -} - -ORT_API(const char*, OrtNode_GetIthInputName, const OrtNode* node, size_t i) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - assert(i < n->InputDefs().size()); - return n->InputDefs()[i]->Name().c_str(); -} - -ORT_API(size_t, OrtNode_GetOutputSize, const OrtNode* node) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->OutputDefs().size(); -} - -ORT_API(const char*, OrtNode_GetIthOutputName, const OrtNode* node, size_t i) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - assert(i < n->OutputDefs().size()); - return n->OutputDefs()[i]->Name().c_str(); -} diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index efcf89a3375ff..1132dd2446fc1 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -528,4 +528,20 @@ ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* l ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + +ORT_API_STATUS_IMPL(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); + +ORT_API_STATUS_IMPL(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); + +ORT_API_STATUS_IMPL(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); + +ORT_API_STATUS_IMPL(OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type); + +ORT_API_STATUS_IMPL(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size); + +ORT_API_STATUS_IMPL(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); + +ORT_API_STATUS_IMPL(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size); + +ORT_API_STATUS_IMPL(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); } // namespace OrtApis diff --git a/samples/c_test/CMakeLists.txt b/samples/c_test/CMakeLists.txt index 068b7cf18be91..c8bf77b99c1a0 100644 --- a/samples/c_test/CMakeLists.txt +++ b/samples/c_test/CMakeLists.txt @@ -1,6 +1,6 @@ # usage: # cd build/ -# cmake -S ../ -B ./ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug # cmake --build ./ cmake_minimum_required(VERSION 3.26) project(TestOutTreeEp) diff --git a/samples/outTreeEp/CMakeLists.txt b/samples/outTreeEp/CMakeLists.txt index 9fbe3f0596e9d..d4193f6f8fffa 100644 --- a/samples/outTreeEp/CMakeLists.txt +++ b/samples/outTreeEp/CMakeLists.txt @@ -1,6 +1,6 @@ # usage: # cd build/ -# cmake -S ../ -B ./ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug # cmake --build ./ cmake_minimum_required(VERSION 3.26) project(outTreeEp VERSION 1.0) diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index a45aa36dbebd6..0300f1760081d 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -6,12 +6,17 @@ namespace onnxruntime { OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(ep_info) { type = ep_type; OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); std::vector cache; size_t nodes_count = 0; - const size_t* nodes_index = OrtGraph_GetNodesIndexInTopologicalOrder(graph, &nodes_count); + const size_t* nodes_index = nullptr; + api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, &nodes_count, &nodes_index); for (size_t i = 0; i < nodes_count; i++) { - const OrtNode* node = OrtGraph_GetOrtNode(graph, nodes_index[i]); - if (OrtNode_GetOpType(node) == "Relu") { + const OrtNode* node = nullptr; + api->OrtGraph_GetOrtNode(graph, nodes_index[i], &node); + const char* node_op_type; + api->OrtNode_GetOpType(node, &node_op_type); + if (!strcmp(node_op_type, "Relu")) { OrtIndexedSubGraph* subgraph = new OrtIndexedSubGraph(); subgraph->node_index_len = 1; subgraph->node_index = new size_t [subgraph->node_index_len]; @@ -19,13 +24,21 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(e subgraph->meta_def = new OrtMetaDef(); subgraph->meta_def->name = "Relu_subgraph"; - subgraph->meta_def->input_len = OrtNode_GetInputSize(node); + subgraph->meta_def->input_len = 0; + api->OrtNode_GetInputSize(node, &(subgraph->meta_def->input_len)); subgraph->meta_def->inputs = new const char* [subgraph->meta_def->input_len]; - for (int j = 0; j < subgraph->meta_def->input_len; j++) subgraph->meta_def->inputs[j] = OrtNode_GetIthInputName(node, j); + for (size_t j = 0; j < subgraph->meta_def->input_len; j++) { + subgraph->meta_def->inputs[j] = nullptr; + api->OrtNode_GetIthInputName(node, j, &(subgraph->meta_def->inputs[j])); + } - subgraph->meta_def->output_len = OrtNode_GetOutputSize(node); + subgraph->meta_def->output_len = 0; + api->OrtNode_GetOutputSize(node, &(subgraph->meta_def->output_len)); subgraph->meta_def->outputs = new const char* [subgraph->meta_def->output_len]; - for (int j = 0; j < subgraph->meta_def->output_len; j++) subgraph->meta_def->outputs[j] = OrtNode_GetIthOutputName(node, j); + for (size_t j = 0; j < subgraph->meta_def->output_len; j++) { + subgraph->meta_def->outputs[j] = nullptr; + api->OrtNode_GetIthOutputName(node, j, &(subgraph->meta_def->outputs[j])); + } cache.push_back(subgraph); } @@ -40,6 +53,7 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(e OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo*** node_compute_info) { for (size_t i = 0; i < cnt; i++) { + (*node_compute_info)[i] = new OrtNodeComputeInfo(); (*node_compute_info)[i]->ComputeFunc = [](void* state, const OrtApi* api, OrtKernelContext* context) ->OrtStatusPtr { const OrtValue* input = nullptr; api->KernelContext_GetInput(context, 0, &input); @@ -54,6 +68,8 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(e for (int i = 0; i < 4; i++) { output_raw[i] = input_raw[i]; if (input_raw[i] < 0) output_raw[i] = 0; + + output_raw[i] = 1.0; } return nullptr; From 49e396c6b32eba31c394bd427a8e1d4a3f01f565 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 29 Jul 2024 23:47:31 +0000 Subject: [PATCH 06/81] prototype works without hardcode --- onnxruntime/core/framework/provider_adapter.h | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index e1ee6ed704c46..c85e87d5bb6b1 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -61,27 +61,32 @@ class ExecutionProviderAdapter : public IExecutionProvider { node_compute_funcs.reserve(count); for (size_t i = 0; i < count; i++) { NodeComputeInfo compute_info; - compute_info.create_state_func = [&](ComputeContext* context, void** state) { - if (node_compute_info_[0]->CreateFunctionStateFunc) { + compute_info.create_state_func = [&, i](ComputeContext* context, void** state) { + if (node_compute_info_[i]->CreateFunctionStateFunc) { OrtComputeContext occ; occ.AllocateFunc = context->allocate_func; occ.DestroyFunc = context->release_func; occ.allocator_handle = context->allocator_handle; occ.node_name = context->node_name; - return node_compute_info_[0]->CreateFunctionStateFunc(&occ, state); // TODO(leca): reinterpret_cast(context)? + return node_compute_info_[i]->CreateFunctionStateFunc(&occ, state); // TODO(leca): reinterpret_cast(context)? } return 0; }; - compute_info.compute_func = [&](void* state, const OrtApi* api, OrtKernelContext* context) { - return ToStatus(node_compute_info_[0]->ComputeFunc(state, api, context)); + compute_info.compute_func = [&, i](void* state, const OrtApi* api, OrtKernelContext* context) { + return ToStatus(node_compute_info_[i]->ComputeFunc(state, api, context)); }; - compute_info.release_state_func = [&](void* state) { - if (node_compute_info_[0]->DestroyFunctionStateFunc) { - node_compute_info_[0]->DestroyFunctionStateFunc(state); + compute_info.release_state_func = [&, i](void* state) { + if (node_compute_info_[i]->DestroyFunctionStateFunc) { + node_compute_info_[i]->DestroyFunctionStateFunc(state); } }; node_compute_funcs.push_back(compute_info); } + +/* node_compute_funcs.resize(count); + NodeComputeInfo* + ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, reinterpret_cast<>(&node_compute_funcs.data())); +*/ return Status::OK(); } private: From e790105c3791cbd08bb4dfefc4a66c2604359c2a Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 31 Jul 2024 20:38:45 +0000 Subject: [PATCH 07/81] fix comments for Compile function --- .../core/session/onnxruntime_c_api.h | 2 +- onnxruntime/core/framework/provider_adapter.h | 35 +++++++------------ samples/outTreeEp/out_tree_ep.cc | 5 ++- 3 files changed, 15 insertions(+), 27 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5cd0e92cf6126..8f052163ece43 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -729,7 +729,7 @@ typedef struct OrtNodeComputeInfo { typedef struct OrtExecutionProvider { void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); - void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo*** node_compute_info); + void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); const char* type; } OrtExecutionProvider; diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index c85e87d5bb6b1..e9e4b0a656e12 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -55,42 +55,31 @@ class ExecutionProviderAdapter : public IExecutionProvider { ortNodes.push_back(reinterpret_cast(&fused_node)); } size_t count = fused_nodes_and_graphs.size(); - node_compute_info_ = new OrtNodeComputeInfo* [count]; - ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &node_compute_info_); - + std::vector cache; + cache.resize(count); + OrtNodeComputeInfo* cache_data = cache.data(); + ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &cache_data); node_compute_funcs.reserve(count); for (size_t i = 0; i < count; i++) { NodeComputeInfo compute_info; - compute_info.create_state_func = [&, i](ComputeContext* context, void** state) { - if (node_compute_info_[i]->CreateFunctionStateFunc) { - OrtComputeContext occ; - occ.AllocateFunc = context->allocate_func; - occ.DestroyFunc = context->release_func; - occ.allocator_handle = context->allocator_handle; - occ.node_name = context->node_name; - return node_compute_info_[i]->CreateFunctionStateFunc(&occ, state); // TODO(leca): reinterpret_cast(context)? - } + compute_info.create_state_func = [&, cache, i](ComputeContext* context, void** state) { + if (cache[i].CreateFunctionStateFunc) return cache[i].CreateFunctionStateFunc(reinterpret_cast(context), state); return 0; }; - compute_info.compute_func = [&, i](void* state, const OrtApi* api, OrtKernelContext* context) { - return ToStatus(node_compute_info_[i]->ComputeFunc(state, api, context)); + compute_info.compute_func = [&, cache, i](void* state, const OrtApi* api, OrtKernelContext* context) { + return ToStatus(cache[i].ComputeFunc(state, api, context)); }; - compute_info.release_state_func = [&, i](void* state) { - if (node_compute_info_[i]->DestroyFunctionStateFunc) { - node_compute_info_[i]->DestroyFunctionStateFunc(state); + compute_info.release_state_func = [&, cache, i](void* state) { + if (cache[i].DestroyFunctionStateFunc) { + cache[i].DestroyFunctionStateFunc(state); } }; - node_compute_funcs.push_back(compute_info); + node_compute_funcs.emplace_back(std::move(compute_info)); } -/* node_compute_funcs.resize(count); - NodeComputeInfo* - ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, reinterpret_cast<>(&node_compute_funcs.data())); -*/ return Status::OK(); } private: OrtExecutionProvider* ep_impl_; - OrtNodeComputeInfo** node_compute_info_; }; } diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index 0300f1760081d..fc9d81f9bb005 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -51,10 +51,9 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(e } }; - OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo*** node_compute_info) { + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) { for (size_t i = 0; i < cnt; i++) { - (*node_compute_info)[i] = new OrtNodeComputeInfo(); - (*node_compute_info)[i]->ComputeFunc = [](void* state, const OrtApi* api, OrtKernelContext* context) ->OrtStatusPtr { + node_compute_info[i]->ComputeFunc = [](void* state, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { const OrtValue* input = nullptr; api->KernelContext_GetInput(context, 0, &input); std::vector dim(1,4); From 92f529d09e2f00c6801b504f8ee9d871d409f544 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 1 Aug 2024 17:42:05 +0000 Subject: [PATCH 08/81] add provider_factory_adapter.h --- .../onnxruntime/core/session/environment.h | 2 +- .../core/session/onnxruntime_c_api.h | 2 +- .../core/framework/provider_factory_adapter.h | 21 +++++++++++++++++++ onnxruntime/core/framework/session_options.h | 3 --- onnxruntime/core/session/environment.cc | 6 ++++++ onnxruntime/core/session/inference_session.cc | 17 --------------- onnxruntime/core/session/onnxruntime_c_api.cc | 11 +++++----- onnxruntime/core/session/ort_apis.h | 2 +- onnxruntime/core/session/ort_env.cc | 4 ++++ onnxruntime/core/session/ort_env.h | 2 ++ samples/c_test/test.cpp | 2 +- 11 files changed, 43 insertions(+), 29 deletions(-) create mode 100644 onnxruntime/core/framework/provider_factory_adapter.h diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 473e0898443e1..08a3730827835 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -90,7 +90,7 @@ class Environment { void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); - const std::unordered_map>& GetCustomEpFactories() const { return custom_ep_factories_; } + OrtExecutionProviderFactory* GetOrtExecutionProviderFactory(const std::string& ep_name); private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8f052163ece43..7b4aa583e7d4e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4716,7 +4716,7 @@ struct OrtApi { ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); - ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, + ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); diff --git a/onnxruntime/core/framework/provider_factory_adapter.h b/onnxruntime/core/framework/provider_factory_adapter.h new file mode 100644 index 0000000000000..71b2aabf60a5a --- /dev/null +++ b/onnxruntime/core/framework/provider_factory_adapter.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/providers.h" +#include "provider_adapter.h" + +namespace onnxruntime { +struct ExecutionProviderFactoryAdapter : IExecutionProviderFactory { +ExecutionProviderFactoryAdapter(OrtExecutionProviderFactory* ep_factory, const char* const* provider_option_keys, const char* const* provider_option_values, size_t provider_option_length) + : ep_factory_(ep_factory), provider_option_keys_(provider_option_keys), provider_option_values_(provider_option_values), provider_option_length_(provider_option_length) {} +std::unique_ptr CreateProvider() override { + void* ep = ep_factory_->CreateExecutionProvider(ep_factory_, provider_option_keys_, provider_option_values_, provider_option_length_); + return std::make_unique(reinterpret_cast(ep)); +} +OrtExecutionProviderFactory* ep_factory_; +const char* const* provider_option_keys_; +const char* const* provider_option_values_; +size_t provider_option_length_; +}; +} diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index b82ff038c8d26..46bfc3630303c 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -15,7 +15,6 @@ #include "core/session/onnxruntime_c_api.h" #include "core/optimizer/graph_transformer_level.h" #include "core/util/thread_utils.h" -#include "core/framework/provider_options.h" #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) #include "core/framework/library_handles.h" @@ -185,8 +184,6 @@ struct SessionOptions { // User specified logging func and param OrtLoggingFunction user_logging_function = nullptr; void* user_logging_param = nullptr; - - ProviderOptionsMap custom_ep_options; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 17d06fbe0dbf9..8083a473211d7 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -353,4 +353,10 @@ void Environment::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactor custom_ep_factories_.insert({ep_name, std::move(p)}); // TODO(leca): review } +OrtExecutionProviderFactory* Environment::GetOrtExecutionProviderFactory(const std::string& ep_name) { + std::unordered_map>::const_iterator it = custom_ep_factories_.find(ep_name); + if (it == custom_ep_factories_.end()) return nullptr; + return it->second.get(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 7fa25bf779fb5..f0eed91d70440 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -89,7 +89,6 @@ #include "core/framework/stream_execution_context.h" #include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" #endif -#include "core/framework/provider_adapter.h" using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; @@ -1656,22 +1655,6 @@ common::Status InferenceSession::Initialize() { const Env& env = Env::Default(); env.GetTelemetryProvider().LogSessionCreationStart(); - const std::unordered_map>& custom_ep_factories = environment_.GetCustomEpFactories(); - if (custom_ep_factories.size() > 0) { - for (auto const& [ep_name, ep_factory] : custom_ep_factories) { - if (session_options_.custom_ep_options.find(ep_name) != session_options_.custom_ep_options.end()) { - std::vector keys, values; - for (auto const& [op_k, op_v] : session_options_.custom_ep_options[ep_name]) { - keys.push_back(op_k.c_str()); - values.push_back(op_v.c_str()); - } - OrtExecutionProvider* ep = reinterpret_cast(ep_factory->CreateExecutionProvider(ep_factory.get(), keys.data(), values.data(), keys.size())); - std::unique_ptr ep_adapter = std::make_unique(ep); - ORT_RETURN_IF_ERROR(RegisterExecutionProvider(std::move(ep_adapter))); - } - } - } - bool have_cpu_ep = false; { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f41214905ea30..165edf6c48bfe 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -39,6 +39,7 @@ #include "core/framework/TensorSeq.h" #include "core/platform/ort_mutex.h" #include "core/common/string_helper.h" +#include "core/framework/provider_factory_adapter.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -2368,13 +2369,13 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterOrtExecutionProviderLibrary, _In_ const cha API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { - std::unordered_map kv; - for (size_t i = 0; i < num_keys; i++) { - kv.insert({provider_options_keys[i], provider_options_values[i]}); + OrtExecutionProviderFactory* ep_factory = env->GetOrtExecutionProviderFactory(ep_name); + if (ep_factory) { + std::shared_ptr factory = std::make_shared(ep_factory, provider_options_keys, provider_options_values, num_keys); + options->provider_factories.push_back(std::move(factory)); } - options->value.custom_ep_options.insert({ep_name, kv}); return nullptr; } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 1132dd2446fc1..496dd013bb00f 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -526,7 +526,7 @@ ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); -ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, +ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); ORT_API_STATUS_IMPL(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index e3212b17dac20..188b276d12a6d 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -114,3 +114,7 @@ onnxruntime::common::Status OrtEnv::CreateAndRegisterAllocatorV2(const std::stri void OrtEnv::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory) { value_->InsertCustomEp(ep_name, ep_factory); } + +OrtExecutionProviderFactory* OrtEnv::GetOrtExecutionProviderFactory(const char* ep_name) { + return value_->GetOrtExecutionProviderFactory(ep_name); +} diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 42d33de3c8d39..31a31b21ef54c 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -67,6 +67,8 @@ struct OrtEnv { void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); + OrtExecutionProviderFactory* GetOrtExecutionProviderFactory(const char* ep_name); + private: static std::unique_ptr p_instance_; static onnxruntime::OrtMutex m_; diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index f5b503b6b5217..ad668350d2307 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -16,7 +16,7 @@ int main() { OrtSessionOptions* so = nullptr; THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; - THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "outTreeEp", keys.data(), values.data(), keys.size())); + THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "outTreeEp", p_env, keys.data(), values.data(), keys.size())); OrtSession* session = nullptr; THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", so, &session)); From 3d83ed1cafebdf343fb11365250bdec6327b4522 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 5 Aug 2024 19:27:04 +0000 Subject: [PATCH 09/81] fix crash after introducing kernel based EP --- .../core/session/onnxruntime_c_api.h | 9 +- onnxruntime/core/framework/provider_adapter.h | 10 +- .../core/framework/provider_factory_adapter.h | 23 ++++- onnxruntime/core/session/custom_ops.cc | 94 ++++++++++++++++++- onnxruntime/core/session/onnxruntime_c_api.cc | 11 +++ onnxruntime/core/session/ort_apis.h | 2 + samples/c_test/test.cpp | 19 +++- samples/outTreeEp/out_tree_ep.cc | 4 +- 8 files changed, 157 insertions(+), 15 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 7b4aa583e7d4e..125aa54e00b81 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -308,6 +308,7 @@ ORT_RUNTIME_CLASS(ExecutionProvider); ORT_RUNTIME_CLASS(ExecutionProviderFactory); ORT_RUNTIME_CLASS(Node); ORT_RUNTIME_CLASS(GraphViewer); +ORT_RUNTIME_CLASS(KernelRegistry); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -728,13 +729,17 @@ typedef struct OrtNodeComputeInfo { } OrtNodeComputeInfo; typedef struct OrtExecutionProvider { +#ifdef __cplusplus + OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr} {} +#endif void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); + void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); const char* type; } OrtExecutionProvider; typedef struct OrtExecutionProviderFactory { - void*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); + OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); } OrtExecutionProviderFactory; /** \brief Thread work loop function @@ -4734,6 +4739,8 @@ struct OrtApi { ORT_API2_STATUS(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size); ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); + + ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op); }; // struct OrtApi /* diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index e9e4b0a656e12..6b3e0f8fe419e 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -8,7 +8,12 @@ namespace onnxruntime { class ExecutionProviderAdapter : public IExecutionProvider { public: - ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type), ep_impl_(ep) {} + ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type), ep_impl_(ep) { + if (ep_impl_->RegisterKernels) { + kernel_registry_ = std::make_shared(); + ep_impl_->RegisterKernels(reinterpret_cast(kernel_registry_.get())); + } + } virtual std::vector> GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup) const override { size_t cnt = 0; OrtIndexedSubGraph** indexed_subgraph = nullptr; @@ -79,7 +84,10 @@ class ExecutionProviderAdapter : public IExecutionProvider { return Status::OK(); } + + virtual std::shared_ptr GetKernelRegistry() const override { return kernel_registry_; } private: OrtExecutionProvider* ep_impl_; + std::shared_ptr kernel_registry_; // TODO(leca): should be static local }; } diff --git a/onnxruntime/core/framework/provider_factory_adapter.h b/onnxruntime/core/framework/provider_factory_adapter.h index 71b2aabf60a5a..9cfa68ecaa864 100644 --- a/onnxruntime/core/framework/provider_factory_adapter.h +++ b/onnxruntime/core/framework/provider_factory_adapter.h @@ -8,14 +8,27 @@ namespace onnxruntime { struct ExecutionProviderFactoryAdapter : IExecutionProviderFactory { ExecutionProviderFactoryAdapter(OrtExecutionProviderFactory* ep_factory, const char* const* provider_option_keys, const char* const* provider_option_values, size_t provider_option_length) - : ep_factory_(ep_factory), provider_option_keys_(provider_option_keys), provider_option_values_(provider_option_values), provider_option_length_(provider_option_length) {} + : ep_factory_(ep_factory), provider_option_length_(provider_option_length) { + provider_option_keys_.reserve(provider_option_length); + provider_option_values_.reserve(provider_option_length); + keys_.reserve(provider_option_length); + values_.reserve(provider_option_length); + for (size_t i = 0; i < provider_option_length; i++) { + provider_option_keys_.push_back(provider_option_keys[i]); + provider_option_values_.push_back(provider_option_values[i]); + keys_.push_back(provider_option_keys_[i].c_str()); + values_.push_back(provider_option_values_[i].c_str()); + } + } + std::unique_ptr CreateProvider() override { - void* ep = ep_factory_->CreateExecutionProvider(ep_factory_, provider_option_keys_, provider_option_values_, provider_option_length_); - return std::make_unique(reinterpret_cast(ep)); + return std::make_unique(ep_factory_->CreateExecutionProvider(ep_factory_, keys_.data(), values_.data(), provider_option_length_)); } OrtExecutionProviderFactory* ep_factory_; -const char* const* provider_option_keys_; -const char* const* provider_option_values_; +//const char* const* provider_option_keys_; +//const char* const* provider_option_values_; +std::vector provider_option_keys_, provider_option_values_; +std::vector keys_, values_; size_t provider_option_length_; }; } diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 4c782f647371e..04b11848d6981 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -49,9 +49,9 @@ static constexpr uint32_t min_ort_version_with_shape_inference = 17; #endif #if !defined(DISABLE_FLOAT8_TYPES) -#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv9() +#define SUPPORTED_TENSOR_TYPES onnxruntime::DataTypeImpl::AllTensorTypesIRv9() #else -#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv4() +#define SUPPORTED_TENSOR_TYPES onnxruntime::DataTypeImpl::AllTensorTypesIRv4() #endif #if defined(ORT_MINIMAL_BUILD) @@ -1331,3 +1331,93 @@ common::Status CreateCustomRegistry(gsl::span op_domai } // namespace onnxruntime #endif // ENABLE_CUSTOM_OP_API + +//namespace onnxruntime { +class FuncManager; +class OpKernelInfo; +onnxruntime::KernelCreateInfo CreateKernelCreateInfo2(const std::string& domain, const OrtCustomOp* op) { + const size_t input_count = op->GetInputTypeCount(op); + const size_t output_count = op->GetOutputTypeCount(op); + + onnxruntime::KernelDefBuilder def_builder; + def_builder.SetName(op->GetName(op)) + .SetDomain(domain); + + if (op->version >= min_ort_version_with_custom_version) { + if (op->GetStartVersion && op->GetEndVersion) { + def_builder.SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op)); + } else if (op->GetStartVersion) { + def_builder.SinceVersion(op->GetStartVersion(op)); + } else { + def_builder.SinceVersion(1); + } + } else { + def_builder.SinceVersion(1); + } + + // GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions + // to work with newer versions (> 12) of the ORT binary. + if (op->version > 12) { + for (size_t i = 0; i < input_count; i++) { + def_builder.InputMemoryType(op->GetInputMemoryType(op, i), gsl::narrow_cast(i)); + } + } + + for (size_t i = 0; i < input_count; i++) { + const auto input_type = op->GetInputType(op, i); + const auto input_name = "Input" + std::to_string(i); + if (input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + def_builder.TypeConstraint(input_name, SUPPORTED_TENSOR_TYPES); + } else { + def_builder.TypeConstraint(input_name, + onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(static_cast(input_type))->AsTensorType()); + } + } + + for (size_t i = 0; i < output_count; i++) { + const auto output_type = op->GetOutputType(op, i); + const auto output_name = "Output" + std::to_string(i); + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + def_builder.TypeConstraint(output_name, SUPPORTED_TENSOR_TYPES); + } else { + def_builder.TypeConstraint(output_name, + onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(static_cast(output_type))->AsTensorType()); + } + } + + if (const char* provider_type = op->GetExecutionProviderType(op)) { + def_builder.Provider(provider_type); + } else { + def_builder.Provider(onnxruntime::kCpuExecutionProvider); + } + + if (op->version >= 18 && op->GetMayInplace != nullptr) { + int* input_index = nullptr; + int* output_index = nullptr; + size_t len = op->GetMayInplace(&input_index, &output_index); + if (len > 0) { + for (size_t i = 0; i < len; i++) def_builder.MayInplace(input_index[i], output_index[i]); + op->ReleaseMayInplace(input_index, output_index); + } + } + + if (op->version >= 18 && op->GetAliasMap != nullptr) { + int* input_index = nullptr; + int* output_index = nullptr; + size_t len = op->GetAliasMap(&input_index, &output_index); + if (len > 0) { + for (size_t i = 0; i < len; i++) def_builder.Alias(input_index[i], output_index[i]); + op->ReleaseAliasMap(input_index, output_index); + } + } + + onnxruntime::KernelCreateFn kernel_create_fn = [op](onnxruntime::FuncManager&, const onnxruntime::OpKernelInfo& info, + std::unique_ptr& out) -> onnxruntime::common::Status { + out = std::make_unique(info, *op); + return onnxruntime::common::Status::OK(); + }; + + return onnxruntime::KernelCreateInfo(def_builder.Build(), kernel_create_fn); +// return onnxruntime::KernelCreateInfo(def_builder.Build(), nullptr); +} +//} // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 165edf6c48bfe..8dad6c1f8133e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -40,6 +40,7 @@ #include "core/platform/ort_mutex.h" #include "core/common/string_helper.h" #include "core/framework/provider_factory_adapter.h" +#include "core/framework/kernel_registry.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -114,6 +115,9 @@ using namespace onnxruntime; auto v = (value); \ auto tensor = v->GetMutable(); +// TODO(leca): try: namespace onnxruntime { KernelCreateInfo CreateKernelCreateInfo2(..); }, then define this function inside onnxruntime namespace +KernelCreateInfo CreateKernelCreateInfo2(const std::string& domain, const OrtCustomOp* op); + ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) { @@ -2431,6 +2435,12 @@ ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthOutputName, const OrtNode* node, size return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op) { + KernelRegistry* kr = reinterpret_cast(kernel_registry); + KernelCreateInfo kci = CreateKernelCreateInfo2("", custom_op); + return ToOrtStatus(kr->Register(std::move(kci))); +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2820,6 +2830,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtNode_GetIthInputName, &OrtApis::OrtNode_GetOutputSize, &OrtApis::OrtNode_GetIthOutputName, + &OrtApis::OrtKernelRegistry_RegisterKernel, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 496dd013bb00f..473ddc59b1c5d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -544,4 +544,6 @@ ORT_API_STATUS_IMPL(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out ORT_API_STATUS_IMPL(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size); ORT_API_STATUS_IMPL(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); + +ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op); } // namespace OrtApis diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index ad668350d2307..c4b8ef925670d 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -6,17 +6,28 @@ inline void THROW_ON_ERROR(OrtStatus* status) { if (status != nullptr) abort(); } +void TestCompileBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { + THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", env, "outTreeEp")); + std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; + THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "outTreeEp", env, keys.data(), values.data(), keys.size())); +} + +void TestKernelBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { + THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp_kernel/build/libkernelEp.so", env, "kernelEp")); + std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; + THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "kernelEp", env, keys.data(), values.data(), keys.size())); +} + int main() { const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); OrtEnv* p_env = nullptr; OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;//OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env)); - THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", p_env, "outTreeEp")); - OrtSessionOptions* so = nullptr; THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); - std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; - THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "outTreeEp", p_env, keys.data(), values.data(), keys.size())); + + TestCompileBasedEp(g_ort, p_env, so); + //TestKernelBasedEp(g_ort, p_env, so); OrtSession* session = nullptr; THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", so, &session)); diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index fc9d81f9bb005..8aff8072de68d 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -3,7 +3,7 @@ #include namespace onnxruntime { -OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(ep_info) { +OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExecutionProvider(), info(ep_info) { type = ep_type; OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); @@ -78,7 +78,7 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(e } OutTreeEpFactory::OutTreeEpFactory() { - OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> void* { + OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { OutTreeEpInfo info; for (size_t i = 0; i < option_size; i++) { if (!strcmp(ep_option_keys[i], "int_property")) info.int_property = std::atoi(ep_option_values[i]); From e29499a196fd115f45e556534bebff393a7d2e7a Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 6 Aug 2024 15:41:47 +0000 Subject: [PATCH 10/81] kernel based EP work with type constraint check commented out --- onnxruntime/core/framework/kernel_registry.cc | 167 +++++++++--------- onnxruntime/core/framework/provider_adapter.h | 2 +- samples/c_test/test.cpp | 4 +- samples/outTreeEp_kernel/CMakeLists.txt | 12 ++ samples/outTreeEp_kernel/kernel_ep.cc | 89 ++++++++++ samples/outTreeEp_kernel/kernel_ep.h | 36 ++++ 6 files changed, 224 insertions(+), 86 deletions(-) create mode 100644 samples/outTreeEp_kernel/CMakeLists.txt create mode 100644 samples/outTreeEp_kernel/kernel_ep.cc create mode 100644 samples/outTreeEp_kernel/kernel_ep.h diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index d695e0e04c2b0..d42c3aefa87f3 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -14,91 +14,92 @@ namespace onnxruntime { namespace { -bool IsTypeProtoCompatible(gsl::span enabled_types, const ONNX_NAMESPACE::TypeProto& actual_type, - std::string& mismatch_reason) { - const bool is_type_compatible = std::any_of( - enabled_types.begin(), enabled_types.end(), - [&actual_type](const DataTypeImpl* expected_type) { - bool rc = expected_type->IsCompatible(actual_type); // for easier debugging - return rc; - }); - - if (!is_type_compatible) { - std::ostringstream ostr; - ostr << "This op has been implemented only for the following types ("; - for (const auto& enabled_type : enabled_types) { - ostr << DataTypeImpl::ToString(enabled_type) << ","; - } - ostr << "),"; - const char* actual_type_str = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(actual_type)); - ostr << " but the node in the model has the following type (" << actual_type_str << ")"; - mismatch_reason = ostr.str(); - return false; - } - - return true; -} +//bool IsTypeProtoCompatible(gsl::span enabled_types, const ONNX_NAMESPACE::TypeProto& actual_type, +// std::string& mismatch_reason) { +// const bool is_type_compatible = std::any_of( +// enabled_types.begin(), enabled_types.end(), +// [&actual_type](const DataTypeImpl* expected_type) { +// bool rc = expected_type->IsCompatible(actual_type); // for easier debugging +// return rc; +// }); +// +// if (!is_type_compatible) { +// std::ostringstream ostr; +// ostr << "This op has been implemented only for the following types ("; +// for (const auto& enabled_type : enabled_types) { +// ostr << DataTypeImpl::ToString(enabled_type) << ","; +// } +// ostr << "),"; +// const char* actual_type_str = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(actual_type)); +// ostr << " but the node in the model has the following type (" << actual_type_str << ")"; +// mismatch_reason = ostr.str(); +// return false; +// } +// +// return true; +//} // match the kernel using type info from the Node's args -bool MatchKernelDefTypes(const Node& node, - const std::unordered_map>& kernel_type_constraints, - const IKernelTypeStrResolver& kernel_type_str_resolver, - std::string& mismatch_reason) { - const auto actual_inputs = node.InputDefs(); - const auto actual_outputs = node.OutputDefs(); - const auto& actual_input_arg_counts = node.InputArgCount(); - const auto actual_input_arg_offsets = [&actual_input_arg_counts]() { - InlinedVector offsets{}; - offsets.reserve(actual_input_arg_counts.size()); - // std::exclusive_scan() is not supported until GCC 9.3 - // std::exclusive_scan(actual_input_arg_counts.begin(), actual_input_arg_counts.end(), - // std::back_inserter(offsets), 0); - int current_offset = 0; - for (size_t i = 0; i < actual_input_arg_counts.size(); ++i) { - offsets.push_back(current_offset); - current_offset += actual_input_arg_counts[i]; - } - return offsets; - }(); - - // for each type constraint - // map type constraint to arg - // check arg type against type constraint enabled types - for (const auto& [kernel_type_str, enabled_types] : kernel_type_constraints) { - gsl::span constraint_args{}; - ORT_THROW_IF_ERROR(kernel_type_str_resolver.ResolveKernelTypeStr(node, kernel_type_str, constraint_args)); - - for (const auto& [arg_type, formal_arg_idx] : constraint_args) { - const NodeArg* arg; - if (arg_type == ArgType::kInput) { - if (formal_arg_idx >= actual_input_arg_counts.size() || - actual_input_arg_counts[formal_arg_idx] == 0) { - arg = nullptr; - } else { - const auto first_arg_idx = actual_input_arg_offsets[formal_arg_idx]; - ORT_ENFORCE(static_cast(first_arg_idx) < actual_inputs.size()); - arg = actual_inputs[first_arg_idx]; - } - } else { - arg = formal_arg_idx < actual_outputs.size() ? actual_outputs[formal_arg_idx] : nullptr; - } - - if (arg && arg->Exists()) { - const ONNX_NAMESPACE::TypeProto* type_proto = arg->TypeAsProto(); - ORT_ENFORCE(type_proto != nullptr); - - if (!IsTypeProtoCompatible(enabled_types, *type_proto, mismatch_reason)) { - return false; - } - - // found a match, don't need to check other args with this constraint - break; - } - } - } - - return true; -} +bool MatchKernelDefTypes(const Node&, const std::unordered_map>&, const IKernelTypeStrResolver&, std::string&) { return true; } +//bool MatchKernelDefTypes(const Node& node, +// const std::unordered_map>& kernel_type_constraints, +// const IKernelTypeStrResolver& kernel_type_str_resolver, +// std::string& mismatch_reason) { +// const auto actual_inputs = node.InputDefs(); +// const auto actual_outputs = node.OutputDefs(); +// const auto& actual_input_arg_counts = node.InputArgCount(); +// const auto actual_input_arg_offsets = [&actual_input_arg_counts]() { +// InlinedVector offsets{}; +// offsets.reserve(actual_input_arg_counts.size()); +// // std::exclusive_scan() is not supported until GCC 9.3 +// // std::exclusive_scan(actual_input_arg_counts.begin(), actual_input_arg_counts.end(), +// // std::back_inserter(offsets), 0); +// int current_offset = 0; +// for (size_t i = 0; i < actual_input_arg_counts.size(); ++i) { +// offsets.push_back(current_offset); +// current_offset += actual_input_arg_counts[i]; +// } +// return offsets; +// }(); +// +// // for each type constraint +// // map type constraint to arg +// // check arg type against type constraint enabled types +// for (const auto& [kernel_type_str, enabled_types] : kernel_type_constraints) { +// gsl::span constraint_args{}; +// ORT_THROW_IF_ERROR(kernel_type_str_resolver.ResolveKernelTypeStr(node, kernel_type_str, constraint_args)); +// +// for (const auto& [arg_type, formal_arg_idx] : constraint_args) { +// const NodeArg* arg; +// if (arg_type == ArgType::kInput) { +// if (formal_arg_idx >= actual_input_arg_counts.size() || +// actual_input_arg_counts[formal_arg_idx] == 0) { +// arg = nullptr; +// } else { +// const auto first_arg_idx = actual_input_arg_offsets[formal_arg_idx]; +// ORT_ENFORCE(static_cast(first_arg_idx) < actual_inputs.size()); +// arg = actual_inputs[first_arg_idx]; +// } +// } else { +// arg = formal_arg_idx < actual_outputs.size() ? actual_outputs[formal_arg_idx] : nullptr; +// } +// +// if (arg && arg->Exists()) { +// const ONNX_NAMESPACE::TypeProto* type_proto = arg->TypeAsProto(); +// ORT_ENFORCE(type_proto != nullptr); +// +// if (!IsTypeProtoCompatible(enabled_types, *type_proto, mismatch_reason)) { +// return false; +// } +// +// // found a match, don't need to check other args with this constraint +// break; +// } +// } +// } +// +// return true; +//} bool MatchKernelDefTypes(const std::unordered_map>& kernel_type_constraints, const KernelRegistry::TypeConstraintMap& type_constraints) { diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 6b3e0f8fe419e..2b2ffe8101f89 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -17,7 +17,7 @@ class ExecutionProviderAdapter : public IExecutionProvider { virtual std::vector> GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup) const override { size_t cnt = 0; OrtIndexedSubGraph** indexed_subgraph = nullptr; - ep_impl_->GetCapability(ep_impl_, reinterpret_cast(&graph_viewer), &cnt, &indexed_subgraph); + if (ep_impl_->GetCapability) ep_impl_->GetCapability(ep_impl_, reinterpret_cast(&graph_viewer), &cnt, &indexed_subgraph); if (cnt == 0) return IExecutionProvider::GetCapability(graph_viewer, kernel_lookup); diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index c4b8ef925670d..7a6824d7a8c04 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -26,8 +26,8 @@ int main() { OrtSessionOptions* so = nullptr; THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); - TestCompileBasedEp(g_ort, p_env, so); - //TestKernelBasedEp(g_ort, p_env, so); + //TestCompileBasedEp(g_ort, p_env, so); + TestKernelBasedEp(g_ort, p_env, so); OrtSession* session = nullptr; THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", so, &session)); diff --git a/samples/outTreeEp_kernel/CMakeLists.txt b/samples/outTreeEp_kernel/CMakeLists.txt new file mode 100644 index 0000000000000..592719d630796 --- /dev/null +++ b/samples/outTreeEp_kernel/CMakeLists.txt @@ -0,0 +1,12 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug +# cmake --build ./ +cmake_minimum_required(VERSION 3.26) +project(kernelEp VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +add_library(kernelEp SHARED kernel_ep.cc) +target_include_directories(kernelEp PUBLIC "../../include/onnxruntime") + +# looks we need this in Win as in Windows you cannot build shared library with undefined symbol +#target_link_libraries(outTreeEp PUBLIC "/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so") diff --git a/samples/outTreeEp_kernel/kernel_ep.cc b/samples/outTreeEp_kernel/kernel_ep.cc new file mode 100644 index 0000000000000..f0c176a63ea6f --- /dev/null +++ b/samples/outTreeEp_kernel/kernel_ep.cc @@ -0,0 +1,89 @@ +#include "kernel_ep.h" +//#include "core/session/onnxruntime_lite_custom_op.h" +#include +#include +namespace onnxruntime { + +struct MyRelu : OrtCustomOp { + MyRelu() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::GetName = [](const struct OrtCustomOp* op) { return "Relu"; }; + OrtCustomOp::GetExecutionProviderType = [](const struct OrtCustomOp* op) { return "KernelEp"; }; + OrtCustomOp::CreateKernelV2 = [](const struct OrtCustomOp* op, const OrtApi* api, const OrtKernelInfo* info, void** kernel) -> OrtStatusPtr { + return nullptr; + }; + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + const OrtValue* input = nullptr; + api->KernelContext_GetInput(context, 0, &input); + std::vector dim(1,4); + OrtValue* output = nullptr; + api->KernelContext_GetOutput(context, 0, dim.data(), dim.size(), &output); + + float* input_raw = nullptr, *output_raw = nullptr; + api->GetTensorMutableData(const_cast(input), reinterpret_cast(&input_raw)); + api->GetTensorMutableData(output, reinterpret_cast(&output_raw)); + + for (int i = 0; i < 4; i++) { + output_raw[i] = input_raw[i]; + if (input_raw[i] < 0) output_raw[i] = 0; + + output_raw[i] = 2.0; + } + return nullptr; + }; + OrtCustomOp::GetInputTypeCount = [](const struct OrtCustomOp* op) -> size_t { return 1; }; + OrtCustomOp::GetOutputTypeCount = [](const struct OrtCustomOp* op) -> size_t { return 1; }; + OrtCustomOp::GetInputMemoryType = [](const struct OrtCustomOp* op, size_t index) { return OrtMemType::OrtMemTypeDefault; }; + OrtCustomOp::GetInputType = [](const struct OrtCustomOp* op, size_t index) { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + OrtCustomOp::GetOutputType = [](const struct OrtCustomOp* op, size_t index) { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + OrtCustomOp::GetStartVersion = [](const struct OrtCustomOp* op) { return 14; }; + } +}; + +//void MyRelu(const Ort::Custom::Tensor& X, Ort::Custom::Tensor& Y) { +// const auto& shape = X.Shape(); +// auto X_raw = X.Data(); +// auto Y_raw = Y.Allocate(shape); +// auto total = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); +// for (int64_t i = 0; i < total; i++) { +// Y_raw[i] = X_raw[i] > 0 ? X_raw[i] : 0; +// } +// std::cout<<"In MyRelu()\n"; +//} + +KernelEp::KernelEp(const char* ep_type, const KernelEpInfo& ep_info) : info(ep_info) { + type = ep_type; + OrtExecutionProvider::RegisterKernels = [](OrtKernelRegistry* kernel_registry) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + //Ort::Custom::OrtLiteCustomOp* op = Ort::Custom::CreateLiteCustomOp("Relu", "kernel_ep", MyRelu); + OrtCustomOp* op = new MyRelu(); + api->OrtKernelRegistry_RegisterKernel(kernel_registry, op); + }; +} + +KernelEpFactory::KernelEpFactory() { + OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { + KernelEpInfo info; + for (size_t i = 0; i < option_size; i++) { + if (!strcmp(ep_option_keys[i], "int_property")) info.int_property = std::atoi(ep_option_values[i]); + else if (!strcmp(ep_option_keys[i], "str_property")) info.str_property = ep_option_values[i]; + // TODO(leca): else throw + } + std::unique_ptr ret = std::make_unique("KernelEp", std::move(info)); + return ret.release(); + }; +} + +} + +#ifdef __cplusplus +extern "C" { +#endif +OrtExecutionProviderFactory* RegisterCustomEp() { + std::unique_ptr ret = std::make_unique(); + return ret.release(); +} +#ifdef __cplusplus +} +#endif diff --git a/samples/outTreeEp_kernel/kernel_ep.h b/samples/outTreeEp_kernel/kernel_ep.h new file mode 100644 index 0000000000000..85c0bec6c302e --- /dev/null +++ b/samples/outTreeEp_kernel/kernel_ep.h @@ -0,0 +1,36 @@ +#pragma once +#include "core/session/onnxruntime_c_api.h" +#include + +#ifdef _WIN32 +#define EXPORT_API __declspec(dllexport) +#else +#define EXPORT_API +#endif + +namespace onnxruntime { + +struct KernelEpInfo { + int int_property; + std::string str_property; +}; + +struct KernelEp : public OrtExecutionProvider { + KernelEp(const char* ep_type, const KernelEpInfo& ep_info); + KernelEpInfo info; +}; + +struct KernelEpFactory : public OrtExecutionProviderFactory { + KernelEpFactory(); +}; +} + +#ifdef __cplusplus +extern "C" { +#endif + +EXPORT_API OrtExecutionProviderFactory* RegisterCustomEp(); + +#ifdef __cplusplus +} +#endif From f3678c47588ea8b5204b8be49388c62c17e1c712 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 7 Aug 2024 00:34:17 +0000 Subject: [PATCH 11/81] add kernel type constraints from out tree EP --- .../core/framework/ort_type_constraints.h | 15 ++ .../core/session/onnxruntime_c_api.h | 7 +- onnxruntime/core/framework/kernel_registry.cc | 167 +++++++++--------- .../core/framework/ort_type_constraints.cc | 14 ++ onnxruntime/core/session/custom_ops.cc | 27 +-- onnxruntime/core/session/onnxruntime_c_api.cc | 19 +- onnxruntime/core/session/ort_apis.h | 6 +- samples/outTreeEp_kernel/kernel_ep.cc | 6 +- 8 files changed, 150 insertions(+), 111 deletions(-) create mode 100644 include/onnxruntime/core/framework/ort_type_constraints.h create mode 100644 onnxruntime/core/framework/ort_type_constraints.cc diff --git a/include/onnxruntime/core/framework/ort_type_constraints.h b/include/onnxruntime/core/framework/ort_type_constraints.h new file mode 100644 index 0000000000000..1224e56d58fb9 --- /dev/null +++ b/include/onnxruntime/core/framework/ort_type_constraints.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/session/onnxruntime_c_api.h" +#include +#include +#include + +struct OrtTypeConstraints { + bool AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type); + inline const std::unordered_map>& GetTypeConstraints() const { return type_constraints_; }; +private: + std::unordered_map> type_constraints_; +}; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 125aa54e00b81..723354a342209 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -309,6 +309,7 @@ ORT_RUNTIME_CLASS(ExecutionProviderFactory); ORT_RUNTIME_CLASS(Node); ORT_RUNTIME_CLASS(GraphViewer); ORT_RUNTIME_CLASS(KernelRegistry); +ORT_RUNTIME_CLASS(TypeConstraints); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -4740,7 +4741,11 @@ struct OrtApi { ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); - ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op); + ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); + + ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); + + ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); }; // struct OrtApi /* diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index d42c3aefa87f3..d695e0e04c2b0 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -14,92 +14,91 @@ namespace onnxruntime { namespace { -//bool IsTypeProtoCompatible(gsl::span enabled_types, const ONNX_NAMESPACE::TypeProto& actual_type, -// std::string& mismatch_reason) { -// const bool is_type_compatible = std::any_of( -// enabled_types.begin(), enabled_types.end(), -// [&actual_type](const DataTypeImpl* expected_type) { -// bool rc = expected_type->IsCompatible(actual_type); // for easier debugging -// return rc; -// }); -// -// if (!is_type_compatible) { -// std::ostringstream ostr; -// ostr << "This op has been implemented only for the following types ("; -// for (const auto& enabled_type : enabled_types) { -// ostr << DataTypeImpl::ToString(enabled_type) << ","; -// } -// ostr << "),"; -// const char* actual_type_str = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(actual_type)); -// ostr << " but the node in the model has the following type (" << actual_type_str << ")"; -// mismatch_reason = ostr.str(); -// return false; -// } -// -// return true; -//} +bool IsTypeProtoCompatible(gsl::span enabled_types, const ONNX_NAMESPACE::TypeProto& actual_type, + std::string& mismatch_reason) { + const bool is_type_compatible = std::any_of( + enabled_types.begin(), enabled_types.end(), + [&actual_type](const DataTypeImpl* expected_type) { + bool rc = expected_type->IsCompatible(actual_type); // for easier debugging + return rc; + }); + + if (!is_type_compatible) { + std::ostringstream ostr; + ostr << "This op has been implemented only for the following types ("; + for (const auto& enabled_type : enabled_types) { + ostr << DataTypeImpl::ToString(enabled_type) << ","; + } + ostr << "),"; + const char* actual_type_str = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(actual_type)); + ostr << " but the node in the model has the following type (" << actual_type_str << ")"; + mismatch_reason = ostr.str(); + return false; + } + + return true; +} // match the kernel using type info from the Node's args -bool MatchKernelDefTypes(const Node&, const std::unordered_map>&, const IKernelTypeStrResolver&, std::string&) { return true; } -//bool MatchKernelDefTypes(const Node& node, -// const std::unordered_map>& kernel_type_constraints, -// const IKernelTypeStrResolver& kernel_type_str_resolver, -// std::string& mismatch_reason) { -// const auto actual_inputs = node.InputDefs(); -// const auto actual_outputs = node.OutputDefs(); -// const auto& actual_input_arg_counts = node.InputArgCount(); -// const auto actual_input_arg_offsets = [&actual_input_arg_counts]() { -// InlinedVector offsets{}; -// offsets.reserve(actual_input_arg_counts.size()); -// // std::exclusive_scan() is not supported until GCC 9.3 -// // std::exclusive_scan(actual_input_arg_counts.begin(), actual_input_arg_counts.end(), -// // std::back_inserter(offsets), 0); -// int current_offset = 0; -// for (size_t i = 0; i < actual_input_arg_counts.size(); ++i) { -// offsets.push_back(current_offset); -// current_offset += actual_input_arg_counts[i]; -// } -// return offsets; -// }(); -// -// // for each type constraint -// // map type constraint to arg -// // check arg type against type constraint enabled types -// for (const auto& [kernel_type_str, enabled_types] : kernel_type_constraints) { -// gsl::span constraint_args{}; -// ORT_THROW_IF_ERROR(kernel_type_str_resolver.ResolveKernelTypeStr(node, kernel_type_str, constraint_args)); -// -// for (const auto& [arg_type, formal_arg_idx] : constraint_args) { -// const NodeArg* arg; -// if (arg_type == ArgType::kInput) { -// if (formal_arg_idx >= actual_input_arg_counts.size() || -// actual_input_arg_counts[formal_arg_idx] == 0) { -// arg = nullptr; -// } else { -// const auto first_arg_idx = actual_input_arg_offsets[formal_arg_idx]; -// ORT_ENFORCE(static_cast(first_arg_idx) < actual_inputs.size()); -// arg = actual_inputs[first_arg_idx]; -// } -// } else { -// arg = formal_arg_idx < actual_outputs.size() ? actual_outputs[formal_arg_idx] : nullptr; -// } -// -// if (arg && arg->Exists()) { -// const ONNX_NAMESPACE::TypeProto* type_proto = arg->TypeAsProto(); -// ORT_ENFORCE(type_proto != nullptr); -// -// if (!IsTypeProtoCompatible(enabled_types, *type_proto, mismatch_reason)) { -// return false; -// } -// -// // found a match, don't need to check other args with this constraint -// break; -// } -// } -// } -// -// return true; -//} +bool MatchKernelDefTypes(const Node& node, + const std::unordered_map>& kernel_type_constraints, + const IKernelTypeStrResolver& kernel_type_str_resolver, + std::string& mismatch_reason) { + const auto actual_inputs = node.InputDefs(); + const auto actual_outputs = node.OutputDefs(); + const auto& actual_input_arg_counts = node.InputArgCount(); + const auto actual_input_arg_offsets = [&actual_input_arg_counts]() { + InlinedVector offsets{}; + offsets.reserve(actual_input_arg_counts.size()); + // std::exclusive_scan() is not supported until GCC 9.3 + // std::exclusive_scan(actual_input_arg_counts.begin(), actual_input_arg_counts.end(), + // std::back_inserter(offsets), 0); + int current_offset = 0; + for (size_t i = 0; i < actual_input_arg_counts.size(); ++i) { + offsets.push_back(current_offset); + current_offset += actual_input_arg_counts[i]; + } + return offsets; + }(); + + // for each type constraint + // map type constraint to arg + // check arg type against type constraint enabled types + for (const auto& [kernel_type_str, enabled_types] : kernel_type_constraints) { + gsl::span constraint_args{}; + ORT_THROW_IF_ERROR(kernel_type_str_resolver.ResolveKernelTypeStr(node, kernel_type_str, constraint_args)); + + for (const auto& [arg_type, formal_arg_idx] : constraint_args) { + const NodeArg* arg; + if (arg_type == ArgType::kInput) { + if (formal_arg_idx >= actual_input_arg_counts.size() || + actual_input_arg_counts[formal_arg_idx] == 0) { + arg = nullptr; + } else { + const auto first_arg_idx = actual_input_arg_offsets[formal_arg_idx]; + ORT_ENFORCE(static_cast(first_arg_idx) < actual_inputs.size()); + arg = actual_inputs[first_arg_idx]; + } + } else { + arg = formal_arg_idx < actual_outputs.size() ? actual_outputs[formal_arg_idx] : nullptr; + } + + if (arg && arg->Exists()) { + const ONNX_NAMESPACE::TypeProto* type_proto = arg->TypeAsProto(); + ORT_ENFORCE(type_proto != nullptr); + + if (!IsTypeProtoCompatible(enabled_types, *type_proto, mismatch_reason)) { + return false; + } + + // found a match, don't need to check other args with this constraint + break; + } + } + } + + return true; +} bool MatchKernelDefTypes(const std::unordered_map>& kernel_type_constraints, const KernelRegistry::TypeConstraintMap& type_constraints) { diff --git a/onnxruntime/core/framework/ort_type_constraints.cc b/onnxruntime/core/framework/ort_type_constraints.cc new file mode 100644 index 0000000000000..108bd3bef55f5 --- /dev/null +++ b/onnxruntime/core/framework/ort_type_constraints.cc @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/ort_type_constraints.h" + +bool OrtTypeConstraints::AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type) { + std::unordered_map>::iterator iter = type_constraints_.find(type_symbol); + if (iter == type_constraints_.end()) { + std::set types{type}; + type_constraints_[type_symbol] = types; + return true; + } + return (iter->second).insert(type).second; +} diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 04b11848d6981..6dd50634d11c3 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -25,6 +25,7 @@ #include "core/session/inference_session.h" #include "core/session/ort_apis.h" #include "core/platform/threadpool.h" +#include "core/framework/ort_type_constraints.h" // NOTE: OrtKernelContext is used by both custom ops and compiled kernels. // In a minimal build, ORT_EXTENDED_MINIMAL_BUILD is used to enable EPs like CoreML/NNAPI which use compiled kernels, @@ -1335,9 +1336,8 @@ common::Status CreateCustomRegistry(gsl::span op_domai //namespace onnxruntime { class FuncManager; class OpKernelInfo; -onnxruntime::KernelCreateInfo CreateKernelCreateInfo2(const std::string& domain, const OrtCustomOp* op) { +onnxruntime::KernelCreateInfo CreateKernelCreateInfo2(const std::string& domain, const OrtCustomOp* op, OrtTypeConstraints* type_constraints) { const size_t input_count = op->GetInputTypeCount(op); - const size_t output_count = op->GetOutputTypeCount(op); onnxruntime::KernelDefBuilder def_builder; def_builder.SetName(op->GetName(op)) @@ -1363,25 +1363,10 @@ onnxruntime::KernelCreateInfo CreateKernelCreateInfo2(const std::string& domain, } } - for (size_t i = 0; i < input_count; i++) { - const auto input_type = op->GetInputType(op, i); - const auto input_name = "Input" + std::to_string(i); - if (input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { - def_builder.TypeConstraint(input_name, SUPPORTED_TENSOR_TYPES); - } else { - def_builder.TypeConstraint(input_name, - onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(static_cast(input_type))->AsTensorType()); - } - } - - for (size_t i = 0; i < output_count; i++) { - const auto output_type = op->GetOutputType(op, i); - const auto output_name = "Output" + std::to_string(i); - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { - def_builder.TypeConstraint(output_name, SUPPORTED_TENSOR_TYPES); - } else { - def_builder.TypeConstraint(output_name, - onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(static_cast(output_type))->AsTensorType()); + const std::unordered_map>& tc = type_constraints->GetTypeConstraints(); + for (const auto& [type_symbol, types] : tc) { + for (const auto& type : types) { + def_builder.TypeConstraint(type_symbol, onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(static_cast(type))->AsTensorType()); } } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 8dad6c1f8133e..b0a71c6fadabb 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -41,6 +41,7 @@ #include "core/common/string_helper.h" #include "core/framework/provider_factory_adapter.h" #include "core/framework/kernel_registry.h" +#include "core/framework/ort_type_constraints.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -116,7 +117,7 @@ using namespace onnxruntime; auto tensor = v->GetMutable(); // TODO(leca): try: namespace onnxruntime { KernelCreateInfo CreateKernelCreateInfo2(..); }, then define this function inside onnxruntime namespace -KernelCreateInfo CreateKernelCreateInfo2(const std::string& domain, const OrtCustomOp* op); +KernelCreateInfo CreateKernelCreateInfo2(const std::string& domain, const OrtCustomOp* op, OrtTypeConstraints* type_constraints); ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid, @@ -2435,12 +2436,22 @@ ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthOutputName, const OrtNode* node, size return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op) { +ORT_API_STATUS_IMPL(OrtApis::OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints) { KernelRegistry* kr = reinterpret_cast(kernel_registry); - KernelCreateInfo kci = CreateKernelCreateInfo2("", custom_op); + KernelCreateInfo kci = CreateKernelCreateInfo2("", custom_op, type_constraints); return ToOrtStatus(kr->Register(std::move(kci))); } +ORT_API_STATUS_IMPL(OrtApis::CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints) { + *type_constraints = new OrtTypeConstraints(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type) { + type_constraints->AddTypeConstraint(type_symbol, type); + return nullptr; +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2831,6 +2842,8 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtNode_GetOutputSize, &OrtApis::OrtNode_GetIthOutputName, &OrtApis::OrtKernelRegistry_RegisterKernel, + &OrtApis::CreateOrtTypeConstraints, + &OrtApis::AddTypeConstraint, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 473ddc59b1c5d..c7cfd62b6a127 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -545,5 +545,9 @@ ORT_API_STATUS_IMPL(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* ou ORT_API_STATUS_IMPL(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); -ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op); +ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); + +ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); + +ORT_API_STATUS_IMPL(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); } // namespace OrtApis diff --git a/samples/outTreeEp_kernel/kernel_ep.cc b/samples/outTreeEp_kernel/kernel_ep.cc index f0c176a63ea6f..a613e62f1c17d 100644 --- a/samples/outTreeEp_kernel/kernel_ep.cc +++ b/samples/outTreeEp_kernel/kernel_ep.cc @@ -57,8 +57,12 @@ KernelEp::KernelEp(const char* ep_type, const KernelEpInfo& ep_info) : info(ep_i OrtExecutionProvider::RegisterKernels = [](OrtKernelRegistry* kernel_registry) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); //Ort::Custom::OrtLiteCustomOp* op = Ort::Custom::CreateLiteCustomOp("Relu", "kernel_ep", MyRelu); + + OrtTypeConstraints* type_constraints = nullptr; + api->CreateOrtTypeConstraints(&type_constraints); + api->AddTypeConstraint(type_constraints, "T", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); OrtCustomOp* op = new MyRelu(); - api->OrtKernelRegistry_RegisterKernel(kernel_registry, op); + api->OrtKernelRegistry_RegisterKernel(kernel_registry, op, type_constraints); }; } From ac5ae0ae99906aba59f3701e50d5e4a1894c561a Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 7 Aug 2024 16:27:40 +0000 Subject: [PATCH 12/81] add API ReleaseOrtTypeConstraints --- include/onnxruntime/core/session/onnxruntime_c_api.h | 2 ++ onnxruntime/core/session/custom_ops.cc | 1 - onnxruntime/core/session/onnxruntime_c_api.cc | 9 ++++++++- onnxruntime/core/session/ort_apis.h | 2 ++ samples/outTreeEp_kernel/kernel_ep.cc | 1 + 5 files changed, 13 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 723354a342209..d590f13da1243 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4746,6 +4746,8 @@ struct OrtApi { ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); + + ORT_API2_STATUS(ReleaseOrtTypeConstraints, _In_ OrtTypeConstraints* type_constraints); }; // struct OrtApi /* diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 6dd50634d11c3..74cd822718b2f 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -1403,6 +1403,5 @@ onnxruntime::KernelCreateInfo CreateKernelCreateInfo2(const std::string& domain, }; return onnxruntime::KernelCreateInfo(def_builder.Build(), kernel_create_fn); -// return onnxruntime::KernelCreateInfo(def_builder.Build(), nullptr); } //} // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b0a71c6fadabb..90af21f1337c1 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2443,7 +2443,8 @@ ORT_API_STATUS_IMPL(OrtApis::OrtKernelRegistry_RegisterKernel, OrtKernelRegistry } ORT_API_STATUS_IMPL(OrtApis::CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints) { - *type_constraints = new OrtTypeConstraints(); + std::unique_ptr otc = std::make_unique(); + *type_constraints = otc.release(); return nullptr; } @@ -2452,6 +2453,11 @@ ORT_API_STATUS_IMPL(OrtApis::AddTypeConstraint, _In_ OrtTypeConstraints* type_co return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::ReleaseOrtTypeConstraints, _In_ OrtTypeConstraints* type_constraints) { + delete type_constraints; + return nullptr; +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2844,6 +2850,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtKernelRegistry_RegisterKernel, &OrtApis::CreateOrtTypeConstraints, &OrtApis::AddTypeConstraint, + &OrtApis::ReleaseOrtTypeConstraints, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c7cfd62b6a127..10d42b49c4124 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -550,4 +550,6 @@ ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_ ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); ORT_API_STATUS_IMPL(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); + +ORT_API_STATUS_IMPL(ReleaseOrtTypeConstraints, _In_ OrtTypeConstraints* type_constraints); } // namespace OrtApis diff --git a/samples/outTreeEp_kernel/kernel_ep.cc b/samples/outTreeEp_kernel/kernel_ep.cc index a613e62f1c17d..f50536756aa6d 100644 --- a/samples/outTreeEp_kernel/kernel_ep.cc +++ b/samples/outTreeEp_kernel/kernel_ep.cc @@ -63,6 +63,7 @@ KernelEp::KernelEp(const char* ep_type, const KernelEpInfo& ep_info) : info(ep_i api->AddTypeConstraint(type_constraints, "T", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); OrtCustomOp* op = new MyRelu(); api->OrtKernelRegistry_RegisterKernel(kernel_registry, op, type_constraints); + api->ReleaseOrtTypeConstraints(type_constraints); }; } From 0cc78e8a95e8c7cbb9128233a2250b6fd7622fd5 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 12 Aug 2024 21:24:23 +0000 Subject: [PATCH 13/81] introduce qnn ep --- .../core/session/onnxruntime_c_api.h | 14 + samples/qnnEp/CMakeLists.txt | 30 + samples/qnnEp/builder/qnn_def.cc | 553 +++++++++++++++ samples/qnnEp/builder/qnn_def.h | 506 ++++++++++++++ samples/qnnEp/builder/qnn_model_wrapper.cc | 627 ++++++++++++++++++ samples/qnnEp/builder/qnn_model_wrapper.h | 285 ++++++++ .../qnnEp/builder/qnn_quant_params_wrapper.cc | 266 ++++++++ .../qnnEp/builder/qnn_quant_params_wrapper.h | 147 ++++ samples/qnnEp/builder/qnn_utils.cc | 557 ++++++++++++++++ samples/qnnEp/builder/qnn_utils.h | 110 +++ samples/qnnEp/qnn_execution_provider.cc | 54 ++ samples/qnnEp/qnn_execution_provider.h | 33 + 12 files changed, 3182 insertions(+) create mode 100644 samples/qnnEp/CMakeLists.txt create mode 100644 samples/qnnEp/builder/qnn_def.cc create mode 100644 samples/qnnEp/builder/qnn_def.h create mode 100644 samples/qnnEp/builder/qnn_model_wrapper.cc create mode 100644 samples/qnnEp/builder/qnn_model_wrapper.h create mode 100644 samples/qnnEp/builder/qnn_quant_params_wrapper.cc create mode 100644 samples/qnnEp/builder/qnn_quant_params_wrapper.h create mode 100644 samples/qnnEp/builder/qnn_utils.cc create mode 100644 samples/qnnEp/builder/qnn_utils.h create mode 100644 samples/qnnEp/qnn_execution_provider.cc create mode 100644 samples/qnnEp/qnn_execution_provider.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d590f13da1243..8db7a5401f53d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -310,6 +310,7 @@ ORT_RUNTIME_CLASS(Node); ORT_RUNTIME_CLASS(GraphViewer); ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(TypeConstraints); +ORT_RUNTIME_CLASS(NodeUnit); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -743,6 +744,19 @@ typedef struct OrtExecutionProviderFactory { OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); } OrtExecutionProviderFactory; +typedef struct OrtNodeUnit { + enum Type { + SingleNode, + QDQGroup, + } type; + OrtNode** dq_nodes; + size_t dq_nodes_len; + OrtNode** q_nodes; + size_t q_nodes_len; + OrtNode* target_node; + size_t input_edge_count; +} OrtNodeUnit; + /** \brief Thread work loop function * * Onnxruntime will provide the working loop on custom thread creation diff --git a/samples/qnnEp/CMakeLists.txt b/samples/qnnEp/CMakeLists.txt new file mode 100644 index 0000000000000..3319516ed3f1c --- /dev/null +++ b/samples/qnnEp/CMakeLists.txt @@ -0,0 +1,30 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug +# cmake --build ./ +cmake_minimum_required(VERSION 3.26) +project(QnnEp VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +if (MSVC) +else() + set(QNN_ARCH_ABI aarch64-android) +endif() + +add_definitions(-DONNX_NAMESPACE=onnx) +add_definitions(-DONNX_ML) +file(GLOB_RECURSE qnn_src "./*.cc") +#message(STATUS "qnn_src=${qnn_src}") +add_library(QnnEp SHARED ${qnn_src}) +target_include_directories(QnnEp PUBLIC "../../include/onnxruntime" + "/home/leca/qnn-v2.25.0.240728/include/QNN" + "../../build/Linux/Debug/_deps/gsl-src/include" + "../../build/Linux/Debug/_deps/onnx-src" + "../../build/Linux/Debug/_deps/onnx-build" + "../../build/Linux/Debug/_deps/protobuf-src/src") + +# looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol +target_link_libraries(QnnEp PUBLIC #"/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotocd.a") diff --git a/samples/qnnEp/builder/qnn_def.cc b/samples/qnnEp/builder/qnn_def.cc new file mode 100644 index 0000000000000..06a52ca6fddc4 --- /dev/null +++ b/samples/qnnEp/builder/qnn_def.cc @@ -0,0 +1,553 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "qnn_def.h" +#include "qnn_utils.h" +#include +#include +#include + +namespace onnxruntime { +namespace qnn { + +size_t memscpy(void* dst, size_t dst_size, const void* src, size_t copy_size) { + if (!dst || !src || !dst_size || !copy_size) return 0; + + size_t min_size = dst_size < copy_size ? dst_size : copy_size; + + memcpy(dst, src, min_size); + + return min_size; +} + +void SetQnnTensorType(Qnn_Tensor_t& qnn_tensor, Qnn_TensorType_t tensor_type) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.type = tensor_type; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.type = tensor_type; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorName(Qnn_Tensor_t& qnn_tensor, const char* name) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.name = name; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.name = name; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorDataFormat(Qnn_Tensor_t& qnn_tensor, Qnn_TensorDataFormat_t data_format) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.dataFormat = data_format; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.dataFormat = data_format; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorDataType(Qnn_Tensor_t& qnn_tensor, Qnn_DataType_t data_type) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.dataType = data_type; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.dataType = data_type; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorDim(Qnn_Tensor_t& qnn_tensor, const std::vector& dimensions) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.rank = static_cast(dimensions.size()); + qnn_tensor.v1.dimensions = const_cast(dimensions.data()); + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.rank = static_cast(dimensions.size()); + qnn_tensor.v2.dimensions = const_cast(dimensions.data()); + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorMemType(Qnn_Tensor_t& qnn_tensor, Qnn_TensorMemType_t mem_type) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.memType = mem_type; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.memType = mem_type; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + auto size = client_buf.size() * sizeof(uint8_t); + qnn_tensor.v1.clientBuf.data = const_cast(static_cast(client_buf.data())); + qnn_tensor.v1.clientBuf.dataSize = static_cast(size); + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + auto size = client_buf.size() * sizeof(uint8_t); + qnn_tensor.v2.clientBuf.data = const_cast(static_cast(client_buf.data())); + qnn_tensor.v2.clientBuf.dataSize = static_cast(size); + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + auto size = client_buf.size() * sizeof(uint32_t); + qnn_tensor.v1.clientBuf.data = const_cast(static_cast(client_buf.data())); + qnn_tensor.v1.clientBuf.dataSize = static_cast(size); + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + auto size = client_buf.size() * sizeof(uint32_t); + qnn_tensor.v2.clientBuf.data = const_cast(static_cast(client_buf.data())); + qnn_tensor.v2.clientBuf.dataSize = static_cast(size); + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.clientBuf.data = buf_data; + qnn_tensor.v1.clientBuf.dataSize = buf_size; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.clientBuf.data = buf_data; + qnn_tensor.v2.clientBuf.dataSize = buf_size; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.clientBuf.dataSize = client_buf_size; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.clientBuf.dataSize = client_buf_size; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.clientBuf.data = client_buf_data; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.clientBuf.data = client_buf_data; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.quantizeParams = quantize_params; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.quantizeParams = quantize_params; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +uint32_t GetQnnTensorID(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.id; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.id; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +Qnn_TensorType_t GetQnnTensorType(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.type; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.type; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +const char* GetQnnTensorName(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.name; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.name; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +Qnn_TensorDataFormat_t GetQnnTensorDataFormat(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.dataFormat; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.dataFormat; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +Qnn_DataType_t GetQnnTensorDataType(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.dataType; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.dataType; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +Qnn_TensorMemType_t GetQnnTensorMemType(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.memType; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.memType; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +uint32_t GetQnnTensorRank(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.rank; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.rank; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +uint32_t* GetQnnTensorDims(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.dimensions; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.dimensions; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.clientBuf; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.clientBuf; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.quantizeParams; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.quantizeParams; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + +Status CompareQnnQuantParams(const Qnn_QuantizeParams_t& qparam0, const Qnn_QuantizeParams_t& qparam1, + float& scale_diff, int32_t& offset_diff) { + scale_diff = 0.0f; + offset_diff = 0; + + ORT_RETURN_IF_NOT((qparam0.encodingDefinition == qparam1.encodingDefinition && + qparam0.quantizationEncoding == qparam1.quantizationEncoding), + "Expected quantization parameters to be the same type."); + + if (qparam0.encodingDefinition == QNN_DEFINITION_DEFINED) { + switch (qparam0.quantizationEncoding) { + case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: { + scale_diff = std::abs(qparam0.scaleOffsetEncoding.scale - qparam1.scaleOffsetEncoding.scale); + offset_diff = std::abs(qparam0.scaleOffsetEncoding.offset - qparam1.scaleOffsetEncoding.offset); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported quantization encoding: ", qparam0.quantizationEncoding); + } + } + + return Status::OK(); +} + +bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_GraphHandle_t& graph, + const std::string& node_name, + const std::string& tensor_name, + Qnn_Tensor_t& qnn_tensor, + std::unordered_map& tensors_created_table, + std::string& error_msg) { + if (tensors_created_table.find(tensor_name) != tensors_created_table.end()) { + error_msg = "Tensor created already: " + tensor_name; + return true; + } + + auto qnn_data_type = GetQnnTensorDataType(qnn_tensor); + size_t data_size = utils::GetElementSizeByType(qnn_data_type); + + std::stringstream ss; + if (0 == data_size) { + ss << "Invalid QNN data type provided, " + << qnn_data_type << ", for tensor " << tensor_name + << " on node " << node_name; + error_msg = ss.str(); + return false; + } + + // sanity check tensor data if AddTensor used for static tensor + auto qnn_tensor_type = GetQnnTensorType(qnn_tensor); + if (qnn_tensor_type == QNN_TENSOR_TYPE_STATIC) { + if (GetQnnTensorMemType(qnn_tensor) != QNN_TENSORMEMTYPE_RAW) { + ss << "Expected raw memType in provided static tensor " + << tensor_name << "for node " << node_name; + error_msg = ss.str(); + return false; + } + // verify size expressed by the dims matches the raw tensor size + auto qnn_tensor_dims = GetQnnTensorDims(qnn_tensor); + auto qnn_tensor_rank = GetQnnTensorRank(qnn_tensor); + uint32_t qnn_tensor_size = std::accumulate(qnn_tensor_dims, + qnn_tensor_dims + qnn_tensor_rank, + static_cast(data_size), + std::multiplies()); + auto qnn_tensor_buf_size = GetQnnTensorClientBuf(qnn_tensor).dataSize; + if (qnn_tensor_size != qnn_tensor_buf_size) { + ss << "Data length mismatch for static tensor. node_name: " << node_name + << " tensor_name: " << tensor_name + << ". size calculated from shape: " << qnn_tensor_size + << ", tensor.clientBuf.dataSize: " << qnn_tensor_buf_size; + error_msg = ss.str(); + return false; + } + } + + auto tensor_create_result = qnn_interface.tensorCreateGraphTensor(graph, &qnn_tensor); + if (tensor_create_result != QNN_TENSOR_NO_ERROR) { + ss << "Failed to create tensor for node: " << node_name + << " tensor_name: " << tensor_name + << " error code: " << tensor_create_result; + error_msg = ss.str(); + return false; + } + + tensors_created_table.emplace(tensor_name, true); + return true; +} + +bool QnnParamWrapper::CreateQnnGraphParam(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_GraphHandle_t& graph, + const std::string& node_name, + std::unordered_map& tensors_created_table, + std::string& error_msg) { + std::stringstream ss; + switch (qnn_param_.paramType) { + case QNN_PARAMTYPE_TENSOR: { + return CreateTensorInQnnGraph(qnn_interface, graph, node_name, tensor_name_, + qnn_param_.tensorParam, tensors_created_table, error_msg); + } + case QNN_PARAMTYPE_SCALAR: { + ss << "Add scalar parameter: " << name_; + error_msg = ss.str(); + return true; + } + default: { + ss << "Unknown param type passed for param: " + << name_ << " on node: " << node_name; + error_msg = ss.str(); + return true; + } + } + + return true; +} + +void QnnOpConfigWrapper::SetNames(const char* op_name, + const char* package_name, + const char* type_name) { + if (QNN_OPCONFIG_VERSION_1 == op_config_.version) { + op_config_.v1.name = op_name; + op_config_.v1.packageName = package_name; + op_config_.v1.typeName = type_name; + } else { + ORT_THROW("QNN OpConfig version not supported, QNN OpConfig version: ", op_config_.version); + } +} + +void QnnOpConfigWrapper::SetNums(uint32_t num_inputs, + uint32_t num_outputs, + uint32_t num_params) { + if (QNN_OPCONFIG_VERSION_1 == op_config_.version) { + op_config_.v1.numOfInputs = num_inputs; + op_config_.v1.numOfOutputs = num_outputs; + op_config_.v1.numOfParams = num_params; + } else { + ORT_THROW("QNN OpConfig version not supported, QNN OpConfig version: ", op_config_.version); + } +} + +void QnnOpConfigWrapper::SetData(Qnn_Tensor_t* input_tensors, + Qnn_Tensor_t* output_tensors, + Qnn_Param_t* params) { + if (QNN_OPCONFIG_VERSION_1 == op_config_.version) { + op_config_.v1.inputTensors = input_tensors; + op_config_.v1.outputTensors = output_tensors; + op_config_.v1.params = params; + } else { + ORT_THROW("QNN OpConfig version not supported, QNN OpConfig version: ", op_config_.version); + } +} + +bool QnnOpConfigWrapper::QnnGraphOpValidation(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_BackendHandle_t& backend_handle, + std::string& error_msg) { + auto validation_status = qnn_interface.backendValidateOpConfig(backend_handle, op_config_); + if (QNN_SUCCESS != validation_status) { + std::ostringstream oss; + oss << "QNN.backendValidateOpConfig() failed for node `" << name_ << "` of type `" + << type_name_ << "` with error code " << validation_status << std::endl; + error_msg = oss.str(); + return false; + } + + return true; +} + +bool QnnOpConfigWrapper::CreateQnnGraphOp(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_GraphHandle_t& graph, + std::string& error_msg) { + auto status = qnn_interface.graphAddNode(graph, op_config_); + if (QNN_GRAPH_NO_ERROR != status) { + std::ostringstream oss; + oss << "QNN.graphAddNode() failed for node `" << name_ << "` of type `" << type_name_ + << "` with error code " << status << std::endl; + error_msg = oss.str(); + return false; + } + + return true; +} + +bool IsNpuBackend(QnnBackendType backend_type) { + return backend_type == QnnBackendType::HTP || backend_type == QnnBackendType::DSP; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/samples/qnnEp/builder/qnn_def.h b/samples/qnnEp/builder/qnn_def.h new file mode 100644 index 0000000000000..f4ac96db9657b --- /dev/null +++ b/samples/qnnEp/builder/qnn_def.h @@ -0,0 +1,506 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "QnnInterface.h" +#include +#include +#include +#include +#include +#include +#include "core/graph/basic_types.h" +#include "core/common/common.h" +#include "qnn_quant_params_wrapper.h" + +namespace onnxruntime { +namespace qnn { +// QNN only support subset of POSIX of dlopen/dlsym/dladdr/dlerror/dlclose +// except the following flags for dlopen, others should be done only +// when we really need them +// DL_NOW is MUST +// DL_LOCAL is enabled if not specified +enum class DlOpenFlag : int { + DL_NOW = 0x0001, + DL_LOCAL = 0x0002, + DL_GLOBAL = 0x0004, +}; + +// specify this address to distinguish from NULL pointer +#define DL_DEFAULT (void*)(0x4) + +enum class ProfilingLevel : uint8_t { + OFF = 0, + BASIC, + DETAILED, + INVALID +}; + +// Defines performance modes available for HTP backend. +enum class HtpPerformanceMode : uint8_t { + kHtpDefault = 0, + kHtpSustainedHighPerformance, + kHtpBurst, + kHtpHighPerformance, + kHtpPowerSaver, + kHtpLowPowerSaver, + kHtpHighPowerSaver, + kHtpLowBalanced, + kHtpBalanced, + kHtpExtremePowerSaver, +}; + +enum class ContextPriority : uint8_t { + LOW = 0, + NORMAL, + NORMAL_HIGH, + HIGH, + UNDEFINED +}; + +// Defines the graph optimization strategy used by the HTP backend. +enum class HtpGraphFinalizationOptimizationMode : uint8_t { + kDefault = 0, + kMode1 = 1, // Faster preparation time, less optimal graph + kMode2 = 2, // Longer preparation time, more optimal graph + kMode3 = 3, // Longest preparation time, most likely even more optimal graph. +}; + +enum class QnnBackendType : uint8_t { + CPU = 0, + GPU, + DSP, + HTP, + HTP_FP16 +}; + +bool IsNpuBackend(QnnBackendType backend_type); + +// constexpr config values +constexpr const int kSleepMinLatency = 40; +constexpr const int kSleepLowLatency = 100; +constexpr const int kSleepMediumLatency = 1000; +constexpr const int kSleepHighLatency = 2000; +constexpr const int kDcvsDisable = 0; +constexpr const int kDcvsEnable = 1; + +struct OnnxTensorInfo { + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OnnxTensorInfo); + OnnxTensorInfo(size_t index, int32_t data_type, std::vector&& shape) : index_(index), data_type_(data_type), shape_(std::move(shape)) {} + size_t index_; + const int32_t data_type_; // Uses TensorProto::DataType + const std::vector shape_; +}; + +size_t memscpy(void* dst, size_t dst_size, const void* src, size_t copy_size); + +void SetQnnTensorType(Qnn_Tensor_t& qnn_tensor, Qnn_TensorType_t tensor_type); +void SetQnnTensorName(Qnn_Tensor_t& qnn_tensor, const char* name); +void SetQnnTensorDataFormat(Qnn_Tensor_t& qnn_tensor, Qnn_TensorDataFormat_t data_format); +void SetQnnTensorDataType(Qnn_Tensor_t& qnn_tensor, Qnn_DataType_t data_type); +void SetQnnTensorDim(Qnn_Tensor_t& qnn_tensor, const std::vector& dimensions); +void SetQnnTensorMemType(Qnn_Tensor_t& qnn_tensor, Qnn_TensorMemType_t mem_type); +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf); +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf); +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size); +void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size); +void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data); +void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params); +bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_GraphHandle_t& graph, + const std::string& node_name, + const std::string& tensor_name, + Qnn_Tensor_t& qnn_tensor, + std::unordered_map& tensors_created_table, + std::string& error_msg); + +uint32_t GetQnnTensorID(const Qnn_Tensor_t& qnn_tensor); +Qnn_TensorType_t GetQnnTensorType(const Qnn_Tensor_t& qnn_tensor); +const char* GetQnnTensorName(const Qnn_Tensor_t& qnn_tensor); +Qnn_TensorDataFormat_t GetQnnTensorDataFormat(const Qnn_Tensor_t& qnn_tensor); +Qnn_DataType_t GetQnnTensorDataType(const Qnn_Tensor_t& qnn_tensor); +Qnn_TensorMemType_t GetQnnTensorMemType(const Qnn_Tensor_t& qnn_tensor); +uint32_t GetQnnTensorRank(const Qnn_Tensor_t& qnn_tensor); +uint32_t* GetQnnTensorDims(const Qnn_Tensor_t& qnn_tensor); +const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor); +const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor); + +/** + * Compares two sets of quantization parameters. Sets the parameters `scale_diff` and `offset_diff` + * to the absolute differences. Returns an error status if the quantization parameters are not + * of the same type, or if the type is not supported. + * + * \param qparam0 The first set of quantization parameters. + * \param qparam1 The second set of quantization parameters. + * \param scale_diff Set to the absolute value of the difference in scale value. + * \param offset_diff Set to the absolute value of the difference in offset value. + * \return Status indicating success. + */ +Status CompareQnnQuantParams(const Qnn_QuantizeParams_t& qparam0, const Qnn_QuantizeParams_t& qparam1, + float& max_scale_diff, int32_t& max_offset_diff); + +// TODO: split out separate files for Wrappers +class QnnTensorWrapper { + public: + QnnTensorWrapper(const std::string& name, + Qnn_TensorType_t tensor_type, + Qnn_DataType_t data_type, + QnnQuantParamsWrapper&& quantize_params, + std::vector&& shape, + std::vector&& client_buf = {}, + Qnn_TensorMemType_t mem_type = QNN_TENSORMEMTYPE_RAW) : tensor_name_(name), + dimensions_(std::move(shape)), + client_buf_(std::move(client_buf)), + quant_params_(quantize_params) { + SetQnnTensorType(qnn_tensor_, tensor_type); + SetQnnTensorName(qnn_tensor_, tensor_name_.c_str()); + SetQnnTensorDataType(qnn_tensor_, data_type); + SetQnnTensorDim(qnn_tensor_, dimensions_); + SetQnnTensorMemType(qnn_tensor_, mem_type); + if (QNN_TENSOR_TYPE_STATIC == tensor_type) { + SetQnnTensorClientBuf(qnn_tensor_, client_buf_); + } + + if (mem_type != QNN_TENSORMEMTYPE_RAW) { + ORT_THROW("mem_type not supported for now."); + } + + SetQnnTensorQParams(qnn_tensor_, quant_params_.Get()); + } + + // Initialize from a raw Qnn_Tensor_t. This method is currently used for graph inputs/outputs + // when deserializing from cached context object. Possible return errors due to: + // - Unexpected Qnn_TensorType_t: only handle graph inputs/outputs, not static initializers with data buffers. + // - Unexpected quantization encoding. + Status Init(const Qnn_Tensor_t& qnn_tensor) { + Qnn_TensorType_t tensor_type = GetQnnTensorType(qnn_tensor); + ORT_RETURN_IF(tensor_type == QNN_TENSOR_TYPE_STATIC, + "QnnTensorWrapper::Init(const Qnn_Tensor_t&) does not support static initializers"); + + tensor_name_ = GetQnnTensorName(qnn_tensor); + client_buf_.clear(); + + qnn_tensor_ = qnn_tensor; + SetQnnTensorName(qnn_tensor_, tensor_name_.c_str()); + + const Qnn_QuantizeParams_t& src_quantize_param = GetQnnTensorQParams(qnn_tensor); + ORT_RETURN_IF_ERROR(quant_params_.Init(src_quantize_param)); + SetQnnTensorQParams(qnn_tensor_, quant_params_.Get()); + + uint32_t shape_rank = GetQnnTensorRank(qnn_tensor); + uint32_t* shape_data = GetQnnTensorDims(qnn_tensor); + dimensions_.assign(shape_data, shape_data + shape_rank); + SetQnnTensorDim(qnn_tensor_, dimensions_); + + SetQnnTensorMemType(qnn_tensor_, QNN_TENSORMEMTYPE_RAW); + + return Status::OK(); + } + + QnnTensorWrapper() = default; + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnTensorWrapper); + + QnnTensorWrapper(QnnTensorWrapper&& other) noexcept { + SwapOther(std::move(other)); + } + + QnnTensorWrapper& operator=(QnnTensorWrapper&& other) noexcept { + if (this != &other) { + SwapOther(std::move(other)); + } + + return *this; + } + + ~QnnTensorWrapper() = default; + + const Qnn_Tensor_t& GetQnnTensor() const { + return qnn_tensor_; + } + + Qnn_Tensor_t& GetQnnTensor() { + return qnn_tensor_; + } + + const QnnQuantParamsWrapper& GetQnnQuantParams() const { + return quant_params_; + } + + QnnQuantParamsWrapper& GetQnnQuantParams() { + return quant_params_; + } + + const std::string& GetName() const { return tensor_name_; } + + Qnn_TensorType_t GetTensorType() const { return GetQnnTensorType(qnn_tensor_); } + Qnn_DataType_t GetTensorDataType() const { return GetQnnTensorDataType(qnn_tensor_); } + uint32_t GetTensorRank() const { return static_cast(dimensions_.size()); } + const std::vector& GetTensorDims() const { return dimensions_; } + + bool CreateQnnGraphTensor(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_GraphHandle_t& graph, + const std::string& node_name, + std::unordered_map& tensors_created_table, + std::string& error_msg) { + return CreateTensorInQnnGraph(qnn_interface, graph, node_name, tensor_name_, + qnn_tensor_, tensors_created_table, error_msg); + } + + private: + void SwapOther(QnnTensorWrapper&& other) noexcept { + std::swap(tensor_name_, other.tensor_name_); + std::swap(dimensions_, other.dimensions_); + std::swap(client_buf_, other.client_buf_); + std::swap(quant_params_, other.quant_params_); + std::swap(qnn_tensor_, other.qnn_tensor_); + SetQnnTensorName(qnn_tensor_, tensor_name_.c_str()); + SetQnnTensorDim(qnn_tensor_, dimensions_); + SetQnnTensorClientBuf(qnn_tensor_, client_buf_); + SetQnnTensorQParams(qnn_tensor_, quant_params_.Get()); + } + + std::string tensor_name_; + std::vector dimensions_; + std::vector client_buf_; + Qnn_Tensor_t qnn_tensor_ = QNN_TENSOR_INIT; + QnnQuantParamsWrapper quant_params_; +}; + +class QnnParamWrapper { + public: + QnnParamWrapper(NodeIndex node_index, + const std::string& node_name, + const std::string& name, + Qnn_Scalar_t scalarParam) : name_(name), shape_({}), param_data_({}) { + qnn_param_.paramType = QNN_PARAMTYPE_SCALAR; + qnn_param_.name = name_.c_str(); + std::stringstream ss; + ss << node_name << "_" << node_index << "_" << name; + tensor_name_ = ss.str(); + qnn_param_.scalarParam = scalarParam; + } + + QnnParamWrapper(NodeIndex node_index, + const std::string& node_name, + const std::string& name, + std::vector&& shape, + std::vector&& param_data, + bool is_signed = false) : name_(name), shape_(std::move(shape)), param_data_(std::move(param_data)) { + qnn_param_.paramType = QNN_PARAMTYPE_TENSOR; + qnn_param_.name = name_.c_str(); + std::stringstream ss; + ss << node_name << "_" << node_index << "_" << name; + tensor_name_ = ss.str(); + qnn_param_.tensorParam = QNN_TENSOR_INIT; + SetQnnTensorType(qnn_param_.tensorParam, QNN_TENSOR_TYPE_STATIC); + SetQnnTensorName(qnn_param_.tensorParam, tensor_name_.c_str()); + SetQnnTensorDataType(qnn_param_.tensorParam, is_signed ? QNN_DATATYPE_INT_32 : QNN_DATATYPE_UINT_32); + SetQnnTensorDim(qnn_param_.tensorParam, shape_); + SetQnnTensorMemType(qnn_param_.tensorParam, QNN_TENSORMEMTYPE_RAW); + SetQnnTensorClientBuf(qnn_param_.tensorParam, param_data_); + } + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnParamWrapper); + QnnParamWrapper(QnnParamWrapper&& other) noexcept { + std::swap(name_, other.name_); + std::swap(tensor_name_, other.tensor_name_); + std::swap(shape_, other.shape_); + std::swap(param_data_, other.param_data_); + std::swap(qnn_param_, other.qnn_param_); + qnn_param_.name = name_.c_str(); + if (qnn_param_.paramType == QNN_PARAMTYPE_TENSOR) { + SetQnnTensorName(qnn_param_.tensorParam, tensor_name_.c_str()); + SetQnnTensorDim(qnn_param_.tensorParam, shape_); + SetQnnTensorClientBuf(qnn_param_.tensorParam, param_data_); + } + } + + ~QnnParamWrapper() = default; + + const std::string& GetName() const { + return name_; + } + + const std::string& GetParamTensorName() const { + return tensor_name_; + } + + const Qnn_Param_t& GetQnnParam() const { + return qnn_param_; + } + + Qnn_Param_t& GetQnnParam() { + return qnn_param_; + } + + bool CreateQnnGraphParam(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_GraphHandle_t& graph, + const std::string& node_name, + std::unordered_map& tensors_created_table, + std::string& error_msg); + + private: + std::string name_; + std::string tensor_name_; + std::vector shape_; + std::vector param_data_; + Qnn_Param_t qnn_param_ = QNN_PARAM_INIT; +}; + +class QnnOpConfigWrapper { + public: + QnnOpConfigWrapper(const std::string& name, + const std::string& package_name, + const std::string& type_name, + std::vector&& inputs, + std::vector&& outputs, + std::vector&& params) : name_(name), + package_name_(package_name), + type_name_(type_name), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), + params_(std::move(params)) { + SetNames(name_.c_str(), package_name_.c_str(), type_name_.c_str()); + SetNums(static_cast(inputs_.size()), + static_cast(outputs_.size()), + static_cast(params_.size())); + SetData(inputs_.data(), outputs_.data(), params_.data()); + } + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnOpConfigWrapper); + + QnnOpConfigWrapper(QnnOpConfigWrapper&& other) noexcept { + std::swap(this->op_config_, other.op_config_); + std::swap(name_, other.name_); + std::swap(package_name_, other.package_name_); + std::swap(type_name_, other.type_name_); + std::swap(inputs_, other.inputs_); + std::swap(outputs_, other.outputs_); + std::swap(params_, other.params_); + SetNames(name_.c_str(), package_name_.c_str(), type_name_.c_str()); + SetData(inputs_.data(), outputs_.data(), params_.data()); + } + + ~QnnOpConfigWrapper() = default; + + const Qnn_OpConfig_t& GetQnnOpConfig() { return op_config_; } + + void SetNames(const char* op_name, + const char* package_name, + const char* type_name); + void SetNums(uint32_t num_inputs, + uint32_t num_outputs, + uint32_t num_params); + void SetData(Qnn_Tensor_t* input_tensors, + Qnn_Tensor_t* output_tensors, + Qnn_Param_t* params); + + const std::string& GetOpName() const { return name_; } + const std::string& GetPackageName() const { return package_name_; } + const std::string& GetTypeName() const { return type_name_; } + uint32_t GetInputsNum() const { return static_cast(inputs_.size()); } + uint32_t GetOutputsNum() const { return static_cast(outputs_.size()); } + uint32_t GetParamsNum() const { return static_cast(params_.size()); } + const Qnn_Tensor_t* GetInputTensors() const { return inputs_.data(); } + const Qnn_Tensor_t* GetOutputTensors() const { return outputs_.data(); } + const Qnn_Param_t* GetParams() const { return params_.data(); } + + bool QnnGraphOpValidation(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_BackendHandle_t& backend_handle, + std::string& error_msg); + + bool CreateQnnGraphOp(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_GraphHandle_t& graph, + std::string& error_msg); + + private: + std::string name_; + std::string package_name_; + std::string type_name_; + std::vector inputs_; + std::vector outputs_; + std::vector params_; + Qnn_OpConfig_t op_config_ = QNN_OPCONFIG_INIT; +}; + +class QnnOpProperty { + public: + QnnOpProperty(const std::string& node_name, + const std::string& package_name, + const std::string& node_type, + std::vector&& input_names, + std::vector&& outputs_names, + std::vector&& param_tensor_names) : qnn_node_name_(node_name), + package_name_(package_name), + qnn_node_type_(node_type), + input_names_(std::move(input_names)), + output_names_(std::move(outputs_names)), + param_tensor_names_(std::move(param_tensor_names)) {} + + const std::string& GetNodeName() const { return qnn_node_name_; } + const std::string& GetPackageName() const { return package_name_; } + const std::string& GetNodeType() const { return qnn_node_type_; } + const std::vector& GetInputNames() const { return input_names_; } + const std::vector& GetOutputNames() const { return output_names_; } + const std::vector& GetParamTensorNames() const { return param_tensor_names_; } + + QnnOpProperty(QnnOpProperty&& other) noexcept { + std::swap(qnn_node_name_, other.qnn_node_name_); + std::swap(package_name_, other.package_name_); + std::swap(qnn_node_type_, other.qnn_node_type_); + std::swap(input_names_, other.input_names_); + std::swap(output_names_, other.output_names_); + std::swap(param_tensor_names_, other.param_tensor_names_); + } + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnOpProperty); + + private: + std::string qnn_node_name_; + std::string package_name_; + std::string qnn_node_type_; + std::vector input_names_; + std::vector output_names_; + std::vector param_tensor_names_; +}; + +class GraphInfo { + public: + GraphInfo(const Qnn_GraphHandle_t graph, + const std::string& name, + std::vector&& input_tensors, + std::vector&& output_tensors) : graph_name_(name), + graph_(graph), + input_tensors_(std::move(input_tensors)), + output_tensors_(std::move(output_tensors)) { + } + + size_t NumInputTensors() const { return input_tensors_.size(); } + size_t NumOutputTensors() const { return output_tensors_.size(); } + const std::string& Name() const { return graph_name_; } + const std::vector& InputTensors() const { return input_tensors_; } + const std::vector& OutputTensors() const { return output_tensors_; } + const Qnn_GraphHandle_t& Graph() const { return graph_; } + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphInfo); + + private: + std::string graph_name_; + Qnn_GraphHandle_t graph_; + std::vector input_tensors_; + std::vector output_tensors_; +}; + +typedef GraphInfo* GraphInfoPtr_t; + +typedef struct GraphConfigInfo { + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphConfigInfo); + const char* graphName; + const QnnGraph_Config_t** graphConfigs; +} GraphConfigInfo_t; + +static const std::vector nchw2hwcn_perm{2, 3, 1, 0}; +static const std::vector nchw2hwcn_perm_3d{2, 3, 4, 1, 0}; +static const std::vector cnhw2hwcn_perm{2, 3, 0, 1}; +static const std::vector cnhw2hwcn_perm_3d{2, 3, 4, 0, 1}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/samples/qnnEp/builder/qnn_model_wrapper.cc b/samples/qnnEp/builder/qnn_model_wrapper.cc new file mode 100644 index 0000000000000..a041b1d599e93 --- /dev/null +++ b/samples/qnnEp/builder/qnn_model_wrapper.cc @@ -0,0 +1,627 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "qnn_model_wrapper.h" +//#include "core/common/safeint.h" +//#include "core/framework/tensorprotoutils.h" +//#include "core/providers/shared/utils/utils.h" +#include "qnn_utils.h" +#include "core/framework/int4.h" +#include "onnx/onnx_pb.h" + +namespace onnxruntime { +namespace qnn { + +bool QnnModelWrapper::CreateQnnGraph(const Qnn_ContextHandle_t& context, + const std::string& graph_name, + const QnnGraph_Config_t** graph_configs) { + if (!graph_name_.empty()) { + // only one graph is allowed per QnnModel +// LOGS(logger_, ERROR) << "Graph " << graph_name << " already initialized."; + return false; + } + if (context == nullptr) { +// LOGS(logger_, ERROR) << "Invalid Qnn context."; + return false; + } + if (graph_name.length() == 0) { +// LOGS(logger_, ERROR) << "Empty grpah name."; + return false; + } + + graph_name_ = graph_name; + auto rt = qnn_interface_.graphCreate(context, graph_name_.c_str(), graph_configs, &graph_); + if (rt != QNN_GRAPH_NO_ERROR || graph_ == nullptr) { + rt = qnn_interface_.graphRetrieve(context, graph_name_.c_str(), &graph_); + if (rt != QNN_GRAPH_NO_ERROR || graph_ == nullptr) { +// LOGS(logger_, ERROR) << "Failed to create Qnn graph: " << graph_name; + return false; + } + } +// LOGS(logger_, VERBOSE) << "Created Qnn graph: " << graph_name; + + return true; +} + +bool QnnModelWrapper::IsQnnTensorWrapperExist(const std::string& name) const { + return model_tensors_map_.find(name) != model_tensors_map_.end(); +} + +bool QnnModelWrapper::IsQnnParamExit(const std::string& param_tensor_name) const { + return model_params_map_.find(param_tensor_name) != model_params_map_.end(); +} + +//Status QnnModelWrapper::MakeTensorWrapper(const NodeUnitIODef& tensor, QnnTensorWrapper& tensor_wrapper) const { +// const std::string& tensor_name = tensor.node_arg.Name(); +// +// TensorInfo tensor_info = {}; +// ORT_RETURN_IF_ERROR(GetTensorInfo(tensor, tensor_info)); +// +// std::vector unpacked_tensor; +// if (tensor_info.is_initializer) { +// ORT_RETURN_IF_ERROR(UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor)); +// } +// +// tensor_wrapper = QnnTensorWrapper(tensor_name, GetTensorType(tensor_name), tensor_info.qnn_data_type, +// std::move(tensor_info.quant_param), std::move(tensor_info.shape), +// std::move(unpacked_tensor)); +// return Status::OK(); +//} + +bool QnnModelWrapper::AddTensorWrapper(QnnTensorWrapper&& tensor_wrapper) { + // Keep a copy of tensor name sine it will be moved with the wrapper into model_tensors_map_ + std::string tensor_name = tensor_wrapper.GetName(); + if (tensor_name.length() == 0) { +// LOGS(logger_, ERROR) << "Invalid tensor encountered empty name."; + return false; + } + + if (IsQnnTensorWrapperExist(tensor_name) == true) { +// LOGS(logger_, VERBOSE) << "Tensor exist already: " << tensor_name; + return true; + } + + const Qnn_TensorType_t& qnn_tensor_type = tensor_wrapper.GetTensorType(); + // save created tensors for later lookup to populate graph node construction + model_tensors_map_.emplace(tensor_name, std::move(tensor_wrapper)); + + // save network input/outputs tensors to use for setting the Qnn graph's + // input and output tensors for populating GraphInfo for caller + if (qnn_tensor_type == QNN_TENSOR_TYPE_APP_WRITE) { + model_input_names_.push_back(tensor_name); + } else if (qnn_tensor_type == QNN_TENSOR_TYPE_APP_READ) { + model_output_names_.push_back(tensor_name); + } + + return true; +} + +bool QnnModelWrapper::AddParamWrapper(QnnParamWrapper&& param_wrapper) { + // Keep a copy of tensor name sine it will be moved with the wrapper into model_params_map_ + std::string param_tensor_name = param_wrapper.GetParamTensorName(); + if (param_tensor_name.length() == 0) { +// LOGS(logger_, ERROR) << "Invalid parameter encountered empty name."; + return false; + } + + if (IsQnnParamExit(param_tensor_name) == true) { + return true; + } + + // save created tensors for later lookup to populate graph node construction + model_params_map_.emplace(param_tensor_name, std::move(param_wrapper)); + + return true; +} + +const QnnTensorWrapper& QnnModelWrapper::GetQnnTensorWrapper(const std::string& tensor_name) { + auto map_iter = model_tensors_map_.find(tensor_name); + if (map_iter != model_tensors_map_.end()) { + return (map_iter->second); + } + + ORT_THROW("Qnn tensor not exist: ", tensor_name); +} + +bool QnnModelWrapper::CreateQnnInputOutputTensors(const std::string& qnn_node_name, + const std::vector& tensor_names, + std::vector& qnn_tensors, + bool do_op_validation) { + for (const auto& tensor_name : tensor_names) { + auto it = model_tensors_map_.find(tensor_name); + if (it == model_tensors_map_.end()) { +// LOGS(logger_, ERROR) << "Input name not exist: " << tensor_name; + return false; + } + + // During graph patitioning, we only need to do op validation, it's not required to create Qnn graph tensor + // We only need to creat the Qnn graph tensor during Compile to create Qnn graph + if (!do_op_validation) { + std::string error_string; + auto rt = it->second.CreateQnnGraphTensor(qnn_interface_, graph_, qnn_node_name, tensor_created_map_, error_string); + if (!rt) { +// LOGS(logger_, ERROR) << error_string; + return false; + } +// LOGS(logger_, VERBOSE) << "Tensor: " << tensor_name << " created. " << error_string; + } + + qnn_tensors.push_back(it->second.GetQnnTensor()); + } + return true; +} + +bool QnnModelWrapper::CreateQnnParamTensors(const std::string& qnn_node_name, + const std::vector& param_tensor_names, + std::vector& qnn_params, + bool do_op_validation) { + for (const auto& param_tensor_name : param_tensor_names) { + auto it = model_params_map_.find(param_tensor_name); + if (it == model_params_map_.end()) { +// LOGS(logger_, ERROR) << "Parameter name not exist: " << param_tensor_name; + return false; + } + +// LOGS(logger_, VERBOSE) << "Add parameter tensor: " << it->second.GetName(); + if (!do_op_validation) { + std::string error_string; + auto rt = it->second.CreateQnnGraphParam(qnn_interface_, graph_, qnn_node_name, tensor_created_map_, error_string); + if (!rt) { +// LOGS(logger_, ERROR) << error_string; + return false; + } +// LOGS(logger_, VERBOSE) << "Tensor: " << param_tensor_name << " created. " << error_string; + } + + qnn_params.push_back(it->second.GetQnnParam()); + } + + return true; +} + +Status QnnModelWrapper::ValidateQnnNode(const std::string& node_name, + const std::string& package_name, + const std::string& qnn_op_type, + std::vector&& input_tensors, + std::vector&& output_tensors, + std::vector&& params) const { + QnnOpConfigWrapper op_config_wrapper(node_name, + package_name, + qnn_op_type, + std::move(input_tensors), + std::move(output_tensors), + std::move(params)); + + std::string error_msg; + ORT_RETURN_IF_NOT(op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg), error_msg); + + return Status::OK(); +} + +bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, + const std::string& package_name, + const std::string& qnn_node_type, + std::vector&& input_names, + std::vector&& output_names, + std::vector&& param_tensor_names, + bool do_op_validation) { + if (do_op_validation) { + std::vector input_tensors; + std::vector output_tensors; + std::vector params; + if (!CreateQnnInputOutputTensors(qnn_node_name, input_names, input_tensors, do_op_validation)) { + return false; + } + + if (!CreateQnnInputOutputTensors(qnn_node_name, output_names, output_tensors, do_op_validation)) { + return false; + } + + if (!CreateQnnParamTensors(qnn_node_name, param_tensor_names, params, do_op_validation)) { + return false; + } + + QnnOpConfigWrapper op_config_wrapper(qnn_node_name, + package_name, + qnn_node_type, + std::move(input_tensors), + std::move(output_tensors), + std::move(params)); + + using namespace onnxruntime::qnn::utils; +// LOGS(logger_, VERBOSE) << op_config_wrapper; + + std::string error_msg; + bool rt = op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg); + if (!rt) { +// LOGS(logger_, WARNING) << error_msg; + } + return rt; + } else { + QnnOpProperty qnn_op(qnn_node_name, package_name, qnn_node_type, + std::move(input_names), std::move(output_names), std::move(param_tensor_names)); + qnn_op_property_list_.push_back(std::move(qnn_op)); + return true; + } +} + +bool QnnModelWrapper::ComposeQnnGraph() { +// LOGS(logger_, VERBOSE) << "Compose Qnn Graph."; + // ORT_RETURN_IF(qnn_op_property_list_.empty(), "Empty Qnn op list, no graph to compose."); + if (qnn_op_property_list_.empty()) { + return false; + } + + for (const auto& op_property : qnn_op_property_list_) { + std::vector input_tensors; + std::vector output_tensors; + std::vector params; + if (!CreateQnnInputOutputTensors(op_property.GetNodeName(), op_property.GetInputNames(), input_tensors)) { + return false; + } + + if (!CreateQnnInputOutputTensors(op_property.GetNodeName(), op_property.GetOutputNames(), output_tensors)) { + return false; + } + + if (!CreateQnnParamTensors(op_property.GetNodeName(), op_property.GetParamTensorNames(), params)) { + return false; + } + + QnnOpConfigWrapper op_config_wrapper(op_property.GetNodeName(), + op_property.GetPackageName(), + op_property.GetNodeType(), + std::move(input_tensors), + std::move(output_tensors), + std::move(params)); + + using namespace onnxruntime::qnn::utils; +// LOGS(logger_, VERBOSE) << op_config_wrapper; + + std::string error_msg; + bool rt = op_config_wrapper.CreateQnnGraphOp(qnn_interface_, graph_, error_msg); + if (!rt) { +// LOGS(logger_, ERROR) << error_msg; + return false; + } + } + + return true; +} + +//bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vector& shape) { +// const auto* shape_proto = node_arg.Shape(); +// if (shape_proto == nullptr) { +// return false; +// } +// +// // For Scalar data, we need to set shape to 1 for QNN +// if (shape_proto->dim_size() < 1) { +// shape.push_back(1); +// return true; +// } +// +// // We already checked the shape has no dynamic dimension +// for (const auto& dim : shape_proto->dim()) { +//// shape.push_back(SafeInt(dim.dim_value())); +// } +// +// return true; +//} + +Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name, + /*out*/ std::vector& zero_points, + /*out*/ int32_t& onnx_data_type) const { +// const auto& graph_initializers = GetInitializerTensors(); +// auto iter = graph_initializers.find(initializer_name); +// ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for zero-point(s): ", +// initializer_name.c_str()); +// gsl::not_null zp_tensor_proto = iter->second; +// +// ORT_RETURN_IF_NOT(zp_tensor_proto->has_data_type(), "Expected zero-point initializer ", initializer_name.c_str(), +// " to have a proto data type."); +// +// onnx_data_type = zp_tensor_proto->data_type(); +// std::vector initializer_bytes; +// +// ORT_RETURN_IF_ERROR(UnpackInitializerData(*zp_tensor_proto, initializer_bytes)); +// +// switch (onnx_data_type) { +// // QNN use -offset for some reason +// case ONNX_NAMESPACE::TensorProto_DataType_INT4: // INT4 zero-points are unpacked as 8-bit values for QNN +// case ONNX_NAMESPACE::TensorProto_DataType_INT8: { +// auto int8_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); +// std::transform(int8_span.begin(), int8_span.end(), std::back_inserter(zero_points), +// [](int8_t zp) -> int32_t { +// return -static_cast(zp); +// }); +// break; +// } +// case ONNX_NAMESPACE::TensorProto_DataType_UINT4: // UINT4 zero-points are unpacked as 8-bit values for QNN +// case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { +// auto uint8_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); +// std::transform(uint8_span.begin(), uint8_span.end(), std::back_inserter(zero_points), +// [](uint8_t zp) -> int32_t { +// return -static_cast(zp); +// }); +// break; +// } +// case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { +// auto uint16_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); +// std::transform(uint16_span.begin(), uint16_span.end(), std::back_inserter(zero_points), +// [](uint16_t zp) -> int32_t { +// return -static_cast(zp); +// }); +// break; +// } +// case ONNX_NAMESPACE::TensorProto_DataType_INT16: { +// auto int16_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); +// std::transform(int16_span.begin(), int16_span.end(), std::back_inserter(zero_points), +// [](int16_t zp) -> int32_t { +// return -static_cast(zp); +// }); +// break; +// } +// case ONNX_NAMESPACE::TensorProto_DataType_INT32: { +// auto int32_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); +// std::transform(int32_span.begin(), int32_span.end(), std::back_inserter(zero_points), +// [](int32_t zp) -> int32_t { +// return -zp; +// }); +// break; +// } +// case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { +// auto uint32_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); +// std::transform(uint32_span.begin(), uint32_span.end(), std::back_inserter(zero_points), +// [](uint32_t zp) -> int32_t { +// return -static_cast(zp); +// }); +// break; +// } +// default: { +// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Zero-point ONNX data type `", onnx_data_type, +// "` is not supported."); +// } +// } + + return Status::OK(); +} + +Status QnnModelWrapper::UnpackScales(const std::string& initializer_name, std::vector& scales) const { +// const auto& graph_initializers = GetInitializerTensors(); +// auto iter = graph_initializers.find(initializer_name); +// ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for scale(s): ", +// initializer_name.c_str()); +// gsl::not_null scale_tensor_proto = iter->second; +// +// ORT_RETURN_IF_NOT(scale_tensor_proto->has_data_type(), "Expected scale initializer ", initializer_name.c_str(), +// " to have a proto data type."); +// ORT_RETURN_IF_NOT(scale_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT, +// "Expected scale initializer to be of type FLOAT"); +// +// std::vector initializer_bytes; +// +// ORT_RETURN_IF_ERROR(UnpackInitializerData(*scale_tensor_proto, initializer_bytes)); +// +// gsl::span src = gsl::make_span(reinterpret_cast(initializer_bytes.data()), +// initializer_bytes.size() / sizeof(float)); +// +// scales.insert(scales.end(), src.begin(), src.end()); + return Status::OK(); +} + +// Checks if a tensor in the ONNX graph is per-channel quantized. +//Status QnnModelWrapper::IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& io_def, +// /*out*/ bool& is_per_channel, +// /*out*/ int64_t& axis) const { +// if (!io_def.quant_param) { +// is_per_channel = false; +// return Status::OK(); +// } +// +// const std::string& scale_name = io_def.quant_param->scale.Name(); +// const auto& graph_initializers = GetInitializerTensors(); +// auto iter = graph_initializers.find(scale_name); +// ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for scale(s): ", +// scale_name.c_str()); +// gsl::not_null scale_tensor_proto = iter->second; +// TensorShape scale_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(*scale_tensor_proto); +// +// // Check the number of scale values to determine if the tensor is per-channel. +// // This is consistent with CPU EP's Quant/Dequant logic. We can't use the presence of an axis because even a +// // per-channel DQ/Q op may not have an explicit axis attribute (assumed to default to 1 if missing). +// const bool is_scalar_or_1_elem_vector = scale_shape.NumDimensions() == 0 || +// (scale_shape.NumDimensions() == 1 && scale_shape.Size() == 1); +// +// is_per_channel = !is_scalar_or_1_elem_vector; +// +// if (is_per_channel) { +// axis = io_def.quant_param->axis.value_or(1); // 1 is default axis for Q/DQ ops. +// } +// +// return Status::OK(); +//} + +//Status QnnModelWrapper::GetTensorInfo(const NodeUnitIODef& input, TensorInfo& tensor_info) const { +// const std::string& name = input.node_arg.Name(); +// +// // Fill in quantization param info. +// ORT_RETURN_IF_ERROR(tensor_info.quant_param.Init(*this, input)); +// +// // Fill in QNN data type. +// tensor_info.qnn_data_type = QNN_DATATYPE_FLOAT_32; +// ORT_RETURN_IF_ERROR(utils::GetQnnDataType(input.quant_param.has_value(), input.node_arg.TypeAsProto(), +// tensor_info.qnn_data_type)); +// +// // Fill in shape. +// ORT_RETURN_IF_NOT(GetOnnxShape(input.node_arg, tensor_info.shape), "Cannot get shape"); +// +// // Fill in initializer info. +// tensor_info.is_initializer = IsInitializerInput(name); +// if (tensor_info.is_initializer) { +// tensor_info.initializer_tensor = GetInitializerTensors().at(name); +// } +// +// return Status::OK(); +//} + +Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input, + bool is_for_output) { + // Do not allow QNN EP to insert Reshape nodes with per-channel quantization on dynamic tensors. + // We could technically support this by shifting the quantization param's axis value, but + // we don't need this right now. + ORT_RETURN_IF(quantize_param.IsPerChannel(), + "Do not support inserted Reshape nodes with per-channel quantization"); + QnnTensorWrapper input_tensorwrapper(input_name, + is_for_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_NATIVE, + tensor_data_type, + quantize_param.Copy(), + std::vector(input_shape)); + ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(input_tensorwrapper)), + "QNN EP: Failed to add input tensor for inserted Reshape."); + + Qnn_TensorType_t tensor_type = is_for_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper output_tensorwrapper(output_name, + tensor_type, + tensor_data_type, + quantize_param.Copy(), + std::vector(output_shape)); + ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), + "QNN EP: Failed to add output tensor for inserted Reshape."); + + ORT_RETURN_IF_NOT(CreateQnnNode(output_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_RESHAPE, + {input_name}, + {output_name}, + {}, + do_op_validation), + "QNN EP: Failed to create manually inserted Qnn Reshape node."); + + return Status::OK(); +} + +Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index, + const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& transpose_perm, + const std::vector& output_shape, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input, + bool is_for_output) { + // Do not allow QNN EP to insert transpose nodes with per-channel quantization on dynamic tensors. + // We could technically support this by transposing the quantization param's axis value, but + // we don't need this right now. + ORT_RETURN_IF(quantize_param.IsPerChannel(), + "Do not support inserted Transpose nodes with per-channel quantization"); + // No need to add this for output nodes as it is added as output tensor for previous node + if (is_for_input) { + Qnn_TensorType_t tensor_type = QNN_TENSOR_TYPE_APP_WRITE; + QnnTensorWrapper input_tensorwrapper(input_name, + tensor_type, + tensor_data_type, + quantize_param.Copy(), + std::vector(input_shape)); + ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + + uint32_t perm_size = static_cast(transpose_perm.size()); + std::vector perm_dim{perm_size}; + std::vector transpose_perm_copy = transpose_perm; + const std::string& node_name = output_name; + QnnParamWrapper transpose_param(node_index, node_name, QNN_OP_TRANSPOSE_PARAM_PERM, std::move(perm_dim), std::move(transpose_perm_copy)); + std::string param_tensor_name(transpose_param.GetParamTensorName()); + ORT_RETURN_IF_NOT(AddParamWrapper(std::move(transpose_param)), "Failed to add tensor."); + Qnn_TensorType_t tensor_type = (false == is_for_output) ? QNN_TENSOR_TYPE_NATIVE : QNN_TENSOR_TYPE_APP_READ; + std::vector output_shape_copy = output_shape; + QnnTensorWrapper output_tensorwrapper(output_name, + tensor_type, + tensor_data_type, + quantize_param.Copy(), + std::move(output_shape_copy)); + ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); + const static std::string qnn_node_type = "Transpose"; + + CreateQnnNode(output_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + qnn_node_type, + {input_name}, + {output_name}, + {param_tensor_name}, + do_op_validation); + + return Status::OK(); +} + +void QnnModelWrapper::GetGraphInputOutputTensorWrapper(const std::vector& tensor_name_list, + std::vector& wrappers_list) { + for (const auto& tensor_name : tensor_name_list) { + auto it = model_tensors_map_.find(tensor_name); + if (it == model_tensors_map_.end()) { +// LOGS(logger_, ERROR) << "Model input or output name not exist: " << tensor_name +// << ". Could cause execution error."; + break; + } + // It's safe to move QnnTensorWrapper out of model_tensors_map_ + // since this call happens when QnnModelWrapper end of live + wrappers_list.push_back(std::move(it->second)); + model_tensors_map_.erase(tensor_name); + } + + return; +} + +Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, + std::vector& unpacked_tensor) const { +// if (initializer.data_location() == onnx::TensorProto_DataLocation_EXTERNAL) { +// ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), +// unpacked_tensor)); +// } else { +// ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor)); +// } + + int32_t onnx_data_type = initializer.data_type(); + + // If this is an int4, we need to unpack it because QNN treats int4 as a full int8. +// if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) { +// TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); +// const size_t num_elems = shape.Size(); +// std::vector packed_int4_bytes = std::move(unpacked_tensor); +// unpacked_tensor = std::vector(num_elems); +// +// auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); +// auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); +// ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); +// } else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { +// TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); +// const size_t num_elems = shape.Size(); +// std::vector packed_int4_bytes = std::move(unpacked_tensor); +// unpacked_tensor = std::vector(num_elems); +// +// auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); +// auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); +// ORT_RETURN_IF_NOT(UInt4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); +// } + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/samples/qnnEp/builder/qnn_model_wrapper.h b/samples/qnnEp/builder/qnn_model_wrapper.h new file mode 100644 index 0000000000000..24f9cffa0da7c --- /dev/null +++ b/samples/qnnEp/builder/qnn_model_wrapper.h @@ -0,0 +1,285 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/status.h" +#include "QnnInterface.h" +#include "qnn_def.h" +//#include "core/common/logging/logging.h" +//#include "core/framework/node_unit.h" +//#include "core/graph/graph_viewer.h" +//#include "core/providers/shared/utils/utils.h" +#include "qnn_quant_params_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +// Stores information about an ONNX input or output tensor. +// Filled out by QnnModelWrapper::GetTensorInfo() +struct TensorInfo { + std::vector shape; + Qnn_DataType_t qnn_data_type; + QnnQuantParamsWrapper quant_param; + bool is_initializer; + const ONNX_NAMESPACE::TensorProto* initializer_tensor; +}; + +class QnnModelWrapper { + public: + QnnModelWrapper(//const GraphViewer& graph_viewer, +// const logging::Logger& logger, + const QNN_INTERFACE_VER_TYPE& qnn_interface, + const Qnn_BackendHandle_t& backend_handle, + const std::unordered_map& input_index_map, + const std::unordered_map& output_index_map, + const std::unordered_set& initializer_lookup, + QnnBackendType qnn_backend_type) + : //graph_viewer_(graph_viewer), +// logger_(logger), + qnn_interface_(qnn_interface), + backend_handle_(backend_handle), + input_index_map_(input_index_map), + output_index_map_(output_index_map), + initializer_lookup_(initializer_lookup), + qnn_backend_type_(qnn_backend_type) { + } + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnModelWrapper); + + ~QnnModelWrapper() = default; + + bool CreateQnnGraph(const Qnn_ContextHandle_t& context, + const std::string& graph_name, + const QnnGraph_Config_t** graph_configs = nullptr); + + // Make a QnnTensorWrapper from an onnx input or output. +// Status MakeTensorWrapper(const NodeUnitIODef& tensor, QnnTensorWrapper& tensor_wrapper) const; + + // Add to internal tensor wrapper table + bool AddTensorWrapper(QnnTensorWrapper&& tensor_wrapper); + + // Add to internal param wrapper table + bool AddParamWrapper(QnnParamWrapper&& param_wrapper); + + const QnnTensorWrapper& GetQnnTensorWrapper(const std::string& tensor_name); + + // Utility function to validate a QNN node. Does not modify this object's state. + Status ValidateQnnNode(const std::string& node_name, + const std::string& package_name, + const std::string& qnn_op_type, + std::vector&& input_tensors, + std::vector&& output_tensors, + std::vector&& params) const; + + bool CreateQnnNode(const std::string& name, + const std::string& package_name, + const std::string& type, + std::vector&& input_names, + std::vector&& output_names, + std::vector&& param_tensor_names, + bool do_op_validation = false); + + bool ComposeQnnGraph(); + + Qnn_GraphHandle_t GetQnnGraph() { return graph_; } + + std::string GetQnnGraphName() const { return graph_name_; } + + // Move input tensor wrappers to GraphInfo, QnnModelWrapper end of live + std::vector&& GetGraphInputTensorWrappers() { + GetGraphInputOutputTensorWrapper(model_input_names_, model_input_tensor_wrappers_); + return std::move(model_input_tensor_wrappers_); + } + + // Move output tensor wrappers to GraphInfo, QnnModelWrapper end of live + std::vector&& GetGraphOutputTensorWrappers() { + GetGraphInputOutputTensorWrapper(model_output_names_, model_output_tensor_wrappers_); + return std::move(model_output_tensor_wrappers_); + } + +// const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } + + bool IsInitializerInput(std::string input_name) const { + return initializer_lookup_.find(input_name) != initializer_lookup_.end(); + } + +// static bool GetOnnxShape(const NodeArg& node_arg, std::vector& shape); + + bool IsQnnTensorWrapperExist(const std::string& tensor_name) const; + + bool IsGraphOutput(const std::string& tensor_name) const { + return output_index_map_.find(tensor_name) != output_index_map_.end(); + } + + bool IsGraphInput(const std::string& tensor_name) const { + return input_index_map_.find(tensor_name) != input_index_map_.end(); + } + + Qnn_TensorType_t GetTensorType(const std::string& tensor_name) const { + if (IsInitializerInput(tensor_name)) { + return QNN_TENSOR_TYPE_STATIC; + } else if (IsGraphInput(tensor_name)) { + return QNN_TENSOR_TYPE_APP_WRITE; + } else if (IsGraphOutput(tensor_name)) { + return QNN_TENSOR_TYPE_APP_READ; + } else { + return QNN_TENSOR_TYPE_NATIVE; + } + } + + //Status GetTensorInfo(const NodeUnitIODef& input, TensorInfo& input_info) const; + + Status AddReshapeNode(const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input = true, + bool is_for_output = false); + + Status AddTransposeNode(NodeIndex node_index, + const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& transpose_perm, + const std::vector& output_shape, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input = true, + bool is_for_output = false); + + // Tranpose NCHW->HWCN for QNN weight + Status AddNchwToHwcnTranspose(NodeIndex node_index, + const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input = true, + bool is_for_output = false, + bool is_3d = false) { +// LOGS(logger_, VERBOSE) << "Add NCHW->HWCN Transpose node after Conv weight input: " << input_name +// << " -> " << output_name; + auto perm = is_3d ? nchw2hwcn_perm_3d : nchw2hwcn_perm; + std::vector transpose_perm; + transpose_perm.resize(perm.size()); + std::transform(perm.begin(), perm.end(), + transpose_perm.begin(), [](size_t item) -> uint32_t { + return gsl::narrow(item); + }); + return AddTransposeNode(node_index, input_name, output_name, input_shape, transpose_perm, output_shape, + tensor_data_type, quantize_param, do_op_validation, is_for_input, is_for_output); + } + + // Tranpose CNHW->HWCN for QNN weight + Status AddCnhwToHwcnTranspose(NodeIndex node_index, + const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input = true, + bool is_for_output = false, + bool is_3d = false) { +// LOGS(logger_, VERBOSE) << "Add CNHW->HWCN Transpose node after ConvTranspose weight input: " << input_name +// << " -> " << output_name; + auto perm = is_3d ? cnhw2hwcn_perm_3d : cnhw2hwcn_perm; + std::vector transpose_perm; + transpose_perm.resize(perm.size()); + std::transform(perm.begin(), perm.end(), + transpose_perm.begin(), [](size_t item) -> uint32_t { + return gsl::narrow(item); + }); + return AddTransposeNode(node_index, input_name, output_name, input_shape, transpose_perm, output_shape, + tensor_data_type, quantize_param, do_op_validation, is_for_input, is_for_output); + } + + Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, + std::vector& unpacked_tensor) const; + + QnnBackendType GetQnnBackendType() const { return qnn_backend_type_; } + +// const GraphViewer& GetGraphViewer() const { return graph_viewer_; } + + // Unpack float scales from initializer (1 scale for per-tensor, > 1 for per-axis). + Status UnpackScales(const std::string& initializer_name, std::vector& scales) const; + + // Unpack zero-points from initializer and convert to int32_t (1 zero-point for per-tensor, > 1 for per-channel). + Status UnpackZeroPoints(const std::string& initializer_name, + /*out*/ std::vector& zero_points, + /*out*/ int32_t& onnx_data_type) const; + + // Checks if a tensor in the ONNX graph is per-channel quantized. +// Status IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& io_def, +// /*out*/ bool& is_per_channel, +// /*out*/ int64_t& axis) const; + + private: + bool CreateQnnInputOutputTensors(const std::string& qnn_node_name, + const std::vector& names, + std::vector& tensor_wrappers, + bool do_op_validation = false); + + bool IsQnnParamExit(const std::string& param_tensor_name) const; + + bool CreateQnnParamTensors(const std::string& qnn_node_name, + const std::vector& param_tensor_names, + std::vector& qnn_params, + bool do_op_validation = false); + +// bool IsQDQNode(const Node& node) const { +// if (node.OpType() == "QuantizeLinear" || node.OpType() == "DequantizeLinear") { +// return true; +// } +// return false; +// } + + bool IsQnnTensorCreated(const std::string& tensor_name) { + auto pos = tensor_created_map_.find(tensor_name); + if (pos == tensor_created_map_.end()) { + return false; + } + return pos->second; + } + + void GetGraphInputOutputTensorWrapper(const std::vector& names, + std::vector& wrappers_list); + +// const GraphViewer& graph_viewer_; +// const logging::Logger& logger_; + const QNN_INTERFACE_VER_TYPE& qnn_interface_; + const Qnn_BackendHandle_t& backend_handle_; + Qnn_GraphHandle_t graph_ = nullptr; + std::string graph_name_ = ""; + + std::vector model_input_names_; + std::vector model_output_names_; + std::vector model_input_tensor_wrappers_; + std::vector model_output_tensor_wrappers_; + // All QnnTensorWrapper for the graph + std::unordered_map model_tensors_map_; + // All QnnParamWrapper for the graph + std::unordered_map model_params_map_; + std::vector qnn_op_property_list_; + // -- true means qnn tensor created in qnn graph + // it includs normal qnn_tensors and qnn_tensors inside param_tensors + std::unordered_map tensor_created_map_; + const std::unordered_map& input_index_map_; + const std::unordered_map& output_index_map_; + const std::unordered_set& initializer_lookup_; + QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; +}; // QnnModelWrapper + +} // namespace qnn +} // namespace onnxruntime diff --git a/samples/qnnEp/builder/qnn_quant_params_wrapper.cc b/samples/qnnEp/builder/qnn_quant_params_wrapper.cc new file mode 100644 index 0000000000000..65d748f3b6a0e --- /dev/null +++ b/samples/qnnEp/builder/qnn_quant_params_wrapper.cc @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "qnn_quant_params_wrapper.h" +#include +#include +#include +#include +#include "qnn_model_wrapper.h" + +#define ALIGN_PTR_UP(ptr, align, type) \ + reinterpret_cast((reinterpret_cast(ptr) + (align)-1) & ~((align)-1)) + +namespace onnxruntime { +namespace qnn { + +QnnQuantParamsWrapper::QnnQuantParamsWrapper(const QnnQuantParamsWrapper& other) + : params_(QNN_QUANTIZE_PARAMS_INIT) { + Status status = Init(other.params_); + assert(status.IsOK()); // Expect other QnnQuantParamsWrapper to always have a supported quantization encoding. +} + +QnnQuantParamsWrapper& QnnQuantParamsWrapper::operator=(const QnnQuantParamsWrapper& other) { + if (this != &other) { + Status status = Init(other.params_); + assert(status.IsOK()); // Expect other QnnQuantParamsWrapper to always have a supported quantization encoding. + } + + return *this; +} + +QnnQuantParamsWrapper::QnnQuantParamsWrapper(float scale, int32_t offset) { + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + params_.scaleOffsetEncoding.scale = scale; + params_.scaleOffsetEncoding.offset = offset; +} + +QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const { + return QnnQuantParamsWrapper(*this); +} + +// Initializes by copying from a Qnn_QuantizeParams_t. +Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { + if (per_channel_data_) { + per_channel_data_.reset(nullptr); + params_ = QNN_QUANTIZE_PARAMS_INIT; + } + + if (params.encodingDefinition != QNN_DEFINITION_DEFINED) { + params_ = params; + return Status::OK(); + } + + switch (params.quantizationEncoding) { + case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: + case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET: + params_ = params; + break; + case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: { + params_.encodingDefinition = params.encodingDefinition; + params_.quantizationEncoding = params.quantizationEncoding; + params_.axisScaleOffsetEncoding.axis = params.axisScaleOffsetEncoding.axis; + params_.axisScaleOffsetEncoding.numScaleOffsets = params.axisScaleOffsetEncoding.numScaleOffsets; + + // Deep copy the scaleOffset data. + const uint32_t num_elems = params.axisScaleOffsetEncoding.numScaleOffsets; + + if (num_elems > 0) { + const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t); + constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t); + per_channel_data_ = std::make_unique(num_bytes + align); + Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*); + + std::memcpy(aligned_dst, params.axisScaleOffsetEncoding.scaleOffset, num_bytes); + params_.axisScaleOffsetEncoding.scaleOffset = aligned_dst; + } else { + params_.axisScaleOffsetEncoding.scaleOffset = nullptr; + } + break; + } + case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: { + const uint32_t num_elems = params.bwAxisScaleOffsetEncoding.numElements; + + params_.encodingDefinition = params.encodingDefinition; + params_.quantizationEncoding = params.quantizationEncoding; + params_.bwAxisScaleOffsetEncoding.axis = params.bwAxisScaleOffsetEncoding.axis; + params_.bwAxisScaleOffsetEncoding.bitwidth = params.bwAxisScaleOffsetEncoding.bitwidth; + params_.bwAxisScaleOffsetEncoding.numElements = num_elems; + + // Deep copy the scales[] and offsets[] arrays + if (num_elems > 0) { + const size_t num_scale_bytes = num_elems * sizeof(float); + const size_t num_zp_bytes = num_elems * sizeof(int32_t); + const size_t num_bytes = num_scale_bytes + num_zp_bytes; + constexpr std::uintptr_t align = alignof(float); + static_assert(alignof(float) == alignof(int32_t)); + + per_channel_data_ = std::make_unique(num_bytes + align); + char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*); + char* zps_begin = scales_begin + num_scale_bytes; + + std::memcpy(scales_begin, params.bwAxisScaleOffsetEncoding.scales, num_scale_bytes); + std::memcpy(zps_begin, params.bwAxisScaleOffsetEncoding.offsets, num_zp_bytes); + params_.bwAxisScaleOffsetEncoding.scales = reinterpret_cast(scales_begin); + params_.bwAxisScaleOffsetEncoding.offsets = reinterpret_cast(zps_begin); + } else { + params_.bwAxisScaleOffsetEncoding.scales = nullptr; + params_.bwAxisScaleOffsetEncoding.offsets = nullptr; + } + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params.quantizationEncoding); + } + + return Status::OK(); +} +// +//// Initialize this object from a (potentially) quantized ONNX tensor. +//// QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers. +//Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def) { +// const std::optional& ort_quant_params = io_def.quant_param; +// +// if (per_channel_data_) { +// per_channel_data_.reset(nullptr); +// params_ = QNN_QUANTIZE_PARAMS_INIT; +// } +// +// if (!ort_quant_params.has_value()) { +// params_.encodingDefinition = QNN_DEFINITION_UNDEFINED; +// params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; +// return Status::OK(); +// } +// +// std::vector scales; +// std::vector zero_points; +// +// ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(ort_quant_params->scale.Name(), scales)); +// +// bool is_int4_type = false; +// +// if (ort_quant_params->zero_point != nullptr) { +// int32_t onnx_tp_type = 0; +// ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points, +// onnx_tp_type)); +// +// is_int4_type = (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) || +// (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4); +// } +// +// const bool is_per_tensor = scales.size() == 1; +// +// // QNN uses different structs to represent quantization parameters depending on +// // - per-tensor vs per-channel +// // - int4 vs not int4 +// if (is_per_tensor && !is_int4_type) { +// params_.encodingDefinition = QNN_DEFINITION_DEFINED; +// params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; +// params_.scaleOffsetEncoding.scale = scales[0]; +// +// if (ort_quant_params->zero_point != nullptr) { +// ORT_RETURN_IF_NOT(zero_points.size() == 1, "Expected one zero-point value"); +// params_.scaleOffsetEncoding.offset = zero_points[0]; +// } else { +// params_.scaleOffsetEncoding.offset = 0; +// } +// } else if (is_per_tensor && is_int4_type) { +// params_.encodingDefinition = QNN_DEFINITION_DEFINED; +// params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET; +// params_.bwScaleOffsetEncoding.bitwidth = 4; +// params_.bwScaleOffsetEncoding.scale = scales[0]; +// +// if (ort_quant_params->zero_point != nullptr) { +// ORT_RETURN_IF_NOT(zero_points.size() == 1, "Expected one zero-point value"); +// params_.bwScaleOffsetEncoding.offset = zero_points[0]; +// } else { +// params_.bwScaleOffsetEncoding.offset = 0; +// } +// } else if (!is_per_tensor && is_int4_type) { +// const auto* io_shape = io_def.node_arg.Shape(); +// ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape"); +// const int32_t io_rank = io_shape->dim_size(); +// +// constexpr int64_t DEFAULT_QDQ_AXIS = 1; +// int64_t axis = ort_quant_params->axis.value_or(DEFAULT_QDQ_AXIS); +// if (axis < 0) { +// axis += io_rank; +// } +// ORT_RETURN_IF_NOT(axis >= 0 && axis < io_rank, +// "Quantization axis must be within the range [0, rank - 1]"); +// +// const size_t num_elems = scales.size(); +// const bool no_zero_points = zero_points.empty(); +// ORT_RETURN_IF_NOT(num_elems > 1, "Expected more than one scale value"); +// ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems, +// "Expected the same number of zero-points and scales for per-channel quantization"); +// +// params_.encodingDefinition = QNN_DEFINITION_DEFINED; +// params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; +// params_.bwAxisScaleOffsetEncoding.axis = static_cast(*(ort_quant_params->axis)); +// params_.bwAxisScaleOffsetEncoding.bitwidth = 4; +// params_.bwAxisScaleOffsetEncoding.numElements = static_cast(num_elems); +// +// const size_t num_scale_bytes = num_elems * sizeof(float); +// const size_t num_zp_bytes = num_elems * sizeof(int32_t); +// const size_t num_bytes = num_scale_bytes + num_zp_bytes; +// constexpr std::uintptr_t align = alignof(float); +// per_channel_data_ = std::make_unique(num_bytes + align); +// +// char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*); +// char* zps_begin = scales_begin + num_scale_bytes; +// gsl::span scales_span(reinterpret_cast(scales_begin), num_elems); +// gsl::span zps_span(reinterpret_cast(zps_begin), num_elems); +// +// for (size_t i = 0; i < num_elems; i++) { +// scales_span[i] = scales[i]; +// zps_span[i] = no_zero_points ? 0 : zero_points[i]; +// } +// +// params_.bwAxisScaleOffsetEncoding.scales = scales_span.data(); +// params_.bwAxisScaleOffsetEncoding.offsets = zps_span.data(); +// } else if (!is_per_tensor && !is_int4_type) { +// const auto* io_shape = io_def.node_arg.Shape(); +// ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape"); +// const int32_t io_rank = io_shape->dim_size(); +// +// constexpr int64_t DEFAULT_QDQ_AXIS = 1; +// int64_t axis = ort_quant_params->axis.value_or(DEFAULT_QDQ_AXIS); +// if (axis < 0) { +// axis += io_rank; +// } +// ORT_RETURN_IF_NOT(axis >= 0 && axis < io_rank, +// "Quantization axis must be within the range [0, rank - 1]"); +// +// params_.encodingDefinition = QNN_DEFINITION_DEFINED; +// params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; +// +// const size_t num_elems = scales.size(); +// const bool no_zero_points = zero_points.empty(); +// ORT_RETURN_IF_NOT(num_elems > 1, "Expected more than one scale value"); +// ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems, +// "Expected the same number of zero-points and scales for per-channel quantization"); +// +// const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t); +// constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t); +// per_channel_data_ = std::make_unique(num_bytes + align); +// Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*); +// gsl::span data_span(aligned_dst, num_elems); +// +// for (size_t i = 0; i < num_elems; i++) { +// data_span[i].scale = scales[i]; +// data_span[i].offset = no_zero_points ? 0 : zero_points[i]; +// } +// +// params_.axisScaleOffsetEncoding.axis = static_cast(axis); +// params_.axisScaleOffsetEncoding.numScaleOffsets = static_cast(num_elems); +// params_.axisScaleOffsetEncoding.scaleOffset = data_span.data(); +// } else { +// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected tensor kind for QuantParamsWrapper::Init()"); +// } +// +// return Status::OK(); +//} +} // namespace qnn +} // namespace onnxruntime diff --git a/samples/qnnEp/builder/qnn_quant_params_wrapper.h b/samples/qnnEp/builder/qnn_quant_params_wrapper.h new file mode 100644 index 0000000000000..d785fb98d1f83 --- /dev/null +++ b/samples/qnnEp/builder/qnn_quant_params_wrapper.h @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "QnnTypes.h" +#include "core/common/common.h" +#include +//#include "core/framework/node_unit.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; // Forward-declare + +class QnnQuantParamsWrapper { + public: + QnnQuantParamsWrapper() : params_(QNN_QUANTIZE_PARAMS_INIT) {} + + QnnQuantParamsWrapper(const QnnQuantParamsWrapper& other); + QnnQuantParamsWrapper& operator=(const QnnQuantParamsWrapper& other); + + QnnQuantParamsWrapper(QnnQuantParamsWrapper&& other) = default; + QnnQuantParamsWrapper& operator=(QnnQuantParamsWrapper&& other) = default; + + // Construct a per-tensor quantization param (SCALE_OFFSET) + QnnQuantParamsWrapper(float scale, int32_t offset); + + Qnn_QuantizeParams_t& Get() { return params_; } + const Qnn_QuantizeParams_t& Get() const { return params_; } + + // Initialize this object from a raw Qnn_QuantizeParam_t object. + Status Init(const Qnn_QuantizeParams_t& params); + + // Initialize this object from a (potentially) quantized ONNX tensor. + // QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers. +// Status Init(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def); + + QnnQuantParamsWrapper Copy() const; + + bool IsQuantized() const { + return params_.encodingDefinition == QNN_DEFINITION_DEFINED; + } + + bool IsPerTensor(bool include_bw = false) const { + return params_.encodingDefinition == QNN_DEFINITION_DEFINED && + (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET || + (include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET)); + } + + bool IsPerChannel() const { + return params_.encodingDefinition == QNN_DEFINITION_DEFINED && + (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET || + (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)); + } + + // Handle transposing of a per-channel quantized tensor. The quantization parameter's axis + // must be transposed using the inverse permutation of the Transpose. + template + Status HandleTranspose(gsl::span perm) { + if (!IsPerChannel()) { + return Status::OK(); + } + + if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + ORT_RETURN_IF_NOT(static_cast(params_.axisScaleOffsetEncoding.axis) < perm.size(), + "Axis value is out of range of the provided permutation"); + const int32_t new_axis = static_cast(perm[params_.axisScaleOffsetEncoding.axis]); + params_.axisScaleOffsetEncoding.axis = new_axis; + } else if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + ORT_RETURN_IF_NOT(static_cast(params_.bwAxisScaleOffsetEncoding.axis) < perm.size(), + "Axis value is out of range of the provided permutation"); + const int32_t new_axis = static_cast(perm[params_.bwAxisScaleOffsetEncoding.axis]); + params_.bwAxisScaleOffsetEncoding.axis = new_axis; + } + + return Status::OK(); + } + + // Handle "unsqueeze" of a per-channel quantized tensor. The quantization parameter's axis + // may need to be shifted if the unsqueeze inserted 1s before the quantization axis. + template + Status HandleUnsqueeze(gsl::span orig_shape, + gsl::span new_shape) { + if (!IsPerChannel()) { + return Status::OK(); + } + + ORT_RETURN_IF_NOT(orig_shape.size() < new_shape.size(), "Expected unsqueezed shape to have a greater rank."); + + // Get the axis value. + int32_t axis = 0; + if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + axis = params_.axisScaleOffsetEncoding.axis; + } else if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + axis = params_.bwAxisScaleOffsetEncoding.axis; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unhandled quantization encoding: ", params_.quantizationEncoding); + } + + // Find where the axis was moved to after unsqueeze. + size_t num_found = 0; + size_t j = 0; + for (size_t i = 0; i < orig_shape.size() && j < new_shape.size(); i++) { + while (orig_shape[i] != new_shape[j] && j < new_shape.size()) { + assert(new_shape[j] == 1); + j++; + } + assert(orig_shape[i] == new_shape[j]); + if (num_found == static_cast(axis)) { + break; + } + num_found += 1; + j++; + } + + if (j == static_cast(axis)) { + return Status::OK(); + } + + // Set new axis. + if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + params_.axisScaleOffsetEncoding.axis = static_cast(j); + } else if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + params_.bwAxisScaleOffsetEncoding.axis = static_cast(j); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unhandled quantization encoding: ", params_.quantizationEncoding); + } + + return Status::OK(); + } + + private: + Qnn_QuantizeParams_t params_; + + // Stores arrays of per-channel scales and offsets. Fields in params_ point to this data. + // + // Use an opaque array of bytes because QNN uses different data layouts depending on the quantization encoding: + // - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: array of scale/zp pairs [{scale0, zp0}, {scale1, zp1}, ...] + // - QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: parallel arrays for scales and zps [scale0, ...] [zp0, zp1, ...] + std::unique_ptr per_channel_data_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/samples/qnnEp/builder/qnn_utils.cc b/samples/qnnEp/builder/qnn_utils.cc new file mode 100644 index 0000000000000..22b260ec65904 --- /dev/null +++ b/samples/qnnEp/builder/qnn_utils.cc @@ -0,0 +1,557 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +//#include "core/framework/data_types.h" +#include "qnn_utils.h" +#include "qnn_def.h" +#include "core/framework/int4.h" + +namespace onnxruntime { +namespace qnn { +namespace utils { + +size_t GetElementSizeByType(const Qnn_DataType_t& data_type) { + const static std::unordered_map data_type_to_size = { + {QNN_DATATYPE_INT_8, 1}, + {QNN_DATATYPE_INT_16, 2}, + {QNN_DATATYPE_INT_32, 4}, + {QNN_DATATYPE_INT_64, 8}, + {QNN_DATATYPE_UINT_8, 1}, + {QNN_DATATYPE_UINT_16, 2}, + {QNN_DATATYPE_UINT_32, 4}, + {QNN_DATATYPE_UINT_64, 8}, + {QNN_DATATYPE_FLOAT_16, 2}, + {QNN_DATATYPE_FLOAT_32, 4}, + {QNN_DATATYPE_BOOL_8, 1}, + {QNN_DATATYPE_SFIXED_POINT_8, 1}, + {QNN_DATATYPE_SFIXED_POINT_16, 2}, + {QNN_DATATYPE_SFIXED_POINT_32, 4}, + {QNN_DATATYPE_UFIXED_POINT_8, 1}, + {QNN_DATATYPE_UFIXED_POINT_16, 2}, + {QNN_DATATYPE_UFIXED_POINT_32, 4}, + }; + + auto pos = data_type_to_size.find(data_type); + ORT_ENFORCE(pos != data_type_to_size.end(), "Unknown QNN data type", data_type); + return pos->second; +} +size_t GetElementSizeByType(ONNXTensorElementDataType elem_type) { + const static std::unordered_map elem_type_to_size = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, sizeof(Int4x2)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, sizeof(UInt4x2)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, sizeof(int16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, sizeof(int64_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, sizeof(uint32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, sizeof(uint64_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, 2}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, sizeof(double)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, sizeof(bool)}}; + + auto pos = elem_type_to_size.find(elem_type); + ORT_ENFORCE(pos != elem_type_to_size.end(), "Unknown element type", elem_type); + return pos->second; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { + switch (scalar.dataType) { + case QNN_DATATYPE_INT_8: + out << static_cast(scalar.int8Value); + break; + case QNN_DATATYPE_INT_16: + out << scalar.int16Value; + break; + case QNN_DATATYPE_INT_32: + out << scalar.int32Value; + break; + case QNN_DATATYPE_INT_64: + out << "int64_t is not supported"; + break; + case QNN_DATATYPE_UINT_8: + out << static_cast(scalar.uint8Value); + break; + case QNN_DATATYPE_UINT_16: + out << scalar.uint16Value; + break; + case QNN_DATATYPE_UINT_32: + out << scalar.uint32Value; + break; + case QNN_DATATYPE_UINT_64: + out << "uint64_t is not supported"; + break; + case QNN_DATATYPE_FLOAT_16: + break; + case QNN_DATATYPE_FLOAT_32: + out << scalar.floatValue; + break; + case QNN_DATATYPE_SFIXED_POINT_8: + case QNN_DATATYPE_SFIXED_POINT_16: + case QNN_DATATYPE_SFIXED_POINT_32: + case QNN_DATATYPE_UFIXED_POINT_8: + case QNN_DATATYPE_UFIXED_POINT_16: + case QNN_DATATYPE_UFIXED_POINT_32: + out << "usigned fixedpoint data is not supported"; + break; + case QNN_DATATYPE_BOOL_8: + out << static_cast(scalar.bool8Value); + break; + default: + ORT_THROW("Unknown Qnn Data type"); + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_DataType_t& data_type) { + switch (data_type) { + case QNN_DATATYPE_INT_8: + out << "QNN_DATATYPE_INT_8"; + break; + case QNN_DATATYPE_INT_16: + out << "QNN_DATATYPE_INT_16"; + break; + case QNN_DATATYPE_INT_32: + out << "QNN_DATATYPE_INT_32"; + break; + case QNN_DATATYPE_INT_64: + out << "QNN_DATATYPE_INT_64"; + break; + case QNN_DATATYPE_UINT_8: + out << "QNN_DATATYPE_UINT_8"; + break; + case QNN_DATATYPE_UINT_16: + out << "QNN_DATATYPE_UINT_16"; + break; + case QNN_DATATYPE_UINT_32: + out << "QNN_DATATYPE_UINT_32"; + break; + case QNN_DATATYPE_UINT_64: + out << "QNN_DATATYPE_UINT_64"; + break; + case QNN_DATATYPE_FLOAT_16: + out << "QNN_DATATYPE_FLOAT_16"; + break; + case QNN_DATATYPE_FLOAT_32: + out << "QNN_DATATYPE_FLOAT_32"; + break; + case QNN_DATATYPE_SFIXED_POINT_8: + out << "QNN_DATATYPE_SFIXED_POINT_8"; + break; + case QNN_DATATYPE_SFIXED_POINT_16: + out << "QNN_DATATYPE_SFIXED_POINT_16"; + break; + case QNN_DATATYPE_SFIXED_POINT_32: + out << "QNN_DATATYPE_SFIXED_POINT_32"; + break; + case QNN_DATATYPE_UFIXED_POINT_8: + out << "QNN_DATATYPE_UFIXED_POINT_8"; + break; + case QNN_DATATYPE_UFIXED_POINT_16: + out << "QNN_DATATYPE_UFIXED_POINT_16"; + break; + case QNN_DATATYPE_UFIXED_POINT_32: + out << "QNN_DATATYPE_UFIXED_POINT_32"; + break; + case QNN_DATATYPE_BOOL_8: + out << "QNN_DATATYPE_BOOL_8"; + break; + case QNN_DATATYPE_SFIXED_POINT_4: + out << "QNN_DATATYPE_SFIXED_POINT_4"; + break; + case QNN_DATATYPE_UFIXED_POINT_4: + out << "QNN_DATATYPE_UFIXED_POINT_4"; + break; + default: + ORT_THROW("Unknown Qnn Data type"); + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_Definition_t& definition) { + switch (definition) { + case QNN_DEFINITION_IMPL_GENERATED: + out << "QNN_DEFINITION_IMPL_GENERATED"; + break; + case QNN_DEFINITION_DEFINED: + out << "QNN_DEFINITION_DEFINED"; + break; + case QNN_DEFINITION_UNDEFINED: + out << "QNN_DEFINITION_UNDEFINED"; + break; + default: + out << "Undefined"; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_QuantizationEncoding_t& encoding) { + switch (encoding) { + case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: + out << "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET"; + break; + case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: + out << "QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET"; + break; + case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET: + out << "QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET"; + break; + case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: + out << "QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET"; + break; + case QNN_QUANTIZATION_ENCODING_UNDEFINED: + out << "QNN_QUANTIZATION_ENCODING_UNDEFINED"; + break; + default: + out << "Uknown quantization encoding"; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_QuantizeParams_t& quantize_params) { + out << " encodingDefinition=" << quantize_params.encodingDefinition; + out << " quantizationEncoding=" << quantize_params.quantizationEncoding; + if (quantize_params.encodingDefinition == QNN_DEFINITION_IMPL_GENERATED || + quantize_params.encodingDefinition == QNN_DEFINITION_DEFINED) { + if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + out << " scale=" << quantize_params.scaleOffsetEncoding.scale; + out << " offset=" << quantize_params.scaleOffsetEncoding.offset; + } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) { + out << " bitwidth=" << quantize_params.bwScaleOffsetEncoding.bitwidth; + out << " scale=" << quantize_params.bwScaleOffsetEncoding.scale; + out << " offset=" << quantize_params.bwScaleOffsetEncoding.offset; + } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + out << " axis=" << quantize_params.axisScaleOffsetEncoding.axis; + size_t num_elems = quantize_params.axisScaleOffsetEncoding.numScaleOffsets; + out << " scales=("; + for (size_t i = 0; i < num_elems; i++) { + out << quantize_params.axisScaleOffsetEncoding.scaleOffset[i].scale << (i == num_elems - 1 ? "" : " "); + } + out << ") offsets=("; + for (size_t i = 0; i < num_elems; i++) { + out << quantize_params.axisScaleOffsetEncoding.scaleOffset[i].offset << (i == num_elems - 1 ? "" : " "); + } + out << ")"; + } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + out << " axis=" << quantize_params.bwAxisScaleOffsetEncoding.axis; + out << " bw=" << quantize_params.bwAxisScaleOffsetEncoding.bitwidth; + size_t num_elems = quantize_params.bwAxisScaleOffsetEncoding.numElements; + out << " scales=("; + for (size_t i = 0; i < num_elems; i++) { + out << quantize_params.bwAxisScaleOffsetEncoding.scales[i] << (i == num_elems - 1 ? "" : " "); + } + out << ") offsets=("; + for (size_t i = 0; i < num_elems; i++) { + out << quantize_params.bwAxisScaleOffsetEncoding.offsets[i] << (i == num_elems - 1 ? "" : " "); + } + out << ")"; + } else { + out << " encoding not supported."; + } + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_TensorType_t& tensor_type) { + switch (tensor_type) { + case QNN_TENSOR_TYPE_APP_WRITE: + out << "QNN_TENSOR_TYPE_APP_WRITE"; + break; + case QNN_TENSOR_TYPE_APP_READ: + out << "QNN_TENSOR_TYPE_APP_READ"; + break; + case QNN_TENSOR_TYPE_APP_READWRITE: + out << "QNN_TENSOR_TYPE_APP_READWRITE"; + break; + case QNN_TENSOR_TYPE_NATIVE: + out << "QNN_TENSOR_TYPE_NATIVE"; + break; + case QNN_TENSOR_TYPE_STATIC: + out << "QNN_TENSOR_TYPE_STATIC"; + break; + case QNN_TENSOR_TYPE_NULL: + out << "QNN_TENSOR_TYPE_NULL"; + break; + default: + out << "Unsupported type"; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_TensorMemType_t& mem_type) { + switch (mem_type) { + case QNN_TENSORMEMTYPE_RAW: + out << "QNN_TENSORMEMTYPE_RAW"; + break; + case QNN_TENSORMEMTYPE_MEMHANDLE: + out << "QNN_TENSORMEMTYPE_MEMHANDLE"; + break; + default: + out << "Unsupported mem type"; + } + return out; +} +template +std::ostream& operator<<(std::ostream& out, const Qnn_ClientBuffer_t& client_bufer) { + T* data = reinterpret_cast(client_bufer.data); + out << " dataSize=" << client_bufer.dataSize; + uint32_t count = client_bufer.dataSize / sizeof(T); + const bool truncate = count > 100; + + count = truncate ? 100 : count; // limit to 100 data + out << " clientBuf=("; + for (uint32_t i = 0; i < count; i++) { + if constexpr (sizeof(T) == 1) { + out << static_cast(data[i]) << " "; + } else { + out << data[i] << " "; + } + } + out << (truncate ? "..." : "") << ")"; + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor) { + out << " name=" << GetQnnTensorName(tensor); + out << " id=" << GetQnnTensorID(tensor); + out << " version=" << tensor.version; + out << " type=" << GetQnnTensorType(tensor); + out << " dataFormat=" << GetQnnTensorDataFormat(tensor); + out << " dataType=" << GetQnnTensorDataType(tensor); + out << " rank=" << GetQnnTensorRank(tensor); + out << " dimensions=("; + for (uint32_t i = 0; i < GetQnnTensorRank(tensor); i++) { + out << GetQnnTensorDims(tensor)[i] << " "; + } + out << ")"; + out << " memType=" << GetQnnTensorMemType(tensor); +// TODO: the code below has compilation errors with the latest ABSL +#if 0 + if (GetQnnTensorMemType(tensor) == QNN_TENSORMEMTYPE_RAW) { + if (GetQnnTensorDataType(tensor) == QNN_DATATYPE_FLOAT_32) { + operator<< (out, GetQnnTensorClientBuf(tensor)); + } else if (GetQnnTensorDataType(tensor) == QNN_DATATYPE_UINT_32 || + GetQnnTensorDataType(tensor) == QNN_DATATYPE_UFIXED_POINT_32) { + operator<< (out, GetQnnTensorClientBuf(tensor)); + } else if (GetQnnTensorDataType(tensor) == QNN_DATATYPE_INT_32 || + GetQnnTensorDataType(tensor) == QNN_DATATYPE_SFIXED_POINT_32) { + operator<< (out, GetQnnTensorClientBuf(tensor)); + } else if (GetQnnTensorDataType(tensor) == QNN_DATATYPE_UINT_16 || + GetQnnTensorDataType(tensor) == QNN_DATATYPE_UFIXED_POINT_16) { + operator<< (out, GetQnnTensorClientBuf(tensor)); + } else if (GetQnnTensorDataType(tensor) == QNN_DATATYPE_INT_16 || + GetQnnTensorDataType(tensor) == QNN_DATATYPE_SFIXED_POINT_16) { + operator<< (out, GetQnnTensorClientBuf(tensor)); + } else if (GetQnnTensorDataType(tensor) == QNN_DATATYPE_UINT_8 || + GetQnnTensorDataType(tensor) == QNN_DATATYPE_UFIXED_POINT_8) { + operator<< (out, GetQnnTensorClientBuf(tensor)); + } else { + operator<< (out, GetQnnTensorClientBuf(tensor)); + } + } +#endif + out << " quantizeParams:" << GetQnnTensorQParams(tensor); + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_ParamType_t& param_type) { + switch (param_type) { + case QNN_PARAMTYPE_SCALAR: + out << "QNN_PARAMTYPE_SCALAR"; + break; + case QNN_PARAMTYPE_TENSOR: + out << "QNN_PARAMTYPE_TENSOR"; + break; + default: + out << "Unknown type"; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param) { + out << " type=" << qnn_param.paramType; + out << " name=" << qnn_param.name; + if (qnn_param.paramType == QNN_PARAMTYPE_TENSOR) { + out << qnn_param.tensorParam; + } else { + out << " value=" << qnn_param.scalarParam; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const QnnOpConfigWrapper& op_conf_wrapper) { + out << "Qnn_OpConfig node name: " << op_conf_wrapper.GetOpName() + << " package_name: " << op_conf_wrapper.GetPackageName() + << " QNN_op_type: " << op_conf_wrapper.GetTypeName() + << " num_of_inputs: " << op_conf_wrapper.GetInputsNum() + << " num_of_outputs: " << op_conf_wrapper.GetOutputsNum() + << " num_of_params: " << op_conf_wrapper.GetParamsNum(); + + out << std::endl + << " node_inputs:" << std::endl; + for (uint32_t i = 0; i < op_conf_wrapper.GetInputsNum(); i++) { + out << op_conf_wrapper.GetInputTensors()[i] << std::endl; + } + out << " node_outputs:" << std::endl; + for (uint32_t i = 0; i < op_conf_wrapper.GetOutputsNum(); i++) { + out << op_conf_wrapper.GetOutputTensors()[i] << std::endl; + } + out << " node_params:" << std::endl; + for (uint32_t i = 0; i < op_conf_wrapper.GetParamsNum(); i++) { + out << op_conf_wrapper.GetParams()[i] << std::endl; + } + return out; +} + +Status GetQnnDataType(const bool is_quantized_tensor, const ONNX_NAMESPACE::TypeProto* type_proto, + Qnn_DataType_t& tensor_data_type) { + if (!type_proto || !type_proto->tensor_type().has_elem_type()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The tensor doesn't have elem_type."); + } + + int32_t onnx_data_type = type_proto->tensor_type().elem_type(); + ORT_RETURN_IF_NOT(OnnxDataTypeToQnnDataType(onnx_data_type, tensor_data_type, is_quantized_tensor), + "Failed to map Onnx data type to Qnn data type!"); + + return Status::OK(); +} + +//const std::string& GetNodeName(const NodeUnit& node_unit) { +// const std::string& node_name = node_unit.Name(); +// if (node_name.empty()) { +// return node_unit.Outputs()[0].node_arg.Name(); +// } +// +// return node_name; +//} + +bool OnnxDataTypeToQnnDataType(const int32_t onnx_data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized) { + const std::unordered_map onnx_to_qnn_data_type = { + {ONNX_NAMESPACE::TensorProto_DataType_INT8, QNN_DATATYPE_INT_8}, + {ONNX_NAMESPACE::TensorProto_DataType_INT16, QNN_DATATYPE_INT_16}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, QNN_DATATYPE_INT_32}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, QNN_DATATYPE_INT_64}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, QNN_DATATYPE_UINT_8}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT16, QNN_DATATYPE_UINT_16}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, QNN_DATATYPE_UINT_32}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, QNN_DATATYPE_UINT_64}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, QNN_DATATYPE_FLOAT_16}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, QNN_DATATYPE_FLOAT_32}, + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, QNN_DATATYPE_BOOL_8}, + }; + + const std::unordered_map onnx_to_qnn_data_type_quantized = { + {ONNX_NAMESPACE::TensorProto_DataType_INT4, QNN_DATATYPE_SFIXED_POINT_8}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, QNN_DATATYPE_SFIXED_POINT_8}, + {ONNX_NAMESPACE::TensorProto_DataType_INT16, QNN_DATATYPE_SFIXED_POINT_16}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, QNN_DATATYPE_SFIXED_POINT_32}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, QNN_DATATYPE_INT_64}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT4, QNN_DATATYPE_UFIXED_POINT_8}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, QNN_DATATYPE_UFIXED_POINT_8}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT16, QNN_DATATYPE_UFIXED_POINT_16}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, QNN_DATATYPE_UFIXED_POINT_32}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, QNN_DATATYPE_UINT_64}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, QNN_DATATYPE_FLOAT_16}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, QNN_DATATYPE_FLOAT_32}, + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, QNN_DATATYPE_BOOL_8}, + }; + + const auto do_type_mapping = [](const std::unordered_map& mapping_table, + const int32_t onnx_data_type, + Qnn_DataType_t& qnn_data_type) -> bool { + auto pos = mapping_table.find(onnx_data_type); + if (pos == mapping_table.end()) { + return false; + } + qnn_data_type = pos->second; + return true; + }; + + if (is_quantized) { + return do_type_mapping(onnx_to_qnn_data_type_quantized, onnx_data_type, qnn_data_type); + } else { + return do_type_mapping(onnx_to_qnn_data_type, onnx_data_type, qnn_data_type); + } +} + +std::pair CheckMinMax(float rmin, float rmax) { + // Ensure a minimum range of 0.0001 (required by QNN) + rmax = std::max(rmax, rmin + 0.0001f); + + // Both QNN and ORT require the range to include 0.0f + rmin = std::min(rmin, 0.0f); + rmax = std::max(rmax, 0.0f); + + return std::make_pair(rmin, rmax); +} + +template +Status GetQminQmax(const Qnn_DataType_t qnn_data_type, + T& qmin, + T& qmax) { + if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); +} + +Status GetQuantParams(float rmin, + float rmax, + const Qnn_DataType_t qnn_data_type, + float& scale, + int& zero_point) { + std::tie(rmin, rmax) = CheckMinMax(rmin, rmax); + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + + scale = (rmax - rmin) / (qmax - qmin); + const float initial_zero_point = qmin - (rmin / scale); +// zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); + // To match QNN quantization definition + zero_point = 0 - zero_point; + return Status::OK(); +} + +double Dequantize(int32_t offset, float scale, const double quant_value) { + double offset_d = static_cast(offset); + double scale_d = static_cast(scale); + return (quant_value + offset_d) * scale_d; +} + +Status Quantize(const double double_value, + const float scale, + const int zero_point, + const Qnn_DataType_t qnn_data_type, + int& quant_value) { + int qmin = 0; + int qmax = 255; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + quant_value = Saturate(qmax, qmin, static_cast(std::round((double_value / scale) - zero_point))); + return Status::OK(); +} + +} // namespace utils +} // namespace qnn +} // namespace onnxruntime diff --git a/samples/qnnEp/builder/qnn_utils.h b/samples/qnnEp/builder/qnn_utils.h new file mode 100644 index 0000000000000..f8404d4856748 --- /dev/null +++ b/samples/qnnEp/builder/qnn_utils.h @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include + +#include "QnnTypes.h" +#include "core/session/onnxruntime_cxx_api.h" +//#include "core/framework/node_unit.h" +//#include "core/util/qmath.h" +#include +#include "onnx/onnx_pb.h" + +namespace onnxruntime { +namespace qnn { +class QnnOpConfigWrapper; + +namespace utils { +size_t GetElementSizeByType(const Qnn_DataType_t& data_type); + +size_t GetElementSizeByType(ONNXTensorElementDataType elem_type); + +// TODO: make these work with Wrappers? +std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param); +std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor); +std::ostream& operator<<(std::ostream& out, const QnnOpConfigWrapper& op_conf_wrapper); + +Status GetQnnDataType(const bool is_quantized_tensor, const ONNX_NAMESPACE::TypeProto* type_proto, + Qnn_DataType_t& tensor_data_type); + +//const std::string& GetNodeName(const NodeUnit& node_unit); + +bool OnnxDataTypeToQnnDataType(const int32_t data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized = false); + +//inline Status GetOnnxTensorElemDataType(const NodeArg& node_arg, /*out*/ int32_t& onnx_data_type) { +// auto type_proto = node_arg.TypeAsProto(); +// ORT_RETURN_IF_NOT(type_proto != nullptr && type_proto->has_tensor_type() && type_proto->tensor_type().has_elem_type(), +// "NodeArg must have a tensor TypeProto"); +// onnx_data_type = type_proto->tensor_type().elem_type(); +// return Status::OK(); +//} + +template +static Status InvertPerm(gsl::span perm, /*out*/ gsl::span perm_inv) { + static_assert(std::is_integral::value, "permutation arrays must contain integer elements"); + + size_t rank = perm.size(); + ORT_RETURN_IF_NOT(perm_inv.size() == rank, "perm.size() != perm_inv.size()"); + + for (size_t i = 0; i < rank; ++i) { + size_t j = static_cast(perm[i]); + ORT_RETURN_IF_NOT(j < rank, "perm element out of range [0, rank - 1]"); + perm_inv[j] = static_cast(i); + } + + return Status::OK(); +} + +// Utility function that checks if an array of strings contains a specific string. +// Used to validate ONNX operator attributes. +template +static bool ArrayHasString(const std::array& strings, std::string_view str) { + for (auto s : strings) { + if (s == str) { + return true; + } + } + + return false; +} + +std::pair CheckMinMax(float rmin, float rmax); + +template +Status GetQminQmax(const Qnn_DataType_t qnn_data_type, T& qmin, T& qmax); + +template +inline T Saturate(const T qmax, + const T qmin, + const T quant_value) { + if (quant_value > qmax) { + return qmax; + } else if (quant_value < qmin) { + return qmin; + } else { + return quant_value; + } +} + +Status GetQuantParams(float rmin, + float rmax, + const Qnn_DataType_t qnn_data_type, + float& scale, + int& zero_point); + +double Dequantize(int32_t offset, float scale, const double quant_value); + +Status Quantize(const double double_value, + const float scale, + const int zero_point, + const Qnn_DataType_t qnn_data_type, + int& quant_value); + +} // namespace utils +} // namespace qnn +} // namespace onnxruntime diff --git a/samples/qnnEp/qnn_execution_provider.cc b/samples/qnnEp/qnn_execution_provider.cc new file mode 100644 index 0000000000000..68707252cbf82 --- /dev/null +++ b/samples/qnnEp/qnn_execution_provider.cc @@ -0,0 +1,54 @@ +#include "qnn_execution_provider.h" +#include "builder/qnn_def.h" +#include +#include +namespace onnxruntime { + +static void ParseProfilingLevel(std::string profiling_level_string, + qnn::ProfilingLevel& profiling_level) { + std::transform(profiling_level_string.begin(), + profiling_level_string.end(), + profiling_level_string.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); +// LOGS_DEFAULT(INFO) << "profiling_level: " << profiling_level_string; + if (profiling_level_string == "off") { + profiling_level = qnn::ProfilingLevel::OFF; + } else if (profiling_level_string == "basic") { + profiling_level = qnn::ProfilingLevel::BASIC; + } else if (profiling_level_string == "detailed") { + profiling_level = qnn::ProfilingLevel::DETAILED; + } else { +// LOGS_DEFAULT(WARNING) << "Profiling level not valid."; + } +} + +QNNExecutionProvider::QNNExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { + type = ep_type; + OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { + }; + + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) { + }; +} + +QNNExecutionProviderFactory::QNNExecutionProviderFactory() { + OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { + ProviderOptions options; + for (size_t i = 0; i < option_size; i++) options[ep_option_keys[i]] = ep_option_values[i]; + std::unique_ptr ret = std::make_unique("QNNExecutionProvider", std::move(options)); + return ret.release(); + }; +} + +} // namespace onnxruntime + +#ifdef __cplusplus +extern "C" { +#endif +OrtExecutionProviderFactory* RegisterCustomEp() { + std::unique_ptr ret = std::make_unique(); + return ret.release(); +} +#ifdef __cplusplus +} +#endif diff --git a/samples/qnnEp/qnn_execution_provider.h b/samples/qnnEp/qnn_execution_provider.h new file mode 100644 index 0000000000000..91fa97158de0b --- /dev/null +++ b/samples/qnnEp/qnn_execution_provider.h @@ -0,0 +1,33 @@ +#pragma once +#include "core/session/onnxruntime_c_api.h" +#include "core/framework/provider_options.h" +#include + +#ifdef _WIN32 +#define EXPORT_API __declspec(dllexport) +#else +#define EXPORT_API +#endif + +namespace onnxruntime { + +struct QNNExecutionProvider : public OrtExecutionProvider { + QNNExecutionProvider(const char* ep_type, const ProviderOptions& provider_options); +private: + std::string context_cache_path_cfg_ = ""; +}; + +struct QNNExecutionProviderFactory : public OrtExecutionProviderFactory { + QNNExecutionProviderFactory(); +}; +} + +#ifdef __cplusplus +extern "C" { +#endif + +EXPORT_API OrtExecutionProviderFactory* RegisterCustomEp(); + +#ifdef __cplusplus +} +#endif From 740a6877c7663544ebe1f4bb783b99c34a881bbe Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 13 Aug 2024 21:15:22 +0000 Subject: [PATCH 14/81] more graph/node C API --- .../core/session/onnxruntime_c_api.h | 52 +++++-- onnxruntime/core/session/onnxruntime_c_api.cc | 144 ++++++++++++++++++ onnxruntime/core/session/ort_apis.h | 38 +++++ 3 files changed, 220 insertions(+), 14 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8db7a5401f53d..769a3b17b67ce 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -310,7 +310,6 @@ ORT_RUNTIME_CLASS(Node); ORT_RUNTIME_CLASS(GraphViewer); ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(TypeConstraints); -ORT_RUNTIME_CLASS(NodeUnit); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -744,19 +743,6 @@ typedef struct OrtExecutionProviderFactory { OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); } OrtExecutionProviderFactory; -typedef struct OrtNodeUnit { - enum Type { - SingleNode, - QDQGroup, - } type; - OrtNode** dq_nodes; - size_t dq_nodes_len; - OrtNode** q_nodes; - size_t q_nodes_len; - OrtNode* target_node; - size_t input_edge_count; -} OrtNodeUnit; - /** \brief Thread work loop function * * Onnxruntime will provide the working loop on custom thread creation @@ -4745,8 +4731,26 @@ struct OrtApi { ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); + ORT_API2_STATUS(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Out_ size_t* len, _Outptr_ const OrtNode*** consumers); // TODO(leca): ValueConsumers::comprehensive ? + + ORT_API2_STATUS(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); + + ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); + + ORT_API2_STATUS(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); + + ORT_API2_STATUS(OrtNode_GetDomain, const OrtNode* node, _Out_ const char** domain); + + ORT_API2_STATUS(OrtNode_SinceVersion, const OrtNode* node, _Out_ int* since_version); + + ORT_API2_STATUS(OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** ep_type); + ORT_API2_STATUS(OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type); + ORT_API2_STATUS(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* input_size); + + ORT_API2_STATUS(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); + ORT_API2_STATUS(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size); ORT_API2_STATUS(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); @@ -4755,6 +4759,26 @@ struct OrtApi { ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); + ORT_API2_STATUS(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index); + + ORT_API2_STATUS(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size); + + ORT_API2_STATUS(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count); + + ORT_API2_STATUS(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size); + + ORT_API2_STATUS(OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* float_size); + + ORT_API2_STATUS(OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* str_size); + + ORT_API2_STATUS(OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* ints); + + ORT_API2_STATUS(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* floats); + + ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs); + + ORT_API2_STATUS(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); + ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 90af21f1337c1..ee158c4bb4d24 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2404,12 +2404,71 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, s return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Out_ size_t* len, _Outptr_ const OrtNode*** consumers) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + std::vector consumer_nodes = graph_viewer->GetConsumerNodes(input_name); + len = new size_t (consumer_nodes.size()); + *consumers = new const OrtNode* [*len]; + for (size_t i = 0; i < consumer_nodes.size(); i++) (*consumers)[i] = reinterpret_cast(consumer_nodes[i]); + + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *producer = reinterpret_cast(graph_viewer->GetProducerNode(output_name)); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetName, const OrtNode* node, _Out_ const char** name) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *name = n->Name().c_str(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *description = n->Description().c_str(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetDomain, const OrtNode* node, _Out_ const char** domain) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *domain = n->Domain().c_str(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_SinceVersion, const OrtNode* node, _Out_ int* since_version) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *since_version = n->SinceVersion(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** ep_type) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *ep_type = n->GetExecutionProviderType().c_str(); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type) { const ::onnxruntime::Node* n = reinterpret_cast(node); *op_type = n->OpType().c_str(); return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* input_size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *input_size = n->ImplicitInputDefs().size(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + assert(i < n->ImplicitInputDefs().size()); + *ith_input_name = n->ImplicitInputDefs()[i]->Name().c_str(); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size) { const ::onnxruntime::Node* n = reinterpret_cast(node); *input_size = n->InputDefs().size(); @@ -2436,6 +2495,72 @@ ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthOutputName, const OrtNode* node, size return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *index = n->Index(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *attr_size = n->GetAttributes().size(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *count = n->GetAttributes().count(key); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *int_size = n->GetAttributes().at(key).ints_size(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* float_size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *float_size = n->GetAttributes().at(key).floats_size(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* str_size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *str_size = n->GetAttributes().at(key).strings_size(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* ints) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *ints = n->GetAttributes().at(key).ints(i); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* floats) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *floats = n->GetAttributes().at(key).floats(i); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *strs = n->GetAttributes().at(key).strings(i).c_str(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + std::vector> subg = n->GetSubgraphs(); + len = new size_t (subg.size()); + *subgraphs = new const OrtGraphViewer* [*len]; + for (size_t i = 0; i < subg.size(); i++) { + const ::onnxruntime::GraphViewer* graph_viewer = new const ::onnxruntime::GraphViewer(*subg[i]); + (*subgraphs)[i] = reinterpret_cast(graph_viewer); + } + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints) { KernelRegistry* kr = reinterpret_cast(kernel_registry); KernelCreateInfo kci = CreateKernelCreateInfo2("", custom_op, type_constraints); @@ -2842,11 +2967,30 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtGraph_IsConstantInitializer, &OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, &OrtApis::OrtGraph_GetOrtNode, + &OrtApis::OrtGraph_GetNodesConsumingInput, + &OrtApis::OrtGraph_GetNodeProducingOutput, + &OrtApis::OrtNode_GetName, + &OrtApis::OrtNode_GetDescription, + &OrtApis::OrtNode_GetDomain, + &OrtApis::OrtNode_SinceVersion, + &OrtApis::OrtNode_GetExecutionProviderType, &OrtApis::OrtNode_GetOpType, + &OrtApis::OrtNode_GetImplicitInputSize, + &OrtApis::OrtNode_GetIthImplicitInputName, &OrtApis::OrtNode_GetInputSize, &OrtApis::OrtNode_GetIthInputName, &OrtApis::OrtNode_GetOutputSize, &OrtApis::OrtNode_GetIthOutputName, + &OrtApis::OrtNode_GetIndex, + &OrtApis::OrtNode_GetAttributeSize, + &OrtApis::OrtNode_GetAttributeKeyCount, + &OrtApis::OrtNode_GetAttributeIntSize, + &OrtApis::OrtNode_GetAttributeFloatSize, + &OrtApis::OrtNode_GetAttributeStringSize, + &OrtApis::OrtNode_GetAttributeIthInt, + &OrtApis::OrtNode_GetAttributeIthFloat, + &OrtApis::OrtNode_GetAttributeIthStr, + &OrtApis::OrtNode_GetSubgraphs, &OrtApis::OrtKernelRegistry_RegisterKernel, &OrtApis::CreateOrtTypeConstraints, &OrtApis::AddTypeConstraint, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 10d42b49c4124..3df5e8b7df8a7 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -535,8 +535,26 @@ ORT_API_STATUS_IMPL(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphView ORT_API_STATUS_IMPL(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); +ORT_API_STATUS_IMPL(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Out_ size_t* len, _Outptr_ const OrtNode*** consumers); + +ORT_API_STATUS_IMPL(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); + +ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); + +ORT_API_STATUS_IMPL(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); + +ORT_API_STATUS_IMPL(OrtNode_GetDomain, const OrtNode* node, _Out_ const char** domain); + +ORT_API_STATUS_IMPL(OrtNode_SinceVersion, const OrtNode* node, _Out_ int* since_version); + +ORT_API_STATUS_IMPL(OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** ep_type); + ORT_API_STATUS_IMPL(OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type); +ORT_API_STATUS_IMPL(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* input_size); + +ORT_API_STATUS_IMPL(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); + ORT_API_STATUS_IMPL(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size); ORT_API_STATUS_IMPL(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); @@ -545,6 +563,26 @@ ORT_API_STATUS_IMPL(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* ou ORT_API_STATUS_IMPL(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); +ORT_API_STATUS_IMPL(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index); + +ORT_API_STATUS_IMPL(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size); + +ORT_API_STATUS_IMPL(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count); + +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size); + +ORT_API_STATUS_IMPL(OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* float_size); + +ORT_API_STATUS_IMPL(OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* str_size); + +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* ints); + +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* floats); + +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs); + +ORT_API_STATUS_IMPL(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); + ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); From dad6397e7e602a80f71a0542364eceb83c0bb0bd Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 15 Aug 2024 21:12:32 +0000 Subject: [PATCH 15/81] stream support --- .../core/session/onnxruntime_c_api.h | 7 ++++ onnxruntime/core/framework/provider_adapter.h | 10 +++++ samples/tensorRTEp/CMakeLists.txt | 28 +++++++++++++ .../tensorRTEp/tensorrt_execution_provider.cc | 41 +++++++++++++++++++ .../tensorRTEp/tensorrt_execution_provider.h | 33 +++++++++++++++ 5 files changed, 119 insertions(+) create mode 100644 samples/tensorRTEp/CMakeLists.txt create mode 100644 samples/tensorRTEp/tensorrt_execution_provider.cc create mode 100644 samples/tensorRTEp/tensorrt_execution_provider.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 769a3b17b67ce..1a73473f8cf9f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -310,6 +310,7 @@ ORT_RUNTIME_CLASS(Node); ORT_RUNTIME_CLASS(GraphViewer); ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(TypeConstraints); +ORT_RUNTIME_CLASS(Device); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -695,6 +696,11 @@ typedef struct OrtApiBase OrtApiBase; */ ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; +typedef struct OrtCreateStream { + int device_type; + void*(ORT_API_CALL* CreateStreamFunc)(const OrtDevice*); +} OrtCreateStream; + typedef struct OrtMetaDef { const char* name; const char* domain; @@ -737,6 +743,7 @@ typedef struct OrtExecutionProvider { void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); const char* type; + OrtCreateStream* create_stream; } OrtExecutionProvider; typedef struct OrtExecutionProviderFactory { diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 2b2ffe8101f89..d287c8dbcfad2 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -85,6 +85,16 @@ class ExecutionProviderAdapter : public IExecutionProvider { return Status::OK(); } + virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap&) const override { + if (ep_impl_->create_stream) { + CreateStreamFn csf = [&](const OrtDevice& device) -> std::unique_ptr { + void* stream = ep_impl_->create_stream->CreateStreamFunc(&device); + return std::make_unique(stream, device); + }; + stream_handle_registry.RegisterCreateStreamFn(static_cast(ep_impl_->create_stream->device_type), csf); + } + } + virtual std::shared_ptr GetKernelRegistry() const override { return kernel_registry_; } private: OrtExecutionProvider* ep_impl_; diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt new file mode 100644 index 0000000000000..d6fbc41cf89e6 --- /dev/null +++ b/samples/tensorRTEp/CMakeLists.txt @@ -0,0 +1,28 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=90 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake --build ./ +cmake_minimum_required(VERSION 3.26) +project(TensorRTEp VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +enable_language(CUDA) +file(TO_CMAKE_PATH CUDAToolkit_ROOT "/usr/local/cuda") +find_package(CUDAToolkit REQUIRED) + +#add_definitions(-DONNX_NAMESPACE=onnx) +#add_definitions(-DONNX_ML) +add_library(TensorRTEp SHARED tensorrt_execution_provider.cc) +target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" + "/usr/local/cuda/include") +# "/home/leca/qnn-v2.25.0.240728/include/QNN" +# "../../build/Linux/Debug/_deps/gsl-src/include" +# "../../build/Linux/Debug/_deps/onnx-src" +# "../../build/Linux/Debug/_deps/onnx-build" +# "../../build/Linux/Debug/_deps/protobuf-src/src") +# +## looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol +#target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" +# "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" +# "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" +# "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" +# "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotocd.a") diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc new file mode 100644 index 0000000000000..7444ea47cfedc --- /dev/null +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -0,0 +1,41 @@ +#include "tensorrt_execution_provider.h" +#include +#include +namespace onnxruntime { + +TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { + type = ep_type; + create_stream = new OrtCreateStream(); + create_stream->CreateStreamFunc = [](const OrtDevice* device) -> void* { + cudaStream_t stream = nullptr; + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + return stream; + }; + OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { + }; + + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) { + }; +} + +TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { + OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { + ProviderOptions options; + for (size_t i = 0; i < option_size; i++) options[ep_option_keys[i]] = ep_option_values[i]; + std::unique_ptr ret = std::make_unique("TensorrtExecutionProvider", std::move(options)); + return ret.release(); + }; +} + +} // namespace onnxruntime + +#ifdef __cplusplus +extern "C" { +#endif +OrtExecutionProviderFactory* RegisterCustomEp() { + std::unique_ptr ret = std::make_unique(); + return ret.release(); +} +#ifdef __cplusplus +} +#endif diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h new file mode 100644 index 0000000000000..7fec2da170b6d --- /dev/null +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -0,0 +1,33 @@ +#pragma once +#include "core/session/onnxruntime_c_api.h" +#include "core/framework/provider_options.h" +#include + +#ifdef _WIN32 +#define EXPORT_API __declspec(dllexport) +#else +#define EXPORT_API +#endif + +namespace onnxruntime { + +struct TensorrtExecutionProvider : public OrtExecutionProvider { + TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& provider_options); +private: + bool external_stream_ = false; +}; + +struct TensorrtExecutionProviderFactory : public OrtExecutionProviderFactory { + TensorrtExecutionProviderFactory(); +}; +} + +#ifdef __cplusplus +extern "C" { +#endif + +EXPORT_API OrtExecutionProviderFactory* RegisterCustomEp(); + +#ifdef __cplusplus +} +#endif From 94e9cf76ff6ddf635aea33aa5665135e93e897e5 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Fri, 16 Aug 2024 23:13:25 +0000 Subject: [PATCH 16/81] support data transfer and OrtDevice in out tree EP API --- .../core/session/onnxruntime_c_api.h | 24 +++++++- onnxruntime/core/framework/provider_adapter.h | 31 +++++++++- onnxruntime/core/session/onnxruntime_c_api.cc | 37 +++++++++++- onnxruntime/core/session/ort_apis.h | 12 +++- samples/outTreeEp_kernel/kernel_ep.cc | 2 +- .../tensorRTEp/tensorrt_execution_provider.cc | 58 +++++++++++++++++-- 6 files changed, 152 insertions(+), 12 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 1a73473f8cf9f..e968fe4b8c2e9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -402,6 +402,13 @@ typedef enum OrtMemoryInfoDeviceType { OrtMemoryInfoDeviceType_FPGA = 2 } OrtMemoryInfoDeviceType; +typedef enum OrtMemoryType { + OrtMemoryType_Default = 0, + OrtMemoryType_CUDA_PINNED = 1, + OrtMemoryType_HIP_PINNED = 2, + OrtMemoryType_CANN_PINNED = 3, +} OrtMemoryType; + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -737,13 +744,16 @@ typedef struct OrtNodeComputeInfo { typedef struct OrtExecutionProvider { #ifdef __cplusplus - OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr} {} + OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr} {} #endif void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); + bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); + OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); const char* type; OrtCreateStream* create_stream; + const OrtDevice* default_device; } OrtExecutionProvider; typedef struct OrtExecutionProviderFactory { @@ -4727,6 +4737,16 @@ struct OrtApi { _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, size_t num_external_initializer_files); + ORT_API2_STATUS(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out); + + ORT_API2_STATUS(DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); + + ORT_API2_STATUS(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out); + + ORT_API2_STATUS(DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out); + + ORT_CLASS_RELEASE(Device); + ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, @@ -4792,7 +4812,7 @@ struct OrtApi { ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); - ORT_API2_STATUS(ReleaseOrtTypeConstraints, _In_ OrtTypeConstraints* type_constraints); + ORT_CLASS_RELEASE(TypeConstraints); }; // struct OrtApi /* diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index d287c8dbcfad2..991cb1ec3e869 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -6,9 +6,34 @@ #include "core/framework/compute_capability.h" namespace onnxruntime { + +class DataTransferAdapter : public IDataTransfer { +public: + DataTransferAdapter(OrtExecutionProvider* ep) : ep_impl_(ep) {} + virtual bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override { + return ep_impl_->CanCopy(&src_device, &dst_device); + } + + virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const override { + OrtMemoryInfoDeviceType source_device_type = static_cast(src.Location().device.Type()); + OrtMemoryInfoDeviceType target_device_type = static_cast(dst.Location().device.Type()); + OrtMemoryType source_mem_type = static_cast(src.Location().device.MemType()); + return ToStatus(ep_impl_->CopyTensor(src.DataRaw(), source_device_type, source_mem_type, dst.MutableDataRaw(), target_device_type, src.SizeInBytes(), nullptr)); + } + + virtual common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override { + OrtMemoryInfoDeviceType source_device_type = static_cast(src.Location().device.Type()); + OrtMemoryInfoDeviceType target_device_type = static_cast(dst.Location().device.Type()); + OrtMemoryType source_mem_type = static_cast(src.Location().device.MemType()); + return ToStatus(ep_impl_->CopyTensor(src.DataRaw(), source_device_type, source_mem_type, dst.MutableDataRaw(), target_device_type, src.SizeInBytes(), stream.GetHandle())); + } +private: + OrtExecutionProvider* ep_impl_; +}; + class ExecutionProviderAdapter : public IExecutionProvider { public: - ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type), ep_impl_(ep) { + ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type, ep->default_device ? *(ep->default_device) : OrtDevice()), ep_impl_(ep) { if (ep_impl_->RegisterKernels) { kernel_registry_ = std::make_shared(); ep_impl_->RegisterKernels(reinterpret_cast(kernel_registry_.get())); @@ -95,6 +120,10 @@ class ExecutionProviderAdapter : public IExecutionProvider { } } + virtual std::unique_ptr GetDataTransfer() const override { + return std::make_unique(ep_impl_); + } + virtual std::shared_ptr GetKernelRegistry() const override { return kernel_registry_; } private: OrtExecutionProvider* ep_impl_; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ee158c4bb4d24..4e16bde7f7b0d 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2360,6 +2360,33 @@ ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #endif } +ORT_API_STATUS_IMPL(OrtApis::CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out) { + OrtDevice::DeviceType dt = static_cast(device_type); + OrtDevice::MemoryType mt = static_cast(memory_type); + std::unique_ptr device = std::make_unique(dt, mt, device_id); + *out = device.release(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out) { + *out = static_cast(device->Type()); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out) { + *out = static_cast(device->MemType()); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out) { + *out = device->Id(); + return nullptr; +} + +ORT_API(void, OrtApis::ReleaseDevice, OrtDevice* device) { + delete device; +} + ORT_API_STATUS_IMPL(OrtApis::RegisterOrtExecutionProviderLibrary, _In_ const char* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name) { API_IMPL_BEGIN void* handle = nullptr; @@ -2578,9 +2605,8 @@ ORT_API_STATUS_IMPL(OrtApis::AddTypeConstraint, _In_ OrtTypeConstraints* type_co return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::ReleaseOrtTypeConstraints, _In_ OrtTypeConstraints* type_constraints) { +ORT_API(void, OrtApis::ReleaseTypeConstraints, OrtTypeConstraints* type_constraints) { delete type_constraints; - return nullptr; } static constexpr OrtApiBase ort_api_base = { @@ -2961,6 +2987,11 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::CreateDevice, + &OrtApis::DeviceGetDeviceType, + &OrtApis::DeviceGetMemoryType, + &OrtApis::DeviceGetDeviceId, + &OrtApis::ReleaseDevice, &OrtApis::RegisterOrtExecutionProviderLibrary, &OrtApis::SessionOptionsAppendOrtExecutionProvider, @@ -2994,7 +3025,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtKernelRegistry_RegisterKernel, &OrtApis::CreateOrtTypeConstraints, &OrtApis::AddTypeConstraint, - &OrtApis::ReleaseOrtTypeConstraints, + &OrtApis::ReleaseTypeConstraints, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 3df5e8b7df8a7..4b466f17fd20a 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -524,6 +524,16 @@ ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); +ORT_API_STATUS_IMPL(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out); + +ORT_API_STATUS_IMPL(DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); + +ORT_API_STATUS_IMPL(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out); + +ORT_API_STATUS_IMPL(DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out); + +ORT_API(void, ReleaseDevice, _Frees_ptr_opt_ OrtDevice*); + ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, @@ -589,5 +599,5 @@ ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type ORT_API_STATUS_IMPL(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); -ORT_API_STATUS_IMPL(ReleaseOrtTypeConstraints, _In_ OrtTypeConstraints* type_constraints); +ORT_API(void, ReleaseTypeConstraints, _In_ OrtTypeConstraints* type_constraints); } // namespace OrtApis diff --git a/samples/outTreeEp_kernel/kernel_ep.cc b/samples/outTreeEp_kernel/kernel_ep.cc index f50536756aa6d..560bec085d341 100644 --- a/samples/outTreeEp_kernel/kernel_ep.cc +++ b/samples/outTreeEp_kernel/kernel_ep.cc @@ -63,7 +63,7 @@ KernelEp::KernelEp(const char* ep_type, const KernelEpInfo& ep_info) : info(ep_i api->AddTypeConstraint(type_constraints, "T", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); OrtCustomOp* op = new MyRelu(); api->OrtKernelRegistry_RegisterKernel(kernel_registry, op, type_constraints); - api->ReleaseOrtTypeConstraints(type_constraints); + api->ReleaseTypeConstraints(type_constraints); }; } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 7444ea47cfedc..caf2ed0f5dbb0 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -4,6 +4,58 @@ namespace onnxruntime { TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { + OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { + }; + + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) { + }; + + OrtExecutionProvider::CanCopy = [](const OrtDevice* source, const OrtDevice* target) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtMemoryInfoDeviceType source_device_type, target_device_type; + api->DeviceGetDeviceType(source, &source_device_type); + api->DeviceGetDeviceType(target, &target_device_type); + OrtMemoryType source_mem_type, target_mem_type; + api->DeviceGetMemoryType(source, &source_mem_type); + api->DeviceGetMemoryType(target, &target_mem_type); + + return source_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU || + source_mem_type == OrtMemoryType::OrtMemoryType_CUDA_PINNED || + target_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU || + target_mem_type == OrtMemoryType::OrtMemoryType_CUDA_PINNED; + }; + + OrtExecutionProvider::CopyTensor = [](const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream) -> OrtStatusPtr { + // TODO(leca): convert cudaError_t to OrtStatusPtr + if (source_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU && target_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) { + if (src != dst) { + stream ? cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, static_cast(stream)) : cudaMemcpy(dst, src, count, cudaMemcpyDeviceToDevice); + } + return nullptr; + } + if (source_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU && target_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) { + if (stream) cudaMemcpyAsync(dst, src, count, cudaMemcpyHostToDevice, static_cast(stream)); + else { + cudaMemcpy(dst, src, count, cudaMemcpyHostToDevice); + cudaStreamSynchronize(nullptr); + } + return nullptr; + } + if (source_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU && target_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) { + if (stream) cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToHost, static_cast(stream)); + else { + cudaMemcpy(dst, src, count, cudaMemcpyDeviceToHost); + cudaStreamSynchronize(nullptr); + } + return nullptr; + } + if (stream && source_mem_type == OrtMemoryType::OrtMemoryType_CUDA_PINNED) { + cudaStreamSynchronize(static_cast(stream)); + } + memcpy(dst, src, count); + return nullptr; + }; + type = ep_type; create_stream = new OrtCreateStream(); create_stream->CreateStreamFunc = [](const OrtDevice* device) -> void* { @@ -11,11 +63,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); return stream; }; - OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { - }; - OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) { - }; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + api->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, 0, &default_device); } TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { From 8698517857990242858a1146c11da0a6ec884a1a Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 20 Aug 2024 01:00:41 +0000 Subject: [PATCH 17/81] change compile return type from void to OrtStatusPtr --- .../core/session/onnxruntime_c_api.h | 4 ++- onnxruntime/core/framework/provider_adapter.h | 3 +- onnxruntime/core/session/onnxruntime_c_api.cc | 7 ++++ onnxruntime/core/session/ort_apis.h | 2 ++ samples/outTreeEp/out_tree_ep.cc | 3 +- samples/qnnEp/qnn_execution_provider.cc | 3 +- samples/tensorRTEp/onnx_ctx_model_helper.cc | 18 ++++++++++ samples/tensorRTEp/onnx_ctx_model_helper.h | 12 +++++++ .../tensorRTEp/tensorrt_execution_provider.cc | 36 +++++++++++++++++-- .../tensorRTEp/tensorrt_execution_provider.h | 4 +++ 10 files changed, 86 insertions(+), 6 deletions(-) create mode 100644 samples/tensorRTEp/onnx_ctx_model_helper.cc create mode 100644 samples/tensorRTEp/onnx_ctx_model_helper.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e968fe4b8c2e9..2c7008a14614a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -747,7 +747,7 @@ typedef struct OrtExecutionProvider { OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr} {} #endif void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); - void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); + OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); @@ -4762,6 +4762,8 @@ struct OrtApi { ORT_API2_STATUS(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); + ORT_API2_STATUS(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out); + ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); ORT_API2_STATUS(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 991cb1ec3e869..663a2a80dd250 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -88,7 +88,8 @@ class ExecutionProviderAdapter : public IExecutionProvider { std::vector cache; cache.resize(count); OrtNodeComputeInfo* cache_data = cache.data(); - ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &cache_data); + OrtStatus* ret = ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &cache_data); + if (ret != nullptr) return ToStatus(ret); node_compute_funcs.reserve(count); for (size_t i = 0; i < count; i++) { NodeComputeInfo compute_info; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4e16bde7f7b0d..0c375f7da0615 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2447,6 +2447,12 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodeProducingOutput, const OrtGraphView return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *out = graph_viewer->MaxNodeIndex(); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetName, const OrtNode* node, _Out_ const char** name) { const ::onnxruntime::Node* n = reinterpret_cast(node); *name = n->Name().c_str(); @@ -3000,6 +3006,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtGraph_GetOrtNode, &OrtApis::OrtGraph_GetNodesConsumingInput, &OrtApis::OrtGraph_GetNodeProducingOutput, + &OrtApis::OrtGraph_MaxNodeIndex, &OrtApis::OrtNode_GetName, &OrtApis::OrtNode_GetDescription, &OrtApis::OrtNode_GetDomain, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 4b466f17fd20a..94ac330d3aa50 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -549,6 +549,8 @@ ORT_API_STATUS_IMPL(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph ORT_API_STATUS_IMPL(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); +ORT_API_STATUS_IMPL(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out); + ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); ORT_API_STATUS_IMPL(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index 8aff8072de68d..cb6eed5a465d1 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -51,7 +51,7 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe } }; - OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) { + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { for (size_t i = 0; i < cnt; i++) { node_compute_info[i]->ComputeFunc = [](void* state, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { const OrtValue* input = nullptr; @@ -74,6 +74,7 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe return nullptr; }; } + return nullptr; }; } diff --git a/samples/qnnEp/qnn_execution_provider.cc b/samples/qnnEp/qnn_execution_provider.cc index 68707252cbf82..68b59a4f73c3b 100644 --- a/samples/qnnEp/qnn_execution_provider.cc +++ b/samples/qnnEp/qnn_execution_provider.cc @@ -27,7 +27,8 @@ QNNExecutionProvider::QNNExecutionProvider(const char* ep_type, const ProviderOp OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { }; - OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) { + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { + return nullptr; }; } diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc new file mode 100644 index 0000000000000..71cab7968e52e --- /dev/null +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -0,0 +1,18 @@ +#include "onnx_ctx_model_helper.h" +namespace onnxruntime { +bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + int maxNodeIndex = 0; + api->OrtGraph_MaxNodeIndex(graph_viewer, &maxNodeIndex); + for (int i = 0; i < maxNodeIndex; ++i) { + const OrtNode* node = nullptr; + api->OrtGraph_GetOrtNode(graph_viewer, i, &node); + const char* opType = nullptr; + api->OrtNode_GetOpType(node, &opType); + if (node != nullptr && strcmp(opType, EPCONTEXT_OP.c_str()) == 0) { + return true; + } + } + return false; +} +} diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.h b/samples/tensorRTEp/onnx_ctx_model_helper.h new file mode 100644 index 0000000000000..a32037f89f7aa --- /dev/null +++ b/samples/tensorRTEp/onnx_ctx_model_helper.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "core/session/onnxruntime_c_api.h" + +namespace onnxruntime { +static const std::string EPCONTEXT_OP = "EPContext"; + +bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer); +} diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index caf2ed0f5dbb0..d7adfcfc59326 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1,13 +1,38 @@ -#include "tensorrt_execution_provider.h" #include #include +#include "tensorrt_execution_provider.h" +#include "onnx_ctx_model_helper.h" namespace onnxruntime { TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { }; - OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) { + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + for (size_t j = 0; j < cnt; j++) { + std::unordered_map input_map, output_map; + size_t input_size = 0; + api->OrtNode_GetInputSize(node[j], &input_size); + for (size_t i = 0; i < input_size; i++) { + const char* ith_input_name = nullptr; + api->OrtNode_GetIthInputName(node[j], i, &ith_input_name); + input_map[ith_input_name] = i; + } + + size_t output_size = 0; + api->OrtNode_GetOutputSize(node[j], &output_size); + for (size_t i = 0; i < output_size; i++) { + const char* ith_output_name = nullptr; + api->OrtNode_GetIthOutputName(node[j], i, &ith_output_name); + output_map[ith_output_name] = i; + } + + if (GraphHasCtxNode(graph[j])) { + static_cast(this_)->CreateNodeComputeInfoFromPrecompiledEngine(graph[j], node[j], input_map, output_map, &node_compute_info[j]); + } + } + return nullptr; }; OrtExecutionProvider::CanCopy = [](const OrtDevice* source, const OrtDevice* target) { @@ -77,6 +102,13 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { }; } +void TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_funcs) { + +} + } // namespace onnxruntime #ifdef __cplusplus diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 7fec2da170b6d..6f9086d37fb68 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -13,6 +13,10 @@ namespace onnxruntime { struct TensorrtExecutionProvider : public OrtExecutionProvider { TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& provider_options); + void CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_funcs); private: bool external_stream_ = false; }; From 3d5d2bf76748862184306df64f75d338162f436d Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 20 Aug 2024 23:23:06 +0000 Subject: [PATCH 18/81] add TensorRT dependency in tensorRT EP's CMakeLists.txt --- .../core/session/onnxruntime_c_api.h | 12 + onnxruntime/core/session/onnxruntime_c_api.cc | 36 ++ onnxruntime/core/session/ort_apis.h | 12 + samples/tensorRTEp/CMakeLists.txt | 17 +- samples/tensorRTEp/nv_includes.h | 19 + samples/tensorRTEp/onnx_ctx_model_helper.cc | 201 +++++++++ samples/tensorRTEp/onnx_ctx_model_helper.h | 48 +++ .../tensorRTEp/tensorrt_execution_provider.cc | 394 +++++++++++++++++- .../tensorRTEp/tensorrt_execution_provider.h | 200 ++++++++- 9 files changed, 926 insertions(+), 13 deletions(-) create mode 100644 samples/tensorRTEp/nv_includes.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2c7008a14614a..be6ea77859098 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4762,8 +4762,16 @@ struct OrtApi { ORT_API2_STATUS(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); + int(ORT_API_CALL* OrtGraph_NumberOfNodes)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out); + size_t(ORT_API_CALL* OrtGraph_GetOutputSize)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + const char*(ORT_API_CALL* OrtGraph_GetIthOutputName)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); ORT_API2_STATUS(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); @@ -4806,6 +4814,10 @@ struct OrtApi { ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs); + const char*(ORT_API_CALL* OrtNode_GetAttributeStr)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 0c375f7da0615..b9ea2e8bac797 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2447,12 +2447,32 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodeProducingOutput, const OrtGraphView return nullptr; } +ORT_API(int, OrtApis::OrtGraph_NumberOfNodes, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->NumberOfNodes(); +} + ORT_API_STATUS_IMPL(OrtApis::OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); *out = graph_viewer->MaxNodeIndex(); return nullptr; } +ORT_API(size_t, OrtApis::OrtGraph_GetOutputSize, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->GetOutputs().size(); +} + +ORT_API(const char*, OrtApis::OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->GetOutputs()[i]->Name().c_str(); +} + +ORT_API(int32_t, OrtApis::OrtGraph_GetIthOutputElemType, const OrtGraphViewer* graph, size_t i) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->GetOutputs()[i]->TypeAsProto()->tensor_type().elem_type(); +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetName, const OrtNode* node, _Out_ const char** name) { const ::onnxruntime::Node* n = reinterpret_cast(node); *name = n->Name().c_str(); @@ -2582,6 +2602,16 @@ ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIthStr, const OrtNode* node, co return nullptr; } +ORT_API(const char*, OrtApis::OrtNode_GetAttributeStr, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).s().c_str(); +} + +ORT_API(int64_t, OrtApis::OrtNode_GetAttributeInt, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).i(); +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs) { const ::onnxruntime::Node* n = reinterpret_cast(node); std::vector> subg = n->GetSubgraphs(); @@ -3006,7 +3036,11 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtGraph_GetOrtNode, &OrtApis::OrtGraph_GetNodesConsumingInput, &OrtApis::OrtGraph_GetNodeProducingOutput, + &OrtApis::OrtGraph_NumberOfNodes, &OrtApis::OrtGraph_MaxNodeIndex, + &OrtApis::OrtGraph_GetOutputSize, + &OrtApis::OrtGraph_GetIthOutputName, + &OrtApis::OrtGraph_GetIthOutputElemType, &OrtApis::OrtNode_GetName, &OrtApis::OrtNode_GetDescription, &OrtApis::OrtNode_GetDomain, @@ -3028,6 +3062,8 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtNode_GetAttributeIthInt, &OrtApis::OrtNode_GetAttributeIthFloat, &OrtApis::OrtNode_GetAttributeIthStr, + &OrtApis::OrtNode_GetAttributeStr, + &OrtApis::OrtNode_GetAttributeInt, &OrtApis::OrtNode_GetSubgraphs, &OrtApis::OrtKernelRegistry_RegisterKernel, &OrtApis::CreateOrtTypeConstraints, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 94ac330d3aa50..949079766811b 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -549,8 +549,16 @@ ORT_API_STATUS_IMPL(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph ORT_API_STATUS_IMPL(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); +ORT_API(int, OrtGraph_NumberOfNodes, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; + ORT_API_STATUS_IMPL(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out); +ORT_API(size_t, OrtGraph_GetOutputSize, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; + +ORT_API(const char*, OrtGraph_GetIthOutputName, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; + +ORT_API(int32_t, OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; + ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); ORT_API_STATUS_IMPL(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); @@ -593,6 +601,10 @@ ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthFloat, const OrtNode* node, const cha ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs); +ORT_API(const char*, OrtNode_GetAttributeStr, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; + +ORT_API(int64_t, OrtNode_GetAttributeInt, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; + ORT_API_STATUS_IMPL(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index d6fbc41cf89e6..33ec1d1937096 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -9,19 +9,24 @@ enable_language(CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT "/usr/local/cuda") find_package(CUDAToolkit REQUIRED) -#add_definitions(-DONNX_NAMESPACE=onnx) -#add_definitions(-DONNX_ML) -add_library(TensorRTEp SHARED tensorrt_execution_provider.cc) +add_definitions(-DONNX_NAMESPACE=onnx) +add_definitions(-DONNX_ML) +add_definitions(-DNV_TENSORRT_MAJOR=10) +file(GLOB tensorrt_src "./*.cc") +add_library(TensorRTEp SHARED ${tensorrt_src}) target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" - "/usr/local/cuda/include") -# "/home/leca/qnn-v2.25.0.240728/include/QNN" + "/usr/local/cuda/include" + "/home/leca/TensorRT-10.0.1.6/include") # "../../build/Linux/Debug/_deps/gsl-src/include" # "../../build/Linux/Debug/_deps/onnx-src" # "../../build/Linux/Debug/_deps/onnx-build" # "../../build/Linux/Debug/_deps/protobuf-src/src") # ## looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol -#target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" +target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" + "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer.so" + "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer_plugin.so" + "/home/leca/TensorRT-10.0.1.6/lib/libnvonnxparser.so") # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" diff --git a/samples/tensorRTEp/nv_includes.h b/samples/tensorRTEp/nv_includes.h new file mode 100644 index 0000000000000..047f325f49b70 --- /dev/null +++ b/samples/tensorRTEp/nv_includes.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +// File to include the required TRT headers with workarounds for warnings we can't fix or not fixed yet. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) // Ignore warning C4100: unreferenced formal parameter +#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::IPluginV2' was declared deprecated +#endif + +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index 71cab7968e52e..fe022c6c0e85f 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -1,4 +1,9 @@ +#include +#include +#include #include "onnx_ctx_model_helper.h" +#include "tensorrt_execution_provider.h" + namespace onnxruntime { bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); @@ -15,4 +20,200 @@ bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { } return false; } + +/* + * Return the directory where the ep context model locates + */ +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) { + if (ep_context_file_path.empty()) { + return std::filesystem::path(); + } + std::filesystem::path ctx_path(ep_context_file_path); + if (std::filesystem::is_directory(ep_context_file_path)) { + return ctx_path; + } else { + return ctx_path.parent_path(); + } +} + +bool IsAbsolutePath(const std::string& path_string) { +#ifdef _WIN32 + onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + return path.is_absolute(); +#else + if (!path_string.empty() && path_string[0] == '/') { + return true; + } + return false; +#endif +} + +// Like "../file_path" +bool IsRelativePathToParentPath(const std::string& path_string) { +#ifdef _WIN32 + onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + auto relative_path = path.lexically_normal().make_preferred().wstring(); + if (relative_path.find(L"..", 0) != std::string::npos) { + return true; + } + return false; +#else + if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) { + return true; + } + return false; +#endif +} + +/* + * Get the weight-refitted engine cache path from a weight-stripped engine cache path + * + * Weight-stipped engine: + * An engine with weights stripped and its size is smaller than a regualr engine. + * The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine + * + * Weight-refitted engine: + * An engine that its weights have been refitted and it's simply a regular engine. + * The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine + */ +std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { + std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); + std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; + return refitted_engine_cache_path; +} + +bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { + // The weight-stripped engine cache has the naming of xxx.stripped.engine + return engine_cache_path.stem().extension().string() == ".stripped"; +} + +OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphViewer* graph_viewer) { + if (!ValidateEPCtxNode(graph_viewer)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "It's not a valid EP Context node"); + } + const OrtNode* node = nullptr; + api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); + + const int64_t embed_mode = api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str()); + if (embed_mode) { + // Get engine from byte stream. + const std::string& context_binary(api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str())); + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), + static_cast(context_binary.length()))); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; + if (!(*trt_engine_)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); + } + } else { + // Get engine from cache file. + std::string cache_path(api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str())); + + // For security purpose, in the case of running context model, TRT EP won't allow + // engine cache path to be the relative path like "../file_path" or the absolute path. + // It only allows the engine cache to be in the same directory or sub directory of the context model. + if (IsAbsolutePath(cache_path)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path).c_str()); + } + if (IsRelativePathToParentPath(cache_path)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."); + } + + // The engine cache and context model (current model) should be in the same directory + std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); + auto engine_cache_path = ctx_model_dir.append(cache_path); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + + // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled + if (!weight_stripped_engine_refit_) { + weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); + } + + // If the serialized refitted engine is present, use it directly without refitting the engine again + if (weight_stripped_engine_refit_) { + const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); + if (std::filesystem::exists(refitted_engine_cache_path)) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; + engine_cache_path = refitted_engine_cache_path.string(); + weight_stripped_engine_refit_ = false; + } + } + + if (!std::filesystem::exists(engine_cache_path)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP can't find engine cache: " + engine_cache_path.string() + + ". Please make sure engine cache is in the same directory or sub-directory of context model.").c_str()); + } + + std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*trt_engine_)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()).c_str()); + } +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); + + if (weight_stripped_engine_refit_) { + const std::string onnx_model_filename(api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str())); + std::string weight_stripped_engine_cache = engine_cache_path.string(); + auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + weight_stripped_engine_cache, + true /* path check for security */, + (*trt_engine_).get(), + true /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + } + return nullptr; +} + +bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_viewer) { + assert(api_->OrtGraph_NumberOfNodes(graph_viewer) == 1); + const OrtNode* node = nullptr; + api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); + const char* opType = nullptr; + api_->OrtNode_GetOpType(node, &opType); + assert(strcmp(opType, EPCONTEXT_OP.c_str()) == 0); + + size_t key_count = 0; + api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str(), &key_count); + // Show the warning if compute capability is not matched + if (key_count > 0) { + const char* model_compute_capability = api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str()); + // Verify if engine was compiled with ampere+ hardware compatibility enabled + if (strcmp(model_compute_capability, "80+") == 0) { +// if (std::stoi(compute_capability_) < 80) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] However, this GPU doesn't match. The compute capability of the GPU: " << compute_capability_; +// } + } else if (strcmp(model_compute_capability, compute_capability_.c_str()) != 0) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal"; +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability; +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_; + } + } + + // "embed_mode" attr and "ep_cache_context" attr should be present + api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str(), &key_count); + assert(key_count > 0); + api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str(), &key_count); + assert(key_count > 0); + + const int64_t embed_mode = api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str()); + if (embed_mode == 1) { + // engine binary data +// LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + } + + return true; +} } diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.h b/samples/tensorRTEp/onnx_ctx_model_helper.h index a32037f89f7aa..a7604bcbd5839 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.h +++ b/samples/tensorRTEp/onnx_ctx_model_helper.h @@ -2,11 +2,59 @@ // Licensed under the MIT License. #pragma once +#include #include +#include #include "core/session/onnxruntime_c_api.h" +#include "nv_includes.h" namespace onnxruntime { static const std::string EPCONTEXT_OP = "EPContext"; +static const std::string EMBED_MODE = "embed_mode"; +static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; +static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; +static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename"; +static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; +static const std::string EPCONTEXT_WARNING = + "It's suggested to set the ORT graph optimization level to 0 and \ + make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ + for the best model loading time"; bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer); +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); +bool IsAbsolutePath(const std::string& path_string); +bool IsRelativePathToParentPath(const std::string& path_string); + +class TensorRTCacheModelHandler { + public: + TensorRTCacheModelHandler(std::unique_ptr* trt_engine, + nvinfer1::IRuntime* trt_runtime, + std::string ep_context_model_path, + std::string compute_capability, + bool weight_stripped_engine_refit, + std::string onnx_model_folder_path, + bool detailed_build_log) + : trt_engine_(trt_engine), + trt_runtime_(trt_runtime), + ep_context_model_path_(ep_context_model_path), + compute_capability_(compute_capability), + weight_stripped_engine_refit_(weight_stripped_engine_refit), + onnx_model_folder_path_(onnx_model_folder_path), + detailed_build_log_(detailed_build_log) { + api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); + } + bool ValidateEPCtxNode(const OrtGraphViewer* graph_viewer); + + OrtStatusPtr GetEpContextFromGraph(const OrtGraphViewer* graph_viewer); + + private: + std::unique_ptr* trt_engine_; + nvinfer1::IRuntime* trt_runtime_; + std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory + std::string compute_capability_; + bool weight_stripped_engine_refit_; + std::string onnx_model_folder_path_; + bool detailed_build_log_; + const OrtApi* api_; +}; // TRTCacheModelHandler } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index d7adfcfc59326..caae17d3cb8ed 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1,9 +1,21 @@ #include +#include #include #include "tensorrt_execution_provider.h" #include "onnx_ctx_model_helper.h" namespace onnxruntime { +TensorrtLogger& GetTensorrtLogger(bool verbose_log) { + const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING; + static TensorrtLogger trt_logger(log_level); + if (log_level != trt_logger.get_level()) { + trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING); + } + return trt_logger; +} + +const OrtApi* TensorrtExecutionProvider::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); + TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { }; @@ -28,9 +40,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const output_map[ith_output_name] = i; } + OrtStatusPtr ret = nullptr; + TensorrtExecutionProvider* p = static_cast(this_); if (GraphHasCtxNode(graph[j])) { - static_cast(this_)->CreateNodeComputeInfoFromPrecompiledEngine(graph[j], node[j], input_map, output_map, &node_compute_info[j]); + ret = p->CreateNodeComputeInfoFromPrecompiledEngine(graph[j], node[j], input_map, output_map, &node_compute_info[j]); + } else { + } + if (ret != nullptr) return api->CreateStatus(api->GetErrorCode(ret), api->GetErrorMessage(ret)); } return nullptr; }; @@ -89,8 +106,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const return stream; }; - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - api->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, 0, &default_device); + api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, 0, &default_device); } TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { @@ -102,11 +118,381 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { }; } -void TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, +OrtStatusPtr TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log) { +#if NV_TENSORRT_MAJOR >= 10 + std::filesystem::path onnx_model_path{onnx_model_folder_path}; + onnx_model_path.append(onnx_model_filename); + if (path_check && IsAbsolutePath(onnx_model_path.string())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("For security purpose, the ONNX model path should be set with " + "a relative path, but it is an absolute path: " + + onnx_model_path.string()).c_str()); + } + if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + "The ONNX model path has '..'. For security purpose, it's not " + "allowed to point outside the directory."); + } + + if (!std::filesystem::exists(onnx_model_path)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("The ONNX model " + onnx_model_path.string() + + " does not exist.").c_str()); + } + + // weight-stripped engine refit logic + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log); + auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); + auto parser_refitter = std::unique_ptr( + nvonnxparser::createParserRefitter(*refitter, trt_logger)); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()).c_str()); + } + if (refitter->refitCudaEngine()) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine."; + } else { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()).c_str()); + } + + // serialize the refitted engine to disk + if (serialize_refitted_engine) { + std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); + nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); + std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); + engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialize the refitted engine to " << refitted_engine_cache; + } + return nullptr; +#else + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP's IParserRefitter can only be used on TRT 10.0 onwards."); +#endif +} + +OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, OrtNodeComputeInfo** node_compute_funcs) { + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index + std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index + std::unordered_map output_types; // TRT engine output name -> ORT output tensor type + + // Get engine binary data and deserialize it + auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, + runtime_.get(), + model_path_, + compute_capability_, + weight_stripped_engine_enable_, + onnx_model_folder_path_, + detailed_build_log_); + auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + + // Build context + // + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } +#if NV_TENSORRT_MAJOR < 10 + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + + const char* fused_node_name = nullptr; + api_->OrtNode_GetName(fused_node, &fused_node_name); + if (!trt_context) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP could not build execution context for fused node: " + std::string(fused_node_name)).c_str()); + } + + // Create input/output to index maps + for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + const auto& iter = input_map.find(name); + if (iter != input_map.end()) { + input_indexes[name] = iter->second; + } + } else { + const auto& iter = output_map.find(name); + if (iter != output_map.end()) { + output_indexes[name] = iter->second; + } + } + } + + // Create output to type map + size_t graph_output_size = api_->OrtGraph_GetOutputSize(graph_body_viewer); + for (size_t i = 0; i < graph_output_size; i++) { + output_types[api_->OrtGraph_GetIthOutputName(graph_body_viewer, i)] = api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); + } + + // Save TRT engine, TRT context and input/output info to map + engines_.emplace(fused_node_name, std::move(trt_engine)); + contexts_.emplace(fused_node_name, std::move(trt_context)); + input_info_[fused_node_name].push_back(input_indexes); + output_info_[fused_node_name].push_back(output_indexes); + output_info_[fused_node_name].push_back(output_types); + // Create function state + // TODO: remove default capture +// NodeComputeInfo compute_info; +// compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { +// std::unique_ptr p = std::make_unique(); +// *p = {context->allocate_func, +// context->release_func, +// context->allocator_handle, +// context->node_name, +// &engines_[context->node_name], +// &contexts_[context->node_name], +// input_info_[context->node_name], +// output_info_[context->node_name], +// context_memory_sharing_enable_, +// &max_ctx_mem_size_, +// &tensorrt_mu_}; +// *state = p.release(); +// return 0; +// }; +// +// // Release function state +// compute_info.release_state_func = [](FunctionState state) { +// delete static_cast(state); +// }; +// +// // Create compute function +// compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { +// Ort::KernelContext ctx(context); +// +// TensorrtShortFuncState* trt_state = reinterpret_cast(state); +// +// // The whole compute_function should be considered the critical section. +// // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading +// std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); +// +// const std::unordered_map& input_indexes = (trt_state->input_info)[0]; +// const std::unordered_map& output_indexes = (trt_state->output_info)[0]; +// const std::unordered_map& output_types = (trt_state->output_info)[1]; +// auto fused_node_name = trt_state->fused_node_name; +// auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; +// auto trt_engine = trt_state->engine->get(); +// auto trt_context = trt_state->context->get(); +// auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; +// int num_outputs = static_cast(output_indexes.size()); +// std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run +// std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input +// +// OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); +// OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); +// if (alloc_ == nullptr) { +// Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); +// } +// OrtAllocator* alloc = alloc_; +// +// void* cuda_stream; +// Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); +// cudaStream_t stream = static_cast(cuda_stream); +// +// // Get input and output binding names +// int total_bindings = trt_engine->getNbIOTensors(); +// std::vector input_binding_names, output_binding_names; +// for (int i = 0, end = total_bindings; i < end; ++i) { +// auto const& name = trt_engine->getIOTensorName(i); +// auto const& mode = trt_engine->getTensorIOMode(name); +// if (mode == nvinfer1::TensorIOMode::kINPUT) { +// input_binding_names.push_back(name); +// } else { +// output_binding_names.push_back(name); +// } +// } +// +// /* +// * Set input shapes and bind input buffers +// */ +// std::vector> scratch_buffers; +// for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { +// char const* input_name = input_binding_names[i]; +// +// size_t input_index = 0; +// const auto iter = input_indexes.find(input_name); +// if (iter != input_indexes.end()) { +// input_index = iter->second; +// } +// +// Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); +// if (status != Status::OK()) { +// return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, status.ErrorMessage()); +// } +// } +// +// /* +// * Set output shapes and bind output buffers +// */ +// std::unordered_map buffers; +// buffers.reserve(num_outputs); +// using OutputOrtValue = Ort::UnownedValue; +// std::unordered_map output_tensors; +// output_tensors.reserve(num_outputs); +// std::unordered_map output_dim_sizes; +// output_dim_sizes.reserve(num_outputs); +// +// for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { +// char const* output_name = output_binding_names[i]; +// +// size_t output_index = 0; +// const auto& index_iter = output_indexes.find(output_name); +// if (index_iter != output_indexes.end()) { +// output_index = index_iter->second; +// } +// +// size_t output_type = 0; +// const auto type_iter = output_types.find(output_name); +// if (type_iter != output_types.end()) { +// output_type = type_iter->second; +// } +// +// Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, +// dds_output_allocator_map, scratch_buffers, alloc, buffers); +// if (status != Status::OK()) { +// return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, status.ErrorMessage()); +// } +// } +// +// // Set execution context memory +// if (trt_state->context_memory_sharing_enable) { +//#if defined(_MSC_VER) +//#pragma warning(push) +//#pragma warning(disable : 4996) +//#endif +// size_t mem_size = trt_engine->getDeviceMemorySize(); +//#if defined(_MSC_VER) +//#pragma warning(pop) +//#endif +// if (mem_size > *max_context_mem_size_ptr) { +// *max_context_mem_size_ptr = mem_size; +// } +// trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); +// } +// +// // Start CUDA graph capture. +// // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because +// // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. +// if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { +// LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; +// cuda_graph_.SetStream(stream); +// CaptureBegin(0); +// } +// +// // Run TRT inference +// if (!trt_context->enqueueV3(stream)) { +// return api_->CreateStatus(OrtErrorCode::ORT_FAIL, "TensorRT EP execution context enqueue failed."); +// } +// +// /* +// * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, +// * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: +// * +// * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. +// * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, +// * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. +// * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. +// * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream +// * is guaranteed. +// * +// * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. +// * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. +// */ +// if (sync_stream_after_enqueue_) { +// CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); +// } +// +// // Assign TRT output back to ORT output +// // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) +// // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output +// for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { +// char const* output_name = output_binding_names[i]; +// +// size_t output_type = 0; +// const auto& iter = output_types.find(output_name); +// if (iter != output_types.end()) { +// output_type = iter->second; +// } +// +// if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { +// size_t output_index = 0; +// const auto& index_iter = output_indexes.find(output_name); +// if (index_iter != output_indexes.end()) { +// output_index = index_iter->second; +// } +// auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); +// if (status != Status::OK()) { +// return api_->CreateStatus(OrtErrorCode::ORT_FAIL, status.ErrorMessage()); +// } +// } else { +// auto& output_tensor = output_tensors[i]; +//#if NV_TENSORRT_MAJOR < 10 +// if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { +// auto output_tensor_ptr = output_tensor.GetTensorMutableData(); +// if (output_tensor_ptr != nullptr) { +// cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); +// } +// } +//#endif +// if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { +// auto output_tensor_ptr = output_tensor.GetTensorMutableData(); +// if (output_tensor_ptr != nullptr) { +// cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); +// } +// } +// } +// } +// +// // End CUDA graph capture. +// // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture +// // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. +// // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. +// if (cuda_graph_enable_ && !IsGraphCaptured(0)) { +// if (IsGraphCaptureAllowed()) { +// CaptureEnd(0); +// // CUDA work issued to a capturing stream doesn’t actually run on the GPU, +// // so run the captured graph here to actually execute the work. +// ORT_RETURN_IF_ERROR(ReplayGraph(0)); +// } else { +// IncrementRegularRunCountBeforeGraphCapture(); +// } +// } +// +// return Status::OK(); +// }; +// +// node_compute_funcs.push_back(compute_info); + return nullptr; } } // namespace onnxruntime diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 6f9086d37fb68..50f4a21de4230 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -1,7 +1,10 @@ #pragma once +#include +#include +#include #include "core/session/onnxruntime_c_api.h" #include "core/framework/provider_options.h" -#include +#include "nv_includes.h" #ifdef _WIN32 #define EXPORT_API __declspec(dllexport) @@ -11,14 +14,205 @@ namespace onnxruntime { +class TensorrtLogger : public nvinfer1::ILogger { + nvinfer1::ILogger::Severity verbosity_; + + public: + TensorrtLogger(Severity verbosity = Severity::kWARNING) + : verbosity_(verbosity) {} + void log(Severity severity, const char* msg) noexcept override { + if (severity <= verbosity_) { + time_t rawtime = std::time(0); + struct tm stm; +#ifdef _MSC_VER + gmtime_s(&stm, &rawtime); +#else + gmtime_r(&rawtime, &stm); +#endif + char buf[256]; + strftime(&buf[0], 256, + "%Y-%m-%d %H:%M:%S", + &stm); + const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" : severity == Severity::kERROR ? " ERROR" + : severity == Severity::kWARNING ? "WARNING" + : severity == Severity::kINFO ? " INFO" + : "UNKNOWN"); + if (severity <= Severity::kERROR) { +// LOGS_DEFAULT(ERROR) << "[" << buf << " " << sevstr << "] " << msg; + } else { +// LOGS_DEFAULT(WARNING) << "[" << buf << " " << sevstr << "] " << msg; + } + } + } + void set_level(Severity verbosity) { + verbosity_ = verbosity; + } + Severity get_level() const { + return verbosity_; + } +}; + +namespace tensorrt_ptr { + +struct TensorrtInferDeleter { + template + void operator()(T* obj) const { + if (obj) { + delete obj; + } + } +}; + +template +using unique_pointer = std::unique_ptr; +}; // namespace tensorrt_ptr + +class OutputAllocator : public nvinfer1::IOutputAllocator { + public: +#if NV_TENSORRT_MAJOR >= 10 + void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override; +#else + void* reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override; +#endif + void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; + + void* getBuffer() { + return outputPtr; + } + + std::vector& getOutputShape() { + return output_shapes; + } + + uint64_t getSize() { + return allocated_size; + } + + ~OutputAllocator() override { + cudaFree(outputPtr); + } + + private: + void* outputPtr{nullptr}; + uint64_t allocated_size = 0; + std::vector output_shapes; +}; + +using ShapeRangesMap = std::unordered_map>>>; + +using DDSOutputAllocatorMap = std::unordered_map>; +std::string GetWeightRefittedEnginePath(std::string engine_cache_path); + struct TensorrtExecutionProvider : public OrtExecutionProvider { TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& provider_options); - void CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, + OrtStatusPtr CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, OrtNodeComputeInfo** node_compute_funcs); + static OrtStatusPtr RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log); private: - bool external_stream_ = false; + static const OrtApi* api_; +// mutable TensorrtExecutionProviderInfo info_; + bool external_stream_ = false; + cudaStream_t stream_ = nullptr; + int max_partition_iterations_ = 1000; + size_t min_subgraph_size_ = 1; + size_t max_workspace_size_ = 1 << 30; // 1GB + bool fp16_enable_ = false; + bool int8_enable_ = false; + bool dla_enable_ = false; + int dla_core_ = 0; + bool force_sequential_engine_build_ = false; + std::string int8_calibration_cache_name_; + bool int8_calibration_cache_available_ = false; + bool int8_use_native_tensorrt_calibration_table_ = false; + bool dump_subgraphs_ = false; + bool engine_cache_enable_ = false; + bool weight_stripped_engine_enable_ = false; + bool weight_stripped_engine_refit_ = false; + std::string onnx_model_folder_path_; + bool build_heuristics_enable_ = false; + bool sparsity_enable_ = false; + int builder_optimization_level_ = 3; + int auxiliary_streams_ = -1; + std::string tactic_sources_; + std::string global_cache_path_, cache_path_, engine_decryption_lib_path_; + std::unique_ptr runtime_ = nullptr; +// OrtMutex tensorrt_mu_; + int device_id_; + std::string compute_capability_; + bool context_memory_sharing_enable_ = false; + bool layer_norm_fp32_fallback_ = false; + size_t max_ctx_mem_size_ = 0; +// IAllocatorUniquePtr context_memory_ = nullptr; + mutable char model_path_[4096] = {}; // Reserved for max path length + bool engine_decryption_enable_ = false; + int (*engine_decryption_)(const char*, char*, size_t*) = nullptr; + int (*engine_encryption_)(const char*, char*, size_t) = nullptr; + bool timing_cache_enable_ = false; + bool force_timing_cache_match_ = false; + bool detailed_build_log_ = false; + bool cuda_graph_enable_ = false; + std::string cache_prefix_; + bool engine_hw_compatible_ = false; + + // The OrtAllocator object will be get during ep compute time + // and should be kept for the lifetime of TRT EP object. + OrtAllocator* alloc_ = nullptr; + + // For create/dump EP context node model + bool dump_ep_context_model_ = false; + std::string ep_context_file_path_; + int ep_context_embed_mode_ = 0; + std::string ctx_model_path_; + std::string ep_cache_context_attr_; + std::string engine_cache_relative_path_to_context_model_dir; +// std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create(); + + std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; +// mutable std::unordered_map> subgraph_context_map_; + + mutable std::unique_ptr builder_; + + // Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading. + // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. + // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. + std::unordered_map> parsers_; + std::unordered_map> engines_; + std::unordered_map> contexts_; + std::unordered_map> builders_; + std::unordered_map> networks_; + std::unordered_map>> input_info_; + std::unordered_map>> output_info_; + std::unordered_map>> profile_min_shapes_; + std::unordered_map>> profile_max_shapes_; + std::unordered_map>> profile_opt_shapes_; + std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with + std::unordered_map> profiles_; + std::unordered_map dds_output_allocator_maps_; + + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture +// cudnnHandle_t external_cudnn_handle_ = nullptr; +// cublasHandle_t external_cublas_handle_ = nullptr; + + // Call cudaStreamSynchronize() after TRT enqueueV3() + mutable bool sync_stream_after_enqueue_ = true; + +// CUDAGraph cuda_graph_; +// bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; struct TensorrtExecutionProviderFactory : public OrtExecutionProviderFactory { From 1f10c28f83dd37b47eaf39806d7769d4bd3603b8 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 22 Aug 2024 00:57:46 +0000 Subject: [PATCH 19/81] Add extra parameters in OrtExecutionProvider to avoid capture variables in Compile function's lambda --- .../core/session/onnxruntime_c_api.h | 9 +- onnxruntime/core/framework/provider_adapter.h | 4 +- samples/outTreeEp/out_tree_ep.cc | 7 +- .../tensorRTEp/tensorrt_execution_provider.cc | 86 +++++++++---------- .../tensorRTEp/tensorrt_execution_provider.h | 65 ++++++++++++++ 5 files changed, 121 insertions(+), 50 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index be6ea77859098..ddee3ce9746e5 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -737,14 +737,15 @@ typedef struct OrtComputeContext { } OrtComputeContext; typedef struct OrtNodeComputeInfo { - int(ORT_API_CALL* CreateFunctionStateFunc)(OrtComputeContext*, void**); - OrtStatusPtr(ORT_API_CALL* ComputeFunc)(void*, const OrtApi*, OrtKernelContext*); + int(ORT_API_CALL* CreateFunctionStateFunc)(OrtComputeContext*, void*, void**); + OrtStatusPtr(ORT_API_CALL* ComputeFunc)(void*, void*, const OrtApi*, OrtKernelContext*); void(ORT_API_CALL* DestroyFunctionStateFunc)(void*); } OrtNodeComputeInfo; typedef struct OrtExecutionProvider { #ifdef __cplusplus - OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr} {} + OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr}, + extra_param_for_create_state_func{nullptr}, extra_param_for_compute_func{nullptr} {} #endif void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); @@ -754,6 +755,8 @@ typedef struct OrtExecutionProvider { const char* type; OrtCreateStream* create_stream; const OrtDevice* default_device; + void* extra_param_for_create_state_func; + void* extra_param_for_compute_func; } OrtExecutionProvider; typedef struct OrtExecutionProviderFactory { diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 663a2a80dd250..ce3792cd94fe6 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -94,11 +94,11 @@ class ExecutionProviderAdapter : public IExecutionProvider { for (size_t i = 0; i < count; i++) { NodeComputeInfo compute_info; compute_info.create_state_func = [&, cache, i](ComputeContext* context, void** state) { - if (cache[i].CreateFunctionStateFunc) return cache[i].CreateFunctionStateFunc(reinterpret_cast(context), state); + if (cache[i].CreateFunctionStateFunc) return cache[i].CreateFunctionStateFunc(reinterpret_cast(context), ep_impl_->extra_param_for_create_state_func, state); return 0; }; compute_info.compute_func = [&, cache, i](void* state, const OrtApi* api, OrtKernelContext* context) { - return ToStatus(cache[i].ComputeFunc(state, api, context)); + return ToStatus(cache[i].ComputeFunc(state, ep_impl_->extra_param_for_compute_func, api, context)); }; compute_info.release_state_func = [&, cache, i](void* state) { if (cache[i].DestroyFunctionStateFunc) { diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index cb6eed5a465d1..f2602201efe96 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -1,6 +1,7 @@ #include "out_tree_ep.h" #include #include +#include namespace onnxruntime { OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExecutionProvider(), info(ep_info) { @@ -52,8 +53,10 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe }; OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { + OutTreeEp* p = static_cast(this_); + this_->extra_param_for_compute_func = p; for (size_t i = 0; i < cnt; i++) { - node_compute_info[i]->ComputeFunc = [](void* state, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + node_compute_info[i]->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { const OrtValue* input = nullptr; api->KernelContext_GetInput(context, 0, &input); std::vector dim(1,4); @@ -71,6 +74,8 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe output_raw[i] = 1.0; } + OutTreeEp* this_ = reinterpret_cast(extra_param); + std::cout<<"int_property: "<info.int_property<<"\nstr_property: "<info.str_property<<"\n"; return nullptr; }; } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index caae17d3cb8ed..21c9577aed071 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -22,6 +22,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + TensorrtExecutionProvider* p = static_cast(this_); + this_->extra_param_for_create_state_func = p; + this_->extra_param_for_compute_func = p; for (size_t j = 0; j < cnt; j++) { std::unordered_map input_map, output_map; size_t input_size = 0; @@ -41,7 +44,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } OrtStatusPtr ret = nullptr; - TensorrtExecutionProvider* p = static_cast(this_); if (GraphHasCtxNode(graph[j])) { ret = p->CreateNodeComputeInfoFromPrecompiledEngine(graph[j], node[j], input_map, output_map, &node_compute_info[j]); } else { @@ -262,45 +264,42 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi output_info_[fused_node_name].push_back(output_types); // Create function state - // TODO: remove default capture -// NodeComputeInfo compute_info; -// compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { -// std::unique_ptr p = std::make_unique(); -// *p = {context->allocate_func, -// context->release_func, -// context->allocator_handle, -// context->node_name, -// &engines_[context->node_name], -// &contexts_[context->node_name], -// input_info_[context->node_name], -// output_info_[context->node_name], -// context_memory_sharing_enable_, -// &max_ctx_mem_size_, -// &tensorrt_mu_}; -// *state = p.release(); -// return 0; -// }; -// -// // Release function state -// compute_info.release_state_func = [](FunctionState state) { -// delete static_cast(state); -// }; -// -// // Create compute function -// compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { -// Ort::KernelContext ctx(context); -// -// TensorrtShortFuncState* trt_state = reinterpret_cast(state); -// -// // The whole compute_function should be considered the critical section. -// // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + (*node_compute_funcs)->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + std::unique_ptr p = std::make_unique(); + *p = { context->AllocateFunc, + context->DestroyFunc, + context->allocator_handle, + context->node_name, + &(this_->engines_[context->node_name]), + &(this_->contexts_[context->node_name]), + this_->input_info_[context->node_name], + this_->output_info_[context->node_name], + this_->context_memory_sharing_enable_, + &this_->max_ctx_mem_size_}; + *state = p.release(); + return 0; + }; + + // Release function state + (*node_compute_funcs)->DestroyFunctionStateFunc = [](void* state) { + delete reinterpret_cast(state); + }; + + // Create compute function + (*node_compute_funcs)->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + TensorrtShortFuncState* trt_state = reinterpret_cast(state); + + // The whole compute_function should be considered the critical section. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); -// -// const std::unordered_map& input_indexes = (trt_state->input_info)[0]; -// const std::unordered_map& output_indexes = (trt_state->output_info)[0]; -// const std::unordered_map& output_types = (trt_state->output_info)[1]; -// auto fused_node_name = trt_state->fused_node_name; -// auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; + + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto fused_node_name = trt_state->fused_node_name; + auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; // auto trt_engine = trt_state->engine->get(); // auto trt_context = trt_state->context->get(); // auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; @@ -487,11 +486,10 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // IncrementRegularRunCountBeforeGraphCapture(); // } // } -// -// return Status::OK(); -// }; -// -// node_compute_funcs.push_back(compute_info); + + return nullptr; + }; + return nullptr; } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 50f4a21de4230..70465bcdec876 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -14,6 +14,9 @@ namespace onnxruntime { +using AllocateFunc = void* (*)(void*, size_t, size_t); +using DestroyFunc = void (*)(void*, void*); + class TensorrtLogger : public nvinfer1::ILogger { nvinfer1::ILogger::Severity verbosity_; @@ -100,6 +103,68 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { using ShapeRangesMap = std::unordered_map>>>; +struct TensorrtFuncState { + AllocateFunc test_allocate_func = nullptr; + DestroyFunc test_release_func = nullptr; + void* allocator = nullptr; + std::string fused_node_name; + nvinfer1::IBuilder* builder; + tensorrt_ptr::unique_pointer* parser = nullptr; + std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; + std::unique_ptr* network = nullptr; + std::vector> input_info; + std::vector> output_info; + std::unordered_map>>> input_shape_ranges; +// OrtMutex* tensorrt_mu_ptr = nullptr; + bool fp16_enable = false; + bool int8_enable = false; + bool int8_calibration_cache_available = false; + bool dla_enable = false; + int dla_core = 0; + size_t* max_workspace_size_ptr = nullptr; + std::string trt_node_name_with_precision; + bool engine_cache_enable = false; + std::string engine_cache_path; + nvinfer1::IRuntime* runtime = nullptr; + std::vector profiles; + bool context_memory_sharing_enable = false; + size_t* max_context_mem_size_ptr = nullptr; + std::unordered_map dynamic_range_map; + bool engine_decryption_enable = false; + int (*engine_decryption)(const char*, char*, size_t*) = nullptr; + int (*engine_encryption)(const char*, char*, size_t) = nullptr; + bool timing_cache_enable = true; + std::string timing_cache_path; + bool force_timing_cache = false; + bool detailed_build_log = false; + bool build_heuristics_enable = false; + bool sparsity_enable = false; + int builder_optimization_level = 3; + int auxiliary_streams = -1; + bool filter_tactic_sources = false; + nvinfer1::TacticSources tactic_sources; + bool cuda_graph_enable = 0; + std::string cache_prefix; + std::string cache_suffix; + bool engine_hw_compatible = false; +}; + +// Minimum information to construct kernel function state for direct engine load code path +struct TensorrtShortFuncState { + AllocateFunc test_allocate_func = nullptr; + DestroyFunc test_release_func = nullptr; + void* allocator = nullptr; + std::string fused_node_name; + std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; + std::vector> input_info; + std::vector> output_info; + bool context_memory_sharing_enable = false; + size_t* max_context_mem_size_ptr = nullptr; +// OrtMutex* tensorrt_mu_ptr = nullptr; +}; + using DDSOutputAllocatorMap = std::unordered_map>; std::string GetWeightRefittedEnginePath(std::string engine_cache_path); From 5e46d0ff5d19f30bdb4091d9158685212a65c21b Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Fri, 23 Aug 2024 00:17:31 +0000 Subject: [PATCH 20/81] add OrtGraph_SerializeToArray --- .../core/session/onnxruntime_c_api.h | 2 + onnxruntime/core/session/onnxruntime_c_api.cc | 20 + onnxruntime/core/session/ort_apis.h | 2 + samples/tensorRTEp/CMakeLists.txt | 6 +- .../tensorRTEp/ort_trt_int8_cal_table.fbs.h | 144 ++ .../tensorRTEp/tensorrt_execution_provider.cc | 1911 +++++++++++++++-- .../tensorRTEp/tensorrt_execution_provider.h | 22 +- .../tensorrt_execution_provider_utils.h | 93 + 8 files changed, 2072 insertions(+), 128 deletions(-) create mode 100644 samples/tensorRTEp/ort_trt_int8_cal_table.fbs.h create mode 100644 samples/tensorRTEp/tensorrt_execution_provider_utils.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index ddee3ce9746e5..c7de6f3538499 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4775,6 +4775,8 @@ struct OrtApi { int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; + ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); ORT_API2_STATUS(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b9ea2e8bac797..26f3df9f647c2 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -21,7 +21,9 @@ #include "core/common/status.h" #include "core/common/safeint.h" #include "core/graph/constants.h" +#include "core/graph/model.h" #include "core/graph/graph.h" +#include "core/graph/graph_proto_serializer.h" #include "core/graph/graph_viewer.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" @@ -2473,6 +2475,23 @@ ORT_API(int32_t, OrtApis::OrtGraph_GetIthOutputElemType, const OrtGraphViewer* g return graph_viewer->GetOutputs()[i]->TypeAsProto()->tensor_type().elem_type(); } +ORT_API(size_t, OrtApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + Model model(graph_viewer->Name(), true, ModelMetaData(), PathString(), +#if defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaRegistryList(), +#else + IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), +#endif + graph_viewer->DomainToVersionMap(), std::vector(), graph_viewer->GetGraph().GetLogger()); + onnx::ModelProto model_proto = model.ToProto(); + GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true); + size_t ret = model_proto.ByteSizeLong(); + *data = malloc(ret); + model_proto.SerializeToArray(*data, ret); + return ret; +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetName, const OrtNode* node, _Out_ const char** name) { const ::onnxruntime::Node* n = reinterpret_cast(node); *name = n->Name().c_str(); @@ -3041,6 +3060,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtGraph_GetOutputSize, &OrtApis::OrtGraph_GetIthOutputName, &OrtApis::OrtGraph_GetIthOutputElemType, + &OrtApis::OrtGraph_SerializeToArray, &OrtApis::OrtNode_GetName, &OrtApis::OrtNode_GetDescription, &OrtApis::OrtNode_GetDomain, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 949079766811b..c92b33d3e91ce 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -559,6 +559,8 @@ ORT_API(const char*, OrtGraph_GetIthOutputName, const OrtGraphViewer*, size_t i) ORT_API(int32_t, OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; +ORT_API(size_t, OrtGraph_SerializeToArray, const OrtGraphViewer*, _Out_ void** data); + ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); ORT_API_STATUS_IMPL(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index 33ec1d1937096..759c43060de4b 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -16,7 +16,8 @@ file(GLOB tensorrt_src "./*.cc") add_library(TensorRTEp SHARED ${tensorrt_src}) target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "/usr/local/cuda/include" - "/home/leca/TensorRT-10.0.1.6/include") + "/home/leca/TensorRT-10.0.1.6/include" + "../../build/Linux/Debug/_deps/flatbuffers-src/include") # "../../build/Linux/Debug/_deps/gsl-src/include" # "../../build/Linux/Debug/_deps/onnx-src" # "../../build/Linux/Debug/_deps/onnx-build" @@ -26,7 +27,8 @@ target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer.so" "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer_plugin.so" - "/home/leca/TensorRT-10.0.1.6/lib/libnvonnxparser.so") + "/home/leca/TensorRT-10.0.1.6/lib/libnvonnxparser.so" + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/flatbuffers-build/libflatbuffers.a") # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" diff --git a/samples/tensorRTEp/ort_trt_int8_cal_table.fbs.h b/samples/tensorRTEp/ort_trt_int8_cal_table.fbs.h new file mode 100644 index 0000000000000..9e4324fb9f516 --- /dev/null +++ b/samples/tensorRTEp/ort_trt_int8_cal_table.fbs.h @@ -0,0 +1,144 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ +#define FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace CalTableFlatBuffers { + +struct KeyValue; +struct KeyValueBuilder; + +struct TrtTable; +struct TrtTableBuilder; + +struct KeyValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef KeyValueBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEY = 4, + VT_VALUE = 6 + }; + const flatbuffers::String* key() const { + return GetPointer(VT_KEY); + } + bool KeyCompareLessThan(const KeyValue* o) const { + return *key() < *o->key(); + } + int KeyCompareWithValue(const char* val) const { + return strcmp(key()->c_str(), val); + } + const flatbuffers::String* value() const { + return GetPointer(VT_VALUE); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffsetRequired(verifier, VT_KEY) && + verifier.VerifyString(key()) && + VerifyOffset(verifier, VT_VALUE) && + verifier.VerifyString(value()) && + verifier.EndTable(); + } +}; + +struct KeyValueBuilder { + typedef KeyValue Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_key(flatbuffers::Offset key) { + fbb_.AddOffset(KeyValue::VT_KEY, key); + } + void add_value(flatbuffers::Offset value) { + fbb_.AddOffset(KeyValue::VT_VALUE, value); + } + explicit KeyValueBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + KeyValueBuilder& operator=(const KeyValueBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + fbb_.Required(o, KeyValue::VT_KEY); + return o; + } +}; + +inline flatbuffers::Offset CreateKeyValue( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset key = 0, + flatbuffers::Offset value = 0) { + KeyValueBuilder builder_(_fbb); + builder_.add_value(value); + builder_.add_key(key); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateKeyValueDirect( + flatbuffers::FlatBufferBuilder& _fbb, + const char* key = nullptr, + const char* value = nullptr) { + auto key__ = key ? _fbb.CreateString(key) : 0; + auto value__ = value ? _fbb.CreateString(value) : 0; + return CalTableFlatBuffers::CreateKeyValue( + _fbb, + key__, + value__); +} + +struct TrtTable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TrtTableBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DICT = 4 + }; + const flatbuffers::Vector>* dict() const { + return GetPointer>*>(VT_DICT); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DICT) && + verifier.VerifyVector(dict()) && + verifier.VerifyVectorOfTables(dict()) && + verifier.EndTable(); + } +}; + +struct TrtTableBuilder { + typedef TrtTable Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_dict(flatbuffers::Offset>> dict) { + fbb_.AddOffset(TrtTable::VT_DICT, dict); + } + explicit TrtTableBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TrtTableBuilder& operator=(const TrtTableBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTrtTable( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset>> dict = 0) { + TrtTableBuilder builder_(_fbb); + builder_.add_dict(dict); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTrtTableDirect( + flatbuffers::FlatBufferBuilder& _fbb, + std::vector>* dict = nullptr) { + auto dict__ = dict ? _fbb.CreateVectorOfSortedTables(dict) : 0; + return CalTableFlatBuffers::CreateTrtTable( + _fbb, + dict__); +} + +} // namespace CalTableFlatBuffers + +#endif // FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 21c9577aed071..c9259f5fa5d86 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1,10 +1,50 @@ #include #include +#include #include +#include "core/session/onnxruntime_cxx_api.h" // TODO(leca): we should be able to use cxx APIs which are built upon C API #include "tensorrt_execution_provider.h" +#include "tensorrt_execution_provider_utils.h" #include "onnx_ctx_model_helper.h" + +void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } + namespace onnxruntime { +template +using IAllocatorUniquePtr = std::unique_ptr>; +const OrtApi* TensorrtExecutionProvider::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); + +bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { + size_t alloc_size = size; + if (alignment == 0) { + *out = alloc_size * nmemb; + } else { + size_t alignment_mask = alignment - 1; + *out = (alloc_size * nmemb + alignment_mask) & ~static_cast(alignment_mask); + } + return true; +} + +template +IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes) { + size_t alloc_size = count_or_bytes; + // if T is not void, 'count_or_bytes' == number of items so allow for that + if constexpr (!std::is_void::value) { + // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't + // reachable if T is void. use std::conditional to 'use' void* in the sizeof call + constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); + CalcMemSizeForArrayWithAlignment(count_or_bytes, size, 0, &alloc_size); + } + + T* p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + + return IAllocatorUniquePtr{p, + [ort_allocator](T* p) { + ort_allocator->Free(ort_allocator, p); + }}; +} + TensorrtLogger& GetTensorrtLogger(bool verbose_log) { const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING; static TensorrtLogger trt_logger(log_level); @@ -14,7 +54,439 @@ TensorrtLogger& GetTensorrtLogger(bool verbose_log) { return trt_logger; } -const OrtApi* TensorrtExecutionProvider::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +template +void GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, + void* shape_values, + int shape_size, + cudaStream_t stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(shape_values, + input_tensor.GetTensorData(), + shape_size * sizeof(T), + cudaMemcpyDeviceToHost, + stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); +} + +bool ApplyProfileShapesFromProviderOptions(std::vector& trt_profiles, + nvinfer1::ITensor* input, + std::unordered_map>>& profile_min_shapes, + std::unordered_map>>& profile_max_shapes, + std::unordered_map>>& profile_opt_shapes, + ShapeRangesMap& input_explicit_shape_ranges) { + if (trt_profiles.size() == 0) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Number of optimization profiles should be greater than 0, but it's 0."; + return false; + } + + const std::string& input_name = input->getName(); + if (profile_min_shapes.find(input_name) == profile_min_shapes.end()) { + return false; + } + + if (input_explicit_shape_ranges.find(input_name) == input_explicit_shape_ranges.end()) { + std::unordered_map>> inner_map; + input_explicit_shape_ranges[input_name] = inner_map; + } + +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Begin to apply profile shapes ..."; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Input tensor name is '" << input_name << "', number of profiles found is " << trt_profiles.size(); + + for (size_t i = 0; i < trt_profiles.size(); i++) { + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + auto trt_profile = trt_profiles[i]; + + // Shape tensor + if (input->isShapeTensor()) { + int shape_size = nb_dims == 0 ? 1 : static_cast(profile_min_shapes[input_name][i].size()); + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shape size of this shape tensor is " << shape_size; + + for (int j = 0; j < shape_size; j++) { + auto min_value = profile_min_shapes[input_name][i][j]; + auto max_value = profile_max_shapes[input_name][i][j]; + auto opt_value = profile_opt_shapes[input_name][i][j]; + shapes_min[j] = static_cast(min_value); + shapes_max[j] = static_cast(max_value); + shapes_opt[j] = static_cast(opt_value); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_min.d[" << j << "] is " << shapes_min[j]; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_max.d[" << j << "] is " << shapes_max[j]; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_opt.d[" << j << "] is " << shapes_opt[j]; + + if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { + std::vector> profile_vector(trt_profiles.size()); + input_explicit_shape_ranges[input_name][j] = profile_vector; + } + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(min_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(max_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(opt_value); + } + + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } + // Execution tensor + else { + nvinfer1::Dims dims_min, dims_opt, dims_max; + dims_min.nbDims = nb_dims; + dims_max.nbDims = nb_dims; + dims_opt.nbDims = nb_dims; + +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] number of dimension of this execution tensor is " << nb_dims; + + for (int j = 0; j < nb_dims; j++) { + if (dims.d[j] == -1) { + auto min_value = profile_min_shapes[input_name][i][j]; + auto max_value = profile_max_shapes[input_name][i][j]; + auto opt_value = profile_opt_shapes[input_name][i][j]; + dims_min.d[j] = static_cast(min_value); + dims_max.d[j] = static_cast(max_value); + dims_opt.d[j] = static_cast(opt_value); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_min.d[" << j << "] is " << dims_min.d[j]; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_max.d[" << j << "] is " << dims_max.d[j]; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_opt.d[" << j << "] is " << dims_opt.d[j]; + + if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { + std::vector> profile_vector(trt_profiles.size()); + input_explicit_shape_ranges[input_name][j] = profile_vector; + } + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(min_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(max_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(opt_value); + } else { + dims_min.d[j] = dims.d[j]; + dims_max.d[j] = dims.d[j]; + dims_opt.d[j] = dims.d[j]; + } + } + + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + } + return true; +} + +#define CASE_GET_INPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + data = const_cast(input_tensor_ptr); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +//#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ +// case DATA_TYPE: { \ +// auto input_tensor_ptr = input_tensor.GetTensorData(); \ +// if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ +// scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ +// data = scratch_buffers.back().get(); \ +// cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ +// } else { \ +// scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ +// data = scratch_buffers.back().get(); \ +// } \ +// break; \ +// } + +#define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + buffers[output_name] = output_tensor_ptr; \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = static_cast(elem_cnt); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = 1; \ + } \ + break; \ + } + +#define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(DstT), cudaMemcpyDeviceToDevice, stream)); \ + } \ + break; \ + } + +//#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT) \ +// case DATA_TYPE: { \ +// auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ +// if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ +// cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ +// } \ +// break; \ +// } + +OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, + nvinfer1::ICudaEngine* trt_engine, + nvinfer1::IExecutionContext* trt_context, + const char* input_name, + size_t input_index, + std::unordered_map>& shape_tensor_values, + std::unordered_map>& shape_tensor_values_int64, + std::vector>& scratch_buffers, + OrtAllocator* alloc, + cudaStream_t stream) { + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + const auto elem_cnt = tensor_info.GetElementCount(); + + if (trt_engine->isShapeInferenceIO(input_name)) { + // Bind "shape tensor" input buffer + + // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + int shape_size = trt_engine->getTensorShape(input_name).nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + // get shape tensor value if not present + if (shape_tensor_values.find(input_name) == shape_tensor_values.end()) { + auto input = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + shape_tensor_values[input_name].resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + shape_tensor_values[input_name][i] = input[i]; + } + } + + if (!trt_context->setTensorAddress(input_name, &shape_tensor_values[input_name][0])) { + std::string error_input_name = input_name; + std::string error_msg = + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + error_input_name + "'"; + return TensorrtExecutionProvider::api_->CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // get shape tensor value if not present + if (shape_tensor_values_int64.find(input_name) == shape_tensor_values_int64.end()) { + auto input = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + shape_tensor_values_int64[input_name].resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + shape_tensor_values_int64[input_name][i] = input[i]; + } + } + + if (!trt_context->setTensorAddress(input_name, &shape_tensor_values_int64[input_name][0])) { + std::string error_input_name = input_name; + std::string error_msg = + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + error_input_name + "'"; + return TensorrtExecutionProvider::api_->CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + break; + } + default: { + std::string error_input_name = input_name; + return TensorrtExecutionProvider::api_->CreateStatus(ORT_EP_FAIL, std::string("The data type of shape tensor should be INT32 or INT64. Please check the data type of " + error_input_name).c_str()); + } + } + } else { + // Set shape for input tensor which is execution tensor + nvinfer1::Dims dims = trt_context->getTensorShape(input_name); + int nb_dims = dims.nbDims; + for (int j = 0, end = nb_dims; j < end; ++j) { + dims.d[j] = static_cast(tensor_shapes[j]); + } + if (!trt_context->setInputShape(input_name, dims)) { + std::string error_input_name = input_name; + return TensorrtExecutionProvider::api_->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'").c_str()); + } + + // Bind "execution tensor" input buffer + // + // Note: If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses. + // Therefore, in the case of empty tensor, TRT EP always allocates a dummy byte. + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#empty-tensors + void* data = nullptr; + switch (tensor_type) { + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Cast int64 input to int32 input because TensorRT < 10 doesn't support int64 +// CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif + // Cast double input to float because TensorRT doesn't support double +// CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + default: { + return TensorrtExecutionProvider::api_->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); + } + } + trt_context->setTensorAddress(input_name, data); + } + + return nullptr; +} + +OrtStatusPtr BindContextOutput(Ort::KernelContext& ctx, + nvinfer1::IExecutionContext* trt_context, + const char* output_name, + size_t output_index, + size_t output_type, + size_t i, + std::unordered_map& output_tensors, + std::unordered_map& output_dim_sizes, + DDSOutputAllocatorMap& dds_output_allocator_map, + std::vector>& scratch_buffers, + OrtAllocator* alloc, + std::unordered_map& buffers) { + // Get output shape + nvinfer1::Dims dims = trt_context->getTensorShape(output_name); + int nb_dims = dims.nbDims; + bool is_DDS = false; + std::vector output_shapes(nb_dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + // data-dependent shape + if (dims.d[j] == -1) { + is_DDS = true; + break; + } + output_shapes[j] = dims.d[j]; + } + + auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end(); + + // If the output tensor has data-dependent shape, TRT EP will provide an IOutputAllocator for enqueueV3 to dynamically allocate memory buffer. + // Once enqueueV3 returns, TRT EP will then bind the output allocation to ORT kernel context output. + // (Please note that we take strategy A mentioned in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dynamic-shaped-output, + // which we defer allocation until the size is known and don't call IExecution::setTensorAddress) + // + // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. + if (is_DDS || known_DDS) { + if (!known_DDS) { + auto allocatorPtr = std::make_unique(); + trt_context->setOutputAllocator(output_name, allocatorPtr.get()); + dds_output_allocator_map[output_name] = std::move(allocatorPtr); + } + } else { + output_tensors[i] = ctx.GetOutput(output_index, output_shapes); + auto& output_tensor = output_tensors[i]; + const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + + switch (output_type) { + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Allocate int32 CUDA memory for int64 output type because TensorRT < 10 doesn't support int64 + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif + // Allocate float CUDA memory for double output type because TensorRT doesn't support double + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + default: { + return TensorrtExecutionProvider::api_->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + } + } + trt_context->setTensorAddress(output_name, buffers[output_name]); + } + + return nullptr; +} + +OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, + OrtMemoryInfo* /*mem_info*/, + DDSOutputAllocatorMap& allocator_map, + char const* output_name, + size_t output_index, + size_t output_type, + cudaStream_t stream) { + auto allocator = allocator_map[output_name].get(); + auto& shape = allocator->getOutputShape(); + auto output_tensor = ctx.GetOutput(output_index, shape); + + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + + /* + * Copy output data from allocation buffer to ORT kernel context output location or + * cast (int32 or float) -> (int64 or double) to ORT kernel context output location. + * + * Note: + * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0, + * TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors. + * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we + * don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream, + * and within the same stream, operations are guaranteed to be executed in order. + */ + switch (output_type) { + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. +// CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) +#endif + // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. +// CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) + default: { + return TensorrtExecutionProvider::api_->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + } + } + return nullptr; +} TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { @@ -47,7 +519,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const if (GraphHasCtxNode(graph[j])) { ret = p->CreateNodeComputeInfoFromPrecompiledEngine(graph[j], node[j], input_map, output_map, &node_compute_info[j]); } else { - + ret = p->CreateNodeComputeInfoFromGraph(graph[j], node[j], input_map, output_map, &node_compute_info[j]); } if (ret != nullptr) return api->CreateStatus(api->GetErrorCode(ret), api->GetErrorMessage(ret)); } @@ -120,6 +592,16 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { }; } +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const { + if (!builder_) { + { + // auto lock = GetApiLock(); // TODO(leca) + builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + } + } + return builder_.get(); +} + OrtStatusPtr TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, @@ -178,146 +660,1015 @@ OrtStatusPtr TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filen #endif } -OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, - std::unordered_map& input_map, - std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_funcs) { - std::unique_ptr trt_engine; - std::unique_ptr trt_context; - std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index - std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index - std::unordered_map output_types; // TRT engine output name -> ORT output tensor type +OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const OrtGraphViewer* graph_body_viewer, + const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_funcs) { + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#endif + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + void* buf_data = nullptr; + size_t buf_size = api_->OrtGraph_SerializeToArray(graph_body_viewer, &buf_data); + trt_parser->parse(buf_data, buf_size, model_path_); + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); - // Get engine binary data and deserialize it - auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, - runtime_.get(), - model_path_, - compute_capability_, - weight_stripped_engine_enable_, - onnx_model_folder_path_, - detailed_build_log_); - auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); - if (status != nullptr) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow + if (fp16_enable_ && layer_norm_fp32_fallback_) { + for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { + auto layer = trt_network->getLayer(idx); + auto next_layer = trt_network->getLayer(idx + 1); + if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; + layer->setPrecision(nvinfer1::DataType::kFLOAT); + next_layer->setPrecision(nvinfer1::DataType::kFLOAT); + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + } + } } - // Build context + int num_inputs = trt_network->getNbInputs(); + int num_outputs = trt_network->getNbOutputs(); + std::unordered_map input_indexes(num_inputs); + std::unordered_map output_indexes(num_outputs); + std::unordered_map output_types(num_outputs); + + /* + * Initialize shape range for each dynamic shape input tensor: + * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles during EP compile time. + * It won't make adjustment for profile values during EP compute time. + * + * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, INT_MIN, INT_MIN]. + * Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, max_input_value] based on input tensor value. + * + * + * Once the TRT profiles are created: + * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles will be applied to TRT builder config + * and the engine will be built at EP compile time. + * + * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create default shape as described above, + * and all the profiles won't be applied and engine won't be built until EP compute time. + */ + bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false. + bool has_explicit_profile = false; + bool apply_explicit_profile = false; + int num_profiles = 0; + std::vector trt_profiles; + + // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape dimension(s) and min/max/opt values for dynamic shape input tensor. // - // Note: Creating an execution context from an engine is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySize(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; + // (1) Single profile case: + // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b + // has one dynamic shape dimension: dim_1. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape, max_shape, opt_shape]], + // dim_2: [[min_shape, max_shape, opt_shape]] + // }, + // tensor_b: { + // dim_1: [[min_shape, max_shape, opt_shape]] + // } + // } + // + // (2) Multiple profiles case: + // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: dim_1, + // and both of the tensors have two profiles. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] + // }, + // tensor_b: { + // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] + // } + // } + ShapeRangesMap input_explicit_shape_ranges; + ShapeRangesMap input_implicit_shape_ranges; + + if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { + has_explicit_profile = true; + num_profiles = GetNumProfiles(profile_min_shapes_); + for (int i = 0; i < num_profiles; i++) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); } -#if NV_TENSORRT_MAJOR < 10 - trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#else - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -#endif - } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); } - const char* fused_node_name = nullptr; - api_->OrtNode_GetName(fused_node, &fused_node_name); - if (!trt_context) { - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, - std::string("TensorRT EP could not build execution context for fused node: " + std::string(fused_node_name)).c_str()); + // Iterate all input tensors to check dynamic shape + for (unsigned int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + // Apply explicit optimization profiles provided by user + if (has_explicit_profile) { + apply_explicit_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + } + + // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on input tensor values at EP compute time + if (!apply_explicit_profile) { + if (input->isShapeTensor()) { + // Shape tensor + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][0] = profile_vector; + has_dynamic_shape = true; + } else { + // Execution tensor + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dims.d[j] == -1) { + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][j] = profile_vector; + has_dynamic_shape = true; + } + } + } + apply_explicit_profile = false; + } } - // Create input/output to index maps - for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { - auto const& name = trt_engine->getIOTensorName(i); - auto const& mode = trt_engine->getTensorIOMode(name); - if (mode == nvinfer1::TensorIOMode::kINPUT) { - const auto& iter = input_map.find(name); - if (iter != input_map.end()) { - input_indexes[name] = iter->second; + // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user + if (has_explicit_profile) { + // TRT EP has a constraint here. + // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify profiles through provider options. + if (has_dynamic_shape) { + std::ostringstream msg; + msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly set profiles through provider options.\n"; + msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n"; + msg << "Following input(s) has no associated shape profiles provided: "; + auto begin = input_implicit_shape_ranges.begin(); + auto end = input_implicit_shape_ranges.end(); + auto it = begin; + if (it != end) { + msg << it->first; + ++it; + } + for (; it != end; ++it) { + msg << "," << it->first; } + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, msg.str().c_str()); } else { - const auto& iter = output_map.find(name); - if (iter != output_map.end()) { - output_indexes[name] = iter->second; + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); } } } + // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default. + // It will later set proper min/max/opt shape values duing EP compute time. + else if (!has_explicit_profile && has_dynamic_shape) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); + } - // Create output to type map - size_t graph_output_size = api_->OrtGraph_GetOutputSize(graph_body_viewer); - for (size_t i = 0; i < graph_output_size; i++) { - output_types[api_->OrtGraph_GetIthOutputName(graph_body_viewer, i)] = api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); + // Check platform availability for low precision + if (fp16_enable_) { + if (!trt_builder->platformHasFastFp16()) { + fp16_enable_ = false; + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16"; + } } - // Save TRT engine, TRT context and input/output info to map - engines_.emplace(fused_node_name, std::move(trt_engine)); - contexts_.emplace(fused_node_name, std::move(trt_context)); - input_info_[fused_node_name].push_back(input_indexes); - output_info_[fused_node_name].push_back(output_indexes); - output_info_[fused_node_name].push_back(output_types); + if (int8_enable_) { + if (!trt_builder->platformHasFastInt8()) { + int8_enable_ = false; + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; + } + } - // Create function state - (*node_compute_funcs)->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { - TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); - std::unique_ptr p = std::make_unique(); - *p = { context->AllocateFunc, - context->DestroyFunc, - context->allocator_handle, - context->node_name, - &(this_->engines_[context->node_name]), - &(this_->contexts_[context->node_name]), - this_->input_info_[context->node_name], - this_->output_info_[context->node_name], - this_->context_memory_sharing_enable_, - &this_->max_ctx_mem_size_}; - *state = p.release(); - return 0; - }; + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if (int8_enable_ && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + } + } - // Release function state - (*node_compute_funcs)->DestroyFunctionStateFunc = [](void* state) { - delete reinterpret_cast(state); - }; + // Set precision flags + const char* node_name = nullptr; + api_->OrtNode_GetName(fused_node, &node_name); + std::string trt_node_name_with_precision(node_name); + if (fp16_enable_ && int8_enable_) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + trt_node_name_with_precision += "_fp16_int8"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; + } else if (fp16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + trt_node_name_with_precision += "_fp16"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + } else if (int8_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + trt_node_name_with_precision += "_int8"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + } - // Create compute function - (*node_compute_funcs)->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { - TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); - TensorrtShortFuncState* trt_state = reinterpret_cast(state); + // Set DLA + if (fp16_enable_ || int8_enable_) { + if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 + int number_of_dla_core = trt_builder->getNbDLACores(); + if (number_of_dla_core == 0) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + dla_enable_ = false; + } else { + if (dla_core_ >= number_of_dla_core) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead."; + dla_core_ = 0; + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(dla_core_); + trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); + } + } + } - // The whole compute_function should be considered the critical section. - // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading -// std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + // enable sparse weights + if (sparsity_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + if (build_heuristics_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are enabled." + // << " For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder optimization level as 2 to enable builder heuristics."; + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 + if (build_heuristics_enable_) { + if (builder_optimization_level_ == 2) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."; + } else { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set builder optimization level as 2 to enable builder heuristics."; + } + } +#endif - const std::unordered_map& input_indexes = (trt_state->input_info)[0]; - const std::unordered_map& output_indexes = (trt_state->output_info)[0]; - const std::unordered_map& output_types = (trt_state->output_info)[1]; - auto fused_node_name = trt_state->fused_node_name; - auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; -// auto trt_engine = trt_state->engine->get(); -// auto trt_context = trt_state->context->get(); -// auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; -// int num_outputs = static_cast(output_indexes.size()); -// std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run -// std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input +//#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 +// // switch optimizaion level +// if (builder_optimization_level_ != 3) { +// trt_config->setBuilderOptimizationLevel(builder_optimization_level_); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; +// } // -// OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); -// OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); -// if (alloc_ == nullptr) { -// Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); -// } -// OrtAllocator* alloc = alloc_; +// // limit auxiliary streams +// if (auxiliary_streams_ >= 0) { +// trt_config->setMaxAuxStreams(auxiliary_streams_); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; +// } +//#else +// if (builder_optimization_level_ != 3) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; +// } +// if (auxiliary_streams_ >= 0) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; +// } +//#endif // -// void* cuda_stream; -// Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); +// if (weight_stripped_engine_enable_) { +//#if NV_TENSORRT_MAJOR >= 10 +// trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; +// trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +//#else +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +//#endif +// } +// +// // limit used tactic sources +// if (!tactic_sources_.empty()) { +// nvinfer1::TacticSources tactics = trt_config->getTacticSources(); +// tactics |= GetTacticSourceFromString(tactic_sources_); +// trt_config->setTacticSources(tactics); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; +// } +// +// // Build TRT engine (if needed) and load TRT engine if: +// // (1) Graph has no dynamic shape input +// // (2) All the dynamic shape inputs have associated explicit profiles specified by user +// // +// // Otherwise engine will be handled at inference time. +// std::unique_ptr trt_engine; +// std::unique_ptr trt_context; +// +// std::string cache_path = ""; +// std::string cache_suffix = ""; +// // Customize cache prefix if assigned +// if (!cache_prefix_.empty()) { +// // Generate cache suffix in case user would like to customize cache prefix +// cache_suffix = "_" + GetCacheSuffix(fused_node.Name(), trt_node_name_with_precision); +// cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; +// } else { +// cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); +// } +// +// std::string cache_hw_compat = "_sm" + compute_capability_; +// // Enable hardware compatility mode if assigned +// if (engine_cache_enable_ && engine_hw_compatible_) { +// trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); +// cache_hw_compat = "_sm80+"; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; +// } +// +// // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache +// // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity +// const std::string cache_path_prefix = cache_path + cache_hw_compat; +// std::string engine_cache_path = cache_path_prefix + ".engine"; +// const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; +// const std::string profile_cache_path = cache_path_prefix + ".profile"; +// +// // If weight-stripped engine is enabled and refitted engine cache is not present, +// // TRT EP will use the engine cache with ".stripped.engine" appended to the end. +// const std::filesystem::path engine_cache_fs_path = engine_cache_path; +// if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { +// engine_cache_path = cache_path_prefix + ".stripped.engine"; +// weight_stripped_engine_refit_ = true; +// } +// +// // Generate file name for dumping ep context model +// if (dump_ep_context_model_ && ctx_model_path_.empty()) { +// ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); +// } +// +// if (!has_dynamic_shape) { +// std::string timing_cache_path = ""; +// bool engine_update = false; +// if (timing_cache_enable_) { +// timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); +// } +// { +// // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading. +// auto lock = GetApiLock(); +// +// // If explicit profile flag is on and engine cache enable flag is on, +// // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine. +// if (has_explicit_profile && engine_cache_enable_) { +// engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); +// if (engine_update) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; +// } else { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; +// } +// } +// +// std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); +// if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { +// engine_file.seekg(0, std::ios::end); +// size_t engine_size = engine_file.tellg(); +// engine_file.seekg(0, std::ios::beg); +// std::unique_ptr engine_buf{new char[engine_size]}; +// engine_file.read((char*)engine_buf.get(), engine_size); +// trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; +// if (trt_engine == nullptr) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not deserialize engine from cache: " + engine_cache_path); +// } +// +// } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { +// // Decrypt engine +// size_t engine_size = 0; +// if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not get engine buffer size"); +// } +// std::unique_ptr engine_buf{new char[engine_size]}; +// if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not call engine decryption function decrypt"); +// } +// // Deserialize engine +// trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; +// if (trt_engine == nullptr) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); +// } +// } else { +// // Set INT8 per tensor dynamic range +// if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { +//#if defined(_MSC_VER) +//#pragma warning(push) +//#pragma warning(disable : 4996) +//#endif +// trt_config->setInt8Calibrator(nullptr); +//#if defined(_MSC_VER) +//#pragma warning(pop) +//#endif +// if (!SetDynamicRange(*trt_network, dynamic_range_map)) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name()); +// } +// } +// +// // Load timing cache from file. Create a fresh cache if the file doesn't exist +// std::unique_ptr timing_cache = nullptr; +// if (timing_cache_enable_) { +// std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); +// timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); +// if (timing_cache == nullptr) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not create timing cache: " + timing_cache_path); +// } +// trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); +// if (detailed_build_log_) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; +// } +// } +// +// // Build engine +// std::chrono::steady_clock::time_point engine_build_start; +// if (detailed_build_log_) { +// engine_build_start = std::chrono::steady_clock::now(); +// } +// std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; +// if (serialized_engine == nullptr) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP failed to create engine from network for fused node: " + fused_node.Name()); +// } +// trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); +// if (trt_engine == nullptr) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP failed to deserialize engine for fused node: " + fused_node.Name()); +// } +// if (detailed_build_log_) { +// auto engine_build_stop = std::chrono::steady_clock::now(); +// LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; +// } +// if (engine_cache_enable_) { +// // Serialize engine profile if it has explicit profiles +// if (has_explicit_profile) { +// SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; +// } +// +// if (engine_decryption_enable_) { +// // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. +// if (engine_encryption_ != nullptr) { +// if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP call to engine encryption library failed"); +// } +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; +// } else { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; +// } +// } else { +// std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); +// file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; +// } +// } +// // serialize and save timing cache +// if (timing_cache_enable_) { +// auto timing_cache = trt_config->getTimingCache(); +// std::unique_ptr timingCacheHostData{timing_cache->serialize()}; +// if (timingCacheHostData == nullptr) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not serialize timing cache: " + timing_cache_path); +// } +// saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); +// if (detailed_build_log_) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; +// } +// } +// // dump EP context node model +// if (dump_ep_context_model_) { +// // "ep_cache_context" node attribute should be a relative path to context model directory +// if (ep_cache_context_attr_.empty()) { +// auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); +// ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); +// } +// std::string compute_capability_hw_compat = compute_capability_; +// if (engine_cache_enable_ && engine_hw_compatible_) { +// compute_capability_hw_compat = "80+"; +// } +// std::unique_ptr model_proto{CreateCtxModel(graph_body_viewer, +// ep_cache_context_attr_, +// reinterpret_cast(serialized_engine->data()), +// serialized_engine->size(), +// ep_context_embed_mode_, +// compute_capability_hw_compat, +// model_path_, +// GetLogger())}; +// DumpCtxModel(model_proto.get(), ctx_model_path_); +// } +// } +// } +// +// if (weight_stripped_engine_refit_) { +// auto status = RefitEngine(model_path_, +// onnx_model_folder_path_, +// engine_cache_path, +// false /* path check for security */, +// trt_engine.get(), +// true /* serialize refitted engine to disk */, +// detailed_build_log_); +// if (status != Status::OK()) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); +// } +// } +// +// // Build context +// // Note: Creating an execution context from an engine is thread safe per TRT doc +// // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading +// if (context_memory_sharing_enable_) { +//#if defined(_MSC_VER) +//#pragma warning(push) +//#pragma warning(disable : 4996) +//#endif +// size_t mem_size = trt_engine->getDeviceMemorySize(); +//#if defined(_MSC_VER) +//#pragma warning(pop) +//#endif +// if (mem_size > max_ctx_mem_size_) { +// max_ctx_mem_size_ = mem_size; +// } +//#if NV_TENSORRT_MAJOR < 10 +// trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +//#else +// trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +//#endif +// } else { +// trt_context = std::unique_ptr(trt_engine->createExecutionContext()); +// } +// if (!trt_context) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); +// } +// } +// +// // Create input to index map +// for (int i = 0; i < num_inputs; ++i) { +// auto input = trt_network->getInput(i); +// const std::string& input_name = input->getName(); +// const auto& iter = input_map.find(input_name); +// if (iter != input_map.end()) { +// input_indexes[input_name] = iter->second; +// } +// } +// +// // Create output to index and type maps +// const auto& graph_output = model_proto->graph().output(); +// for (int i = 0; i < num_outputs; ++i) { +// const std::string& output_name = trt_network->getOutput(i)->getName(); +// const auto& iter = output_map.find(output_name); +// if (iter != output_map.end()) { +// output_indexes[output_name] = iter->second; +// } +// const auto& tensor_type = graph_output[i].type().tensor_type(); +// output_types[output_name] = tensor_type.elem_type(); +// } +// +// // Save TRT engine, other TRT objects and input/output info to map +// parsers_.emplace(fused_node.Name(), std::move(trt_parser)); +// engines_.emplace(fused_node.Name(), std::move(trt_engine)); +// contexts_.emplace(fused_node.Name(), std::move(trt_context)); +// networks_.emplace(fused_node.Name(), std::move(trt_network)); +// input_info_[fused_node.Name()].push_back(input_indexes); +// output_info_[fused_node.Name()].push_back(output_indexes); +// output_info_[fused_node.Name()].push_back(output_types); +// input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges; +// profiles_.emplace(fused_node.Name(), std::move(trt_profiles)); +// +// // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty engine. +// // TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model. +// // However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here. +// if (dump_ep_context_model_ && has_dynamic_shape) { +// // "ep_cache_context" node attribute should be a relative path to context model directory +// if (ep_cache_context_attr_.empty()) { +// auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); +// ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); +// } +// std::string compute_capability_hw_compat = compute_capability_; +// if (engine_cache_enable_ && engine_hw_compatible_) { +// compute_capability_hw_compat = "80+"; +// } +// model_proto_.reset(CreateCtxModel(graph_body_viewer, +// ep_cache_context_attr_, +// nullptr, +// 0, +// ep_context_embed_mode_, +// compute_capability_hw_compat, +// model_path_, +// GetLogger())); +// if (ep_context_embed_mode_ == 0) { +// DumpCtxModel(model_proto_.get(), ctx_model_path_); +// } +// } +// +// // Create function state +// // TODO: remove default capture +// NodeComputeInfo compute_info; +// compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { +// std::unique_ptr p = std::make_unique(); +// // translate tactic sources string to nvinfer1::TacticSources +// nvinfer1::TacticSources tactics = 0; +// if (!tactic_sources_.empty()) { +// tactics = GetTacticSourceFromString(tactic_sources_); +// } +// *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), +// &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], +// &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], +// input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, +// dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, +// engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], +// context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, +// engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_, force_timing_cache_match_, +// detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_, +// auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix, engine_hw_compatible_}; +// *state = p.release(); +// return 0; +// }; +// +// // Release function state +// compute_info.release_state_func = [](FunctionState state) { +// delete static_cast(state); +// }; +// +// // Create compute function +// compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { +// Ort::KernelContext ctx(context); +// +// TensorrtFuncState* trt_state = reinterpret_cast(state); +// +// // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, +// // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. +// // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading +// std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); +// const std::unordered_map& input_indexes = (trt_state->input_info)[0]; +// const std::unordered_map& output_indexes = (trt_state->output_info)[0]; +// const std::unordered_map& output_types = (trt_state->output_info)[1]; +// auto fused_node_name = trt_state->fused_node_name; +// // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. +// // The info is used for both shape tensor and execution tensor: +// // tensor name->(dimension->[min, max, opt]) +// auto& shape_ranges = trt_state->input_shape_ranges; +// std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run +// std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input +// auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; +// auto trt_builder = trt_state->builder; +// auto trt_engine = trt_state->engine->get(); +// auto trt_context = trt_state->context->get(); +// auto trt_profiles = trt_state->profiles; +// auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; +// int num_inputs = static_cast(input_indexes.size()); +// int num_outputs = static_cast(output_indexes.size()); +// bool engine_update = false; +// bool context_update = false; +// std::unordered_set input_names; +// +// OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); +// OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); +// if (alloc_ == nullptr) { +// Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); +// } +// OrtAllocator* alloc = alloc_; +// +// void* cuda_stream; +// Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); // cudaStream_t stream = static_cast(cuda_stream); // +// // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache +// // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity +// // Prepare cache name +// std::string cache_path = ""; +// // Customize cache prefix if assigned +// if (!cache_prefix_.empty()) { +// cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; +// } else { +// cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); +// } +// +// // Enable hardware compatility mode if assigned +// std::string cache_hw_compat = "_sm" + compute_capability_; +// if (engine_cache_enable_ && engine_hw_compatible_) { +// cache_hw_compat = "_sm80+"; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; +// } +// +// // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache +// // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity +// const std::string cache_path_prefix = cache_path + cache_hw_compat; +// std::string engine_cache_path = cache_path_prefix + ".engine"; +// const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; +// const std::string profile_cache_path = cache_path_prefix + ".profile"; +// std::string timing_cache_path = ""; +// if (timing_cache_enable_) { +// timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); +// } +// +// // If weight-stripped engine is enabled and refitted engine cache is not present, +// // TRT EP will use the engine cache with ".stripped.engine" appended to the end. +// const std::filesystem::path engine_cache_fs_path = engine_cache_path; +// if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { +// engine_cache_path = cache_path_prefix + ".stripped.engine"; +// weight_stripped_engine_refit_ = true; +// } +// +// // Load serialized engine +// if (trt_state->engine_cache_enable && trt_engine == nullptr) { +// std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); +// std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); +// if (engine_file && !trt_state->engine_decryption_enable && profile_file) { +// // Deserialize profile +// shape_ranges = DeserializeProfileV2(profile_file); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; +// +// // Prepare buffer +// engine_file.seekg(0, std::ios::end); +// size_t engine_size = engine_file.tellg(); +// engine_file.seekg(0, std::ios::beg); +// std::unique_ptr engine_buf{new char[engine_size]}; +// engine_file.read((char*)engine_buf.get(), engine_size); +// +// // Deserialize engine +// // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc +// // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading +// trt_state->engine->reset(); +// *(trt_state->engine) = std::unique_ptr( +// trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); +// if (!(*(trt_state->engine))) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); +// } +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; +// trt_engine = trt_state->engine->get(); +// context_update = true; +// +// } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { +// shape_ranges = DeserializeProfileV2(profile_file); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; +// // Decrypt engine +// size_t engine_size = 0; +// if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not get engine buffer size"); +// } +// std::unique_ptr engine_buf{new char[engine_size]}; +// if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not call engine decryption function decrypt"); +// } +// // Deserialize engine +// // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc +// // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading +// trt_state->engine->reset(); +// *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); +// if (!(*(trt_state->engine))) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); +// } +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; +// trt_engine = trt_state->engine->get(); +// context_update = true; +// } +// } +// +// // Check and update shape ranges for dynamic shape inputs. +// for (int i = 0, end = num_inputs; i < end; ++i) { +// auto input = trt_state->network->get()->getInput(i); +// const std::string& input_name = input->getName(); +// input_names.insert(input_name); +// +// // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. +// // TRT EP will help determine the min/max/opt profile values based on current input tensor value. +// if (shape_ranges.find(input_name) != shape_ranges.end()) { +// auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, shape_tensor_values, shape_tensor_values_int64, stream, &engine_update); +// if (status != Status::OK()) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); +// } +// } +// } +// +// // Regenerate engine +// if (engine_update) { +// // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. +// trt_state->context->reset(); +// trt_state->engine->reset(); +// auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); +// trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, *(trt_state->max_workspace_size_ptr)); +// for (auto trt_profile : trt_profiles) { +// trt_config->addOptimizationProfile(trt_profile); +// } +// +// // Set INT8 Per Tensor Dynamic range +// if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { +//#if defined(_MSC_VER) +//#pragma warning(push) +//#pragma warning(disable : 4996) +//#endif +// trt_config->setInt8Calibrator(nullptr); +//#if defined(_MSC_VER) +//#pragma warning(pop) +//#endif +// if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); +// } +// } +// +// // Set precision +// if (trt_state->fp16_enable && trt_state->int8_enable) { +// trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); +// } else if (trt_state->fp16_enable) { +// trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); +// } else if (trt_state->int8_enable) { +// trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); +// } +// +// // Set DLA (DLA can only run with FP16 or INT8) +// if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; +// trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); +// trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); +// trt_config->setDLACore(trt_state->dla_core); +// } +// +// // enable sparse weights +// if (trt_state->sparsity_enable) { +// trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; +// } +//#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 +// // enable builder heuristics +// if (trt_state->build_heuristics_enable) { +// trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; +// } +//#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 +// // switch optimizaion level +// if (trt_state->builder_optimization_level != 3) { +// trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; +// } +// +// // limit auxiliary streams +// if (trt_state->auxiliary_streams >= 0) { +// trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; +// } +//#else +// if (trt_state->builder_optimization_level != 3) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; +// } +// if (trt_state->auxiliary_streams >= 0) { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; +// } +//#endif +// if (weight_stripped_engine_enable_) { +//#if NV_TENSORRT_MAJOR >= 10 +// trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; +// trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +//#else +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +//#endif +// } +// // limit used tactic sources +// if (trt_state->filter_tactic_sources) { +// nvinfer1::TacticSources tactics = trt_config->getTacticSources(); +// tactics |= trt_state->tactic_sources; +// trt_config->setTacticSources(tactics); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; +// } +// +// // Load timing cache from file. Create a fresh cache if the file doesn't exist +// std::unique_ptr timing_cache = nullptr; +// if (trt_state->timing_cache_enable) { +// std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); +// timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); +// if (timing_cache == nullptr) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not create timing cache: " + timing_cache_path); +// } +// trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); +// if (detailed_build_log_) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; +// } +// } +// +// // Enable hardware compatility mode if assigned +// if (trt_state->engine_hw_compatible) { +// trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); +// LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; +// } +// +// // Build engine +// std::unique_ptr serialized_engine; +// { +// auto lock = GetApiLock(); +// std::chrono::steady_clock::time_point engine_build_start; +// if (detailed_build_log_) { +// engine_build_start = std::chrono::steady_clock::now(); +// } +// serialized_engine = std::unique_ptr( +// trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); +// if (!serialized_engine) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create engine from network."); +// } +// *(trt_state->engine) = std::unique_ptr( +// trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); +// if (!(*(trt_state->engine))) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to deserialize engine."); +// } +// if (detailed_build_log_) { +// auto engine_build_stop = std::chrono::steady_clock::now(); +// LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; +// } +// } +// if (!(*(trt_state->engine))) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); +// } +// trt_engine = trt_state->engine->get(); +// if (trt_state->engine_cache_enable) { +// // Serialize engine profile +// SerializeProfileV2(profile_cache_path, shape_ranges); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; +// +// // Serialize engine +// if (trt_state->engine_decryption_enable) { +// // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. +// if (trt_state->engine_encryption != nullptr) { +// if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not call engine encryption function encrypt"); +// } +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; +// } else { +// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; +// } +// } else { +// std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); +// file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; +// } +// } +// +// // serialize and save timing cache +// if (trt_state->timing_cache_enable) { +// auto timing_cache = trt_config->getTimingCache(); +// std::unique_ptr timingCacheHostData{timing_cache->serialize()}; +// if (timingCacheHostData == nullptr) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, +// "TensorRT EP could not serialize timing cache: " + timing_cache_path); +// } +// saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); +// if (detailed_build_log_) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; +// } +// } +// +// // dump ep context model +// if (dump_ep_context_model_ && ep_context_embed_mode_) { +// UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), serialized_engine->size()); +// DumpCtxModel(model_proto_.get(), ctx_model_path_); +// } +// context_update = true; +// +// if (weight_stripped_engine_refit_) { +// auto status = RefitEngine(model_path_, +// onnx_model_folder_path_, +// engine_cache_path, +// false /* path check for security */, +// trt_engine, +// true /* serialize refitted engine to disk */, +// detailed_build_log_); +// if (status != Status::OK()) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); +// } +// } +// } +// +// if (context_update) { +// if (trt_state->context_memory_sharing_enable) { +//#if NV_TENSORRT_MAJOR < 10 +// *(trt_state->context) = std::unique_ptr( +// trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); +//#else +// *(trt_state->context) = std::unique_ptr( +// trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +//#endif +// } else { +// *(trt_state->context) = std::unique_ptr( +// trt_state->engine->get()->createExecutionContext()); +// } +// if (!(*(trt_state->context))) { +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); +// } +// trt_context = trt_state->context->get(); +// } +// // // Get input and output binding names // int total_bindings = trt_engine->getNbIOTensors(); // std::vector input_binding_names, output_binding_names; @@ -343,10 +1694,13 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // if (iter != input_indexes.end()) { // input_index = iter->second; // } +// auto input_tensor = ctx.GetInput(input_index); +// auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); +// const auto tensor_shapes = tensor_info.GetShape(); // -// Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); +// auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); // if (status != Status::OK()) { -// return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, status.ErrorMessage()); +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); // } // } // @@ -379,7 +1733,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, // dds_output_allocator_map, scratch_buffers, alloc, buffers); // if (status != Status::OK()) { -// return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, status.ErrorMessage()); +// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); // } // } // @@ -410,7 +1764,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // // // Run TRT inference // if (!trt_context->enqueueV3(stream)) { -// return api_->CreateStatus(OrtErrorCode::ORT_FAIL, "TensorRT EP execution context enqueue failed."); +// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); // } // // /* @@ -451,7 +1805,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // } // auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); // if (status != Status::OK()) { -// return api_->CreateStatus(OrtErrorCode::ORT_FAIL, status.ErrorMessage()); +// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); // } // } else { // auto& output_tensor = output_tensors[i]; @@ -486,6 +1840,321 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // IncrementRegularRunCountBeforeGraphCapture(); // } // } +// +// return nullptr; +// }; + + return nullptr; +} + +OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_funcs) { + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index + std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index + std::unordered_map output_types; // TRT engine output name -> ORT output tensor type + + // Get engine binary data and deserialize it + auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, + runtime_.get(), + model_path_, + compute_capability_, + weight_stripped_engine_enable_, + onnx_model_folder_path_, + detailed_build_log_); + auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + + // Build context + // + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } +#if NV_TENSORRT_MAJOR < 10 + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + + const char* fused_node_name = nullptr; + api_->OrtNode_GetName(fused_node, &fused_node_name); + if (!trt_context) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, + std::string("TensorRT EP could not build execution context for fused node: " + std::string(fused_node_name)).c_str()); + } + + // Create input/output to index maps + for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + const auto& iter = input_map.find(name); + if (iter != input_map.end()) { + input_indexes[name] = iter->second; + } + } else { + const auto& iter = output_map.find(name); + if (iter != output_map.end()) { + output_indexes[name] = iter->second; + } + } + } + + // Create output to type map + size_t graph_output_size = api_->OrtGraph_GetOutputSize(graph_body_viewer); + for (size_t i = 0; i < graph_output_size; i++) { + output_types[api_->OrtGraph_GetIthOutputName(graph_body_viewer, i)] = api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); + } + + // Save TRT engine, TRT context and input/output info to map + engines_.emplace(fused_node_name, std::move(trt_engine)); + contexts_.emplace(fused_node_name, std::move(trt_context)); + input_info_[fused_node_name].push_back(input_indexes); + output_info_[fused_node_name].push_back(output_indexes); + output_info_[fused_node_name].push_back(output_types); + + // Create function state + (*node_compute_funcs)->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + std::unique_ptr p = std::make_unique(); + *p = { context->AllocateFunc, + context->DestroyFunc, + context->allocator_handle, + context->node_name, + &(this_->engines_[context->node_name]), + &(this_->contexts_[context->node_name]), + this_->input_info_[context->node_name], + this_->output_info_[context->node_name], + this_->context_memory_sharing_enable_, + &this_->max_ctx_mem_size_}; + *state = p.release(); + return 0; + }; + + // Release function state + (*node_compute_funcs)->DestroyFunctionStateFunc = [](void* state) { + delete reinterpret_cast(state); + }; + + // Create compute function + (*node_compute_funcs)->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + TensorrtShortFuncState* trt_state = reinterpret_cast(state); + Ort::KernelContext ctx(context); + + // The whole compute_function should be considered the critical section. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading +//TODO(leca): std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto fused_node_name = trt_state->fused_node_name; + auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + int num_outputs = static_cast(output_indexes.size()); + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input + + OrtMemoryInfo* mem_info = nullptr; + api->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); + if (this_->alloc_ == nullptr) { + Ort::ThrowOnError(api->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); + } + OrtAllocator* alloc = this_->alloc_; + + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + + OrtStatusPtr status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + OrtStatusPtr status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, + dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + } + } + + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + } + trt_context->setDeviceMemory(MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + } + + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (this_->cuda_graph_enable_ && this_->IsGraphCaptureAllowed() && !this_->IsGraphCaptured(0)) { + //LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; +// cuda_graph_.SetStream(stream); +// CaptureBegin(0); + } + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + return api->CreateStatus(OrtErrorCode::ORT_FAIL, "TensorRT EP execution context enqueue failed."); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, + * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. + * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, + * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. + * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. + * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream + * is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. + * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (this_->sync_stream_after_enqueue_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + OrtStatusPtr status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_FAIL, api->GetErrorMessage(status)); + } + } else { + auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { +// auto output_tensor_ptr = output_tensor.GetTensorMutableData(); +// if (output_tensor_ptr != nullptr) { +// cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); +// } + } + } + } + + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture + // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. + // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. + if (this_->cuda_graph_enable_ && !this_->IsGraphCaptured(0)) { +// if (IsGraphCaptureAllowed()) { +// CaptureEnd(0); +// // CUDA work issued to a capturing stream doesn’t actually run on the GPU, +// // so run the captured graph here to actually execute the work. +// ORT_RETURN_IF_ERROR(ReplayGraph(0)); +// } else { +// IncrementRegularRunCountBeforeGraphCapture(); +// } + } return nullptr; }; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 70465bcdec876..8b4ca2ff61ed3 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -170,10 +170,7 @@ std::string GetWeightRefittedEnginePath(std::string engine_cache_path); struct TensorrtExecutionProvider : public OrtExecutionProvider { TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& provider_options); - OrtStatusPtr CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, - std::unordered_map& input_map, - std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_funcs); + bool IsGraphCaptured(int graph_annotation_id) const { return false; } static OrtStatusPtr RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, @@ -181,8 +178,8 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log); -private: static const OrtApi* api_; +private: // mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; @@ -278,6 +275,21 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs // to allocate enough memory in Arena before graph capturing. const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + OrtStatusPtr CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_funcs); + + OrtStatusPtr CreateNodeComputeInfoFromGraph(const OrtGraphViewer* graph_body_viewer, + const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_funcs); + + bool IsGraphCaptureAllowed() const { return false; }; + + nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; }; struct TensorrtExecutionProviderFactory : public OrtExecutionProviderFactory { diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h new file mode 100644 index 0000000000000..97b9ffd91961c --- /dev/null +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -0,0 +1,93 @@ +#pragma once +#include +#include +#include +#include +#include "flatbuffers/idl.h" +#include "ort_trt_int8_cal_table.fbs.h" + +namespace fs = std::filesystem; + +namespace onnxruntime { + +float ConvertSinglePrecisionIEEE754ToFloat(unsigned long input) { + int s = (input >> 31) & 0x01; + int e = ((input & 0x7f800000) >> 23) - 127; + int p = -1; + double m = 0.0; + for (int i = 0; i < 23; ++i) { + m += ((input >> (23 - i - 1)) & 0x01) * pow(2.0, p--); + } + return static_cast((s ? -1 : 1) * pow(2.0, e) * (m + 1.0)); +} + +bool ReadDynamicRange(const std::string file_name, const bool is_trt_calibration_table, std::unordered_map& dynamic_range_map) { + std::ifstream infile(file_name, std::ios::binary | std::ios::in); + if (!infile) { + return false; + } + + if (is_trt_calibration_table) { + // Native TensorRT generated calibration table + std::string line; + char delim = ':'; + if (std::getline(infile, line)) { + std::istringstream first_line(line); + std::string version; + std::getline(first_line, version, delim); + std::size_t found = version.find("TRT-"); + if (found != std::string::npos) { + while (std::getline(infile, line)) { + std::istringstream in_line(line); + std::string str; + std::getline(in_line, str, delim); + std::string tensor_name = str; + std::getline(in_line, str, delim); + unsigned long scale_int = std::strtoul(str.c_str(), nullptr, 16); + float scale_float = ConvertSinglePrecisionIEEE754ToFloat(scale_int); + float dynamic_range = scale_float * 127.0f; + dynamic_range_map[tensor_name] = dynamic_range; + } + } else { + throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + } + } + } else { + // ORT generated calibration table + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read((char*)data.get(), length); + infile.close(); + auto flat_table = flatbuffers::GetRoot((const uint8_t*)data.get()); + auto flat_dict = flat_table->dict(); + for (size_t i = 0, end = flat_dict->size(); i < end; ++i) { + flatbuffers::uoffset_t idx = static_cast(i); + dynamic_range_map[flat_dict->Get(idx)->key()->str()] = std::stof(flat_dict->Get(idx)->value()->str()); + } + } + return true; +} + +int GetNumProfiles(std::unordered_map>>& profile_shapes) { + int num_profile = 0; + for (auto it = profile_shapes.begin(); it != profile_shapes.end(); it++) { + num_profile = static_cast(it->second.size()); + if (num_profile > 0) { + break; + } + } + return num_profile; +} + +std::string GetCachePath(const std::string& root, const std::string& name) { + if (root.empty()) { + return name; + } else { + fs::path path = root; + path.append(name); + return path.string(); + } +} +} From 85c168d437be1f9a9d5cdb0c7e97ff4d915fe116 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Sat, 24 Aug 2024 00:17:09 +0000 Subject: [PATCH 21/81] finish Compile function --- samples/tensorRTEp/onnx_ctx_model_helper.cc | 21 + samples/tensorRTEp/onnx_ctx_model_helper.h | 2 + .../tensorRTEp/tensorrt_execution_provider.cc | 2121 ++++++++++------- .../tensorRTEp/tensorrt_execution_provider.h | 3 + .../tensorrt_execution_provider_utils.h | 211 ++ 5 files changed, 1465 insertions(+), 893 deletions(-) diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index fe022c6c0e85f..9b8e16b0eb549 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -36,6 +36,27 @@ std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_contex } } +std::string GetCtxModelPath(const std::string& ep_context_file_path, + const std::string& original_model_path) { + std::string ctx_model_path; + + if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) { + ctx_model_path = ep_context_file_path; + } else { + std::filesystem::path model_path = original_model_path; + std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name + std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx"; + + if (std::filesystem::is_directory(ep_context_file_path)) { + std::filesystem::path model_directory = ep_context_file_path; + ctx_model_path = model_directory.append(ctx_model_name).string(); + } else { + ctx_model_path = ctx_model_name; + } + } + return ctx_model_path; +} + bool IsAbsolutePath(const std::string& path_string) { #ifdef _WIN32 onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.h b/samples/tensorRTEp/onnx_ctx_model_helper.h index a7604bcbd5839..3fcb809b4bded 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.h +++ b/samples/tensorRTEp/onnx_ctx_model_helper.h @@ -22,6 +22,8 @@ static const std::string EPCONTEXT_WARNING = bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); +std::string GetCtxModelPath(const std::string& ep_context_file_path, + const std::string& original_model_path); bool IsAbsolutePath(const std::string& path_string); bool IsRelativePathToParentPath(const std::string& path_string); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index c9259f5fa5d86..c00cbca0fef7b 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -45,6 +45,191 @@ IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator }}; } +bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map& dynamic_range_map) { + // Set dynamic range for input tensors + for (int i = 0; i < network.getNbInputs(); ++i) { + const std::string tensor_name = network.getInput(i)->getName(); + auto dynamic_range_iter = dynamic_range_map.find(tensor_name); + if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!network.getInput(i)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + //LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for network input " << tensor_name; + return false; + } + } + } + + // Set dynamic range for activations and weights + for (int i = 0; i < network.getNbLayers(); ++i) { + auto trt_layer = network.getLayer(i); + for (int j = 0, e = trt_layer->getNbOutputs(); j < e; ++j) { + const std::string tensor_name = trt_layer->getOutput(j)->getName(); + auto dynamic_range_iter = dynamic_range_map.find(tensor_name); + if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_layer->getOutput(j)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + //LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for tensor " << tensor_name; + return false; + } + } else if (trt_layer->getType() == nvinfer1::LayerType::kCONSTANT) { + nvinfer1::IConstantLayer* const_layer = static_cast(trt_layer); + const std::string const_layer_name = const_layer->getName(); + auto trt_weights = const_layer->getWeights(); + double max_weight = std::numeric_limits::min(); + for (int64_t k = 0, end = trt_weights.count; k < end; ++k) { + double weight{}; + switch (trt_weights.type) { + case nvinfer1::DataType::kFLOAT: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kBOOL: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kINT8: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kHALF: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kINT32: + weight = static_cast(trt_weights.values)[k]; + break; +#if NV_TENSORRT_MAJOR >= 10 + case nvinfer1::DataType::kINT64: + weight = static_cast(static_cast(trt_weights.values)[k]); + break; +#endif // NV_TENSORRT_MAJOR >= 10 + default: + //LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name; + return false; + } + max_weight = std::max(max_weight, std::abs(weight)); + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_layer->getOutput(j)->setDynamicRange(static_cast(-max_weight), static_cast(max_weight))) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + //LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for layer " << const_layer_name; + return false; + } + } + } + } + return true; +} + +std::vector SplitToStringVec(std::string const& s, char separator) { + std::vector splitted; + + for (size_t start = 0; start < s.length();) { + size_t separatorIndex = s.find(separator, start); + if (separatorIndex == std::string::npos) { + separatorIndex = s.length(); + } + splitted.emplace_back(s.substr(start, separatorIndex - start)); + start = separatorIndex + 1; + } + + return splitted; +} + +nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { + nvinfer1::TacticSources disabledTactics = 0; + nvinfer1::TacticSources enabledTactics = 0; + std::vector tacticList = SplitToStringVec(tactic_string, ','); + for (auto& t : tacticList) { + bool enable{false}; + if (t.front() == '+') { + enable = true; + } else if (t.front() != '-') { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source must be prefixed with + or - skipping: " << t; + } + t.erase(0, 1); + + const auto toUpper = [](std::string& sourceName) { + std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), + [](char c) { return static_cast(std::toupper(c)); }); + return sourceName; + }; + + nvinfer1::TacticSource source{}; + t = toUpper(t); + if (t == "CUBLAS") { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 + source = nvinfer1::TacticSource::kCUBLAS; +#endif + } else if (t == "CUBLASLT" || t == "CUBLAS_LT") { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0"; +#if NV_TENSORRT_MAJOR < 9 + source = nvinfer1::TacticSource::kCUBLAS_LT; +#endif + } else if (t == "CUDNN") { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 + source = nvinfer1::TacticSource::kCUDNN; +#endif + } else if (t == "EDGE_MASK_CONVOLUTIONS") { + source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS; + } else if (t == "JIT_CONVOLUTIONS") { + source = nvinfer1::TacticSource::kJIT_CONVOLUTIONS; + } else { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source was not found with name: " << t; + } + + uint32_t sourceBit = 1U << static_cast(source); + + if (enable) { + enabledTactics |= sourceBit; + } else { + disabledTactics |= sourceBit; + } + } + return enabledTactics & ~disabledTactics; +} + +inline std::vector loadTimingCacheFile(const std::string inFileName) { + std::ifstream iFile(inFileName, std::ios::in | std::ios::binary); + if (!iFile) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName + // << ". A new timing cache will be generated and written."; + return std::vector(); + } + iFile.seekg(0, std::ifstream::end); + size_t fsize = iFile.tellg(); + iFile.seekg(0, std::ifstream::beg); + std::vector content(fsize); + iFile.read(content.data(), fsize); + iFile.close(); + return content; +} + +inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob) { + std::ofstream oFile(outFileName, std::ios::out | std::ios::binary); + if (!oFile) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName; + return; + } + oFile.write((char*)blob->data(), blob->size()); + oFile.close(); +} + TensorrtLogger& GetTensorrtLogger(bool verbose_log) { const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING; static TensorrtLogger trt_logger(log_level); @@ -171,6 +356,177 @@ bool ApplyProfileShapesFromProviderOptions(std::vector& trt_profiles, + Ort::KernelContext ctx, + nvinfer1::ITensor* input, + ShapeRangesMap& shape_ranges, + const std::unordered_map& input_indexes, + std::unordered_map>& shape_tensor_values, + std::unordered_map>& shape_tensor_values_int64, + cudaStream_t stream, + bool* engine_update) { + for (size_t i = 0; i < trt_profiles.size(); i++) { + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + size_t input_index = 0; + const auto& iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + auto& shape_ranges_per_input = shape_ranges[input_name]; + + auto trt_profile = trt_profiles[i]; + + // If there are multiple profiles, for second and rest of profiles, simply copy the min/max/opt profile values from the first profile. + // Following "if statement" won't be executed since TRT EP currently only allows single profile for non-explicit profiles case. + if (i > 0) { + if (input->isShapeTensor()) { + // shape tensor + int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + for (int j = 0; j < shape_size; j++) { + shapes_min[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN)); + shapes_max[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX)); + shapes_opt[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT)); + } + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } else { + // execution tensor + nvinfer1::Dims dims_min, dims_opt, dims_max; + dims_min = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN); + dims_max = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX); + dims_opt = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + continue; + } + + // Create shape profile + if (input->isShapeTensor()) { + // Get shape values for shape tensor input + const auto tensor_type = tensor_info.GetElementType(); + // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + int shape_size = dims.nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); + // For setting TRT optimization profile. (Note: the min/opt/max profile values are still int32 even though int64 is supported after TRT 10) + std::vector values(shape_size); + + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto buffer = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream); + shape_tensor_values[input_name].resize(shape_size); + for (int j = 0; j < shape_size; ++j) { + shape_tensor_values[input_name][j] = buffer[j]; + values[j] = buffer[j]; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + auto buffer = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream); + shape_tensor_values_int64[input_name].resize(shape_size); + for (int j = 0; j < shape_size; ++j) { + shape_tensor_values_int64[input_name][j] = buffer[j]; + values[j] = static_cast(buffer[j]); + } + break; + } + default: { + return TensorrtExecutionProvider::api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); + } + } + + // Update shape ranges + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + int shape_range_size = static_cast(shape_ranges_per_input.size()); + if (shape_size == shape_range_size) { + // If shape size matches, check/update shape range + for (int j = 0; j < shape_size; ++j) { + auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile + shapes_min[j] = static_cast(shape_range[0]); + shapes_max[j] = static_cast(shape_range[1]); + shapes_opt[j] = static_cast(shape_range[2]); + + const auto& tensor_shape_value = values[j]; + // Update shape range lower bound + if (tensor_shape_value < shape_range[0]) { + shape_range[0] = tensor_shape_value; + shapes_min[j] = tensor_shape_value; + *engine_update = true; + } + // Update shape range upper bound + if (tensor_shape_value > shape_range[1]) { + shape_range[1] = tensor_shape_value; + shape_range[2] = tensor_shape_value; + shapes_max[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + *engine_update = true; + } + } + } else { + // If shape size doesn't match, initialize shape_range with the new shape value + shape_ranges_per_input.clear(); + for (int j = 0; j < shape_size; ++j) { + const auto& tensor_shape_value = values[j]; + std::vector> profile_vector; + std::vector shape_vector{tensor_shape_value, tensor_shape_value, tensor_shape_value}; + profile_vector.push_back(shape_vector); // only one profile needed + shape_ranges_per_input[j] = profile_vector; + shapes_min[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + shapes_max[j] = tensor_shape_value; + } + *engine_update = true; + } + + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } else { // Execution tensor + nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + const auto& tensor_shape = tensor_shapes[j]; + if (shape_ranges_per_input.find(j) != shape_ranges_per_input.end()) { + auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile + dims_min.d[j] = static_cast(shape_range[0]); + dims_max.d[j] = static_cast(shape_range[1]); + dims_opt.d[j] = static_cast(shape_range[2]); + + // Update minimum dimension + if (tensor_shape < shape_range[0]) { + shape_range[0] = tensor_shape; + dims_min.d[j] = static_cast(tensor_shape); + *engine_update = true; + } + // Update maximum dimension + if (tensor_shape > shape_range[1]) { + shape_range[1] = tensor_shape; + shape_range[2] = tensor_shape; + dims_max.d[j] = static_cast(tensor_shape); + dims_opt.d[j] = static_cast(tensor_shape); + *engine_update = true; + } + } + } + + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + } + return nullptr; +} + #define CASE_GET_INPUT_TENSOR(DATA_TYPE, SrcT) \ case DATA_TYPE: { \ auto input_tensor_ptr = input_tensor.GetTensorData(); \ @@ -845,10 +1201,9 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } // Load INT8 calibration table - std::unordered_map dynamic_range_map; if (int8_enable_ && int8_calibration_cache_available_) { const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map_)) { throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); } } @@ -856,18 +1211,18 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // Set precision flags const char* node_name = nullptr; api_->OrtNode_GetName(fused_node, &node_name); - std::string trt_node_name_with_precision(node_name); + trt_node_name_with_precision_ = node_name; if (fp16_enable_ && int8_enable_) { trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - trt_node_name_with_precision += "_fp16_int8"; + trt_node_name_with_precision_ += "_fp16_int8"; //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; } else if (fp16_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - trt_node_name_with_precision += "_fp16"; + trt_node_name_with_precision_ += "_fp16"; //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; } else if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - trt_node_name_with_precision += "_int8"; + trt_node_name_with_precision_ += "_int8"; //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; } @@ -887,7 +1242,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); trt_config->setDLACore(dla_core_); - trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); + trt_node_name_with_precision_ += "_dlacore" + std::to_string(dla_core_); } } } @@ -914,246 +1269,235 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } #endif -//#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 -// // switch optimizaion level -// if (builder_optimization_level_ != 3) { -// trt_config->setBuilderOptimizationLevel(builder_optimization_level_); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; -// } -// -// // limit auxiliary streams -// if (auxiliary_streams_ >= 0) { -// trt_config->setMaxAuxStreams(auxiliary_streams_); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; -// } -//#else -// if (builder_optimization_level_ != 3) { -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; -// } -// if (auxiliary_streams_ >= 0) { -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; -// } -//#endif -// -// if (weight_stripped_engine_enable_) { -//#if NV_TENSORRT_MAJOR >= 10 -// trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; -// trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; -//#else -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; -//#endif -// } -// -// // limit used tactic sources -// if (!tactic_sources_.empty()) { -// nvinfer1::TacticSources tactics = trt_config->getTacticSources(); -// tactics |= GetTacticSourceFromString(tactic_sources_); -// trt_config->setTacticSources(tactics); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; -// } -// -// // Build TRT engine (if needed) and load TRT engine if: -// // (1) Graph has no dynamic shape input -// // (2) All the dynamic shape inputs have associated explicit profiles specified by user -// // -// // Otherwise engine will be handled at inference time. -// std::unique_ptr trt_engine; -// std::unique_ptr trt_context; -// -// std::string cache_path = ""; -// std::string cache_suffix = ""; -// // Customize cache prefix if assigned -// if (!cache_prefix_.empty()) { -// // Generate cache suffix in case user would like to customize cache prefix -// cache_suffix = "_" + GetCacheSuffix(fused_node.Name(), trt_node_name_with_precision); -// cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; -// } else { -// cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); -// } -// -// std::string cache_hw_compat = "_sm" + compute_capability_; -// // Enable hardware compatility mode if assigned -// if (engine_cache_enable_ && engine_hw_compatible_) { -// trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); -// cache_hw_compat = "_sm80+"; -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; -// } -// -// // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache -// // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity -// const std::string cache_path_prefix = cache_path + cache_hw_compat; -// std::string engine_cache_path = cache_path_prefix + ".engine"; -// const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; -// const std::string profile_cache_path = cache_path_prefix + ".profile"; -// -// // If weight-stripped engine is enabled and refitted engine cache is not present, -// // TRT EP will use the engine cache with ".stripped.engine" appended to the end. -// const std::filesystem::path engine_cache_fs_path = engine_cache_path; -// if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { -// engine_cache_path = cache_path_prefix + ".stripped.engine"; -// weight_stripped_engine_refit_ = true; -// } -// -// // Generate file name for dumping ep context model -// if (dump_ep_context_model_ && ctx_model_path_.empty()) { -// ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); -// } -// -// if (!has_dynamic_shape) { -// std::string timing_cache_path = ""; -// bool engine_update = false; -// if (timing_cache_enable_) { -// timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); -// } -// { -// // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading. -// auto lock = GetApiLock(); -// -// // If explicit profile flag is on and engine cache enable flag is on, -// // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine. -// if (has_explicit_profile && engine_cache_enable_) { -// engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); -// if (engine_update) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; -// } else { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; -// } -// } -// -// std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); -// if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { -// engine_file.seekg(0, std::ios::end); -// size_t engine_size = engine_file.tellg(); -// engine_file.seekg(0, std::ios::beg); -// std::unique_ptr engine_buf{new char[engine_size]}; -// engine_file.read((char*)engine_buf.get(), engine_size); -// trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; -// if (trt_engine == nullptr) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not deserialize engine from cache: " + engine_cache_path); -// } -// -// } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { -// // Decrypt engine -// size_t engine_size = 0; -// if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not get engine buffer size"); -// } -// std::unique_ptr engine_buf{new char[engine_size]}; -// if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not call engine decryption function decrypt"); -// } -// // Deserialize engine -// trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; -// if (trt_engine == nullptr) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); -// } -// } else { -// // Set INT8 per tensor dynamic range -// if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { -//#if defined(_MSC_VER) -//#pragma warning(push) -//#pragma warning(disable : 4996) -//#endif -// trt_config->setInt8Calibrator(nullptr); -//#if defined(_MSC_VER) -//#pragma warning(pop) -//#endif -// if (!SetDynamicRange(*trt_network, dynamic_range_map)) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name()); -// } -// } -// -// // Load timing cache from file. Create a fresh cache if the file doesn't exist -// std::unique_ptr timing_cache = nullptr; -// if (timing_cache_enable_) { -// std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); -// timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); -// if (timing_cache == nullptr) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not create timing cache: " + timing_cache_path); -// } -// trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); -// if (detailed_build_log_) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; -// } -// } -// -// // Build engine -// std::chrono::steady_clock::time_point engine_build_start; -// if (detailed_build_log_) { -// engine_build_start = std::chrono::steady_clock::now(); -// } -// std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; -// if (serialized_engine == nullptr) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP failed to create engine from network for fused node: " + fused_node.Name()); -// } -// trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); -// if (trt_engine == nullptr) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP failed to deserialize engine for fused node: " + fused_node.Name()); -// } -// if (detailed_build_log_) { -// auto engine_build_stop = std::chrono::steady_clock::now(); -// LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; -// } -// if (engine_cache_enable_) { -// // Serialize engine profile if it has explicit profiles -// if (has_explicit_profile) { -// SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; -// } -// -// if (engine_decryption_enable_) { -// // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. -// if (engine_encryption_ != nullptr) { -// if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP call to engine encryption library failed"); -// } -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; -// } else { -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; -// } -// } else { -// std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); -// file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; -// } -// } -// // serialize and save timing cache -// if (timing_cache_enable_) { -// auto timing_cache = trt_config->getTimingCache(); -// std::unique_ptr timingCacheHostData{timing_cache->serialize()}; -// if (timingCacheHostData == nullptr) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not serialize timing cache: " + timing_cache_path); -// } -// saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); -// if (detailed_build_log_) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; -// } -// } -// // dump EP context node model -// if (dump_ep_context_model_) { -// // "ep_cache_context" node attribute should be a relative path to context model directory -// if (ep_cache_context_attr_.empty()) { -// auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); -// ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); -// } -// std::string compute_capability_hw_compat = compute_capability_; -// if (engine_cache_enable_ && engine_hw_compatible_) { -// compute_capability_hw_compat = "80+"; -// } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (builder_optimization_level_ != 3) { + trt_config->setBuilderOptimizationLevel(builder_optimization_level_); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } + + // limit auxiliary streams + if (auxiliary_streams_ >= 0) { + trt_config->setMaxAuxStreams(auxiliary_streams_); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_; + } +#else + if (builder_optimization_level_ != 3) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (auxiliary_streams_ >= 0) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + + if (weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } + + // limit used tactic sources + if (!tactic_sources_.empty()) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= GetTacticSourceFromString(tactic_sources_); + trt_config->setTacticSources(tactics); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_; + } + + // Build TRT engine (if needed) and load TRT engine if: + // (1) Graph has no dynamic shape input + // (2) All the dynamic shape inputs have associated explicit profiles specified by user + // + // Otherwise engine will be handled at inference time. + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + // Generate cache suffix in case user would like to customize cache prefix + cache_suffix_ = "_" + GetCacheSuffix(node_name, trt_node_name_with_precision_); + cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix_; + } else { + cache_path = GetCachePath(cache_path_, trt_node_name_with_precision_); + } + + std::string cache_hw_compat = "_sm" + compute_capability_; + // Enable hardware compatility mode if assigned + if (engine_cache_enable_ && engine_hw_compatible_) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + cache_hw_compat = "_sm80+"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + } + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit_ = true; + } + + // Generate file name for dumping ep context model + if (dump_ep_context_model_ && ctx_model_path_.empty()) { + ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); + } + + if (!has_dynamic_shape) { + std::string timing_cache_path = ""; + bool engine_update = false; + if (timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); + } + { + // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading. + // auto lock = GetApiLock(); // TODO(leca) + + // If explicit profile flag is on and engine cache enable flag is on, + // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine. + if (has_explicit_profile && engine_cache_enable_) { + engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + if (engine_update) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built"; + } else { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt"; + } + } + + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + if (trt_engine == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from cache: " + engine_cache_path).c_str()); + } + + } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { + // Decrypt engine + size_t engine_size = 0; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + if (trt_engine == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path).c_str()); + } + } else { + // Set INT8 per tensor dynamic range + if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_network, dynamic_range_map_)) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not set INT8 dynamic range for fused node: " + std::string(node_name)).c_str()); + } + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (timing_cache_enable_) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not create timing cache: " + timing_cache_path).c_str()); + } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + } + } + + // Build engine + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; + if (serialized_engine == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP failed to create engine from network for fused node: " + std::string(node_name)).c_str()); + } + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (trt_engine == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP failed to deserialize engine for fused node: " + std::string(node_name)).c_str()); + } + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + //LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + } + if (engine_cache_enable_) { + // Serialize engine profile if it has explicit profiles + if (has_explicit_profile) { + SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + } + + if (engine_decryption_enable_) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. + if (engine_encryption_ != nullptr) { + if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP call to engine encryption library failed"); + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; + } + } + // serialize and save timing cache + if (timing_cache_enable_) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not serialize timing cache: " + timing_cache_path).c_str()); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + } + } + // dump EP context node model + if (dump_ep_context_model_) { + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + } + std::string compute_capability_hw_compat = compute_capability_; + if (engine_cache_enable_ && engine_hw_compatible_) { + compute_capability_hw_compat = "80+"; + } // std::unique_ptr model_proto{CreateCtxModel(graph_body_viewer, // ep_cache_context_attr_, // reinterpret_cast(serialized_engine->data()), @@ -1163,652 +1507,643 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // model_path_, // GetLogger())}; // DumpCtxModel(model_proto.get(), ctx_model_path_); -// } -// } -// } -// -// if (weight_stripped_engine_refit_) { -// auto status = RefitEngine(model_path_, -// onnx_model_folder_path_, -// engine_cache_path, -// false /* path check for security */, -// trt_engine.get(), -// true /* serialize refitted engine to disk */, -// detailed_build_log_); -// if (status != Status::OK()) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); -// } -// } -// -// // Build context -// // Note: Creating an execution context from an engine is thread safe per TRT doc -// // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading -// if (context_memory_sharing_enable_) { -//#if defined(_MSC_VER) -//#pragma warning(push) -//#pragma warning(disable : 4996) -//#endif -// size_t mem_size = trt_engine->getDeviceMemorySize(); -//#if defined(_MSC_VER) -//#pragma warning(pop) -//#endif -// if (mem_size > max_ctx_mem_size_) { -// max_ctx_mem_size_ = mem_size; -// } -//#if NV_TENSORRT_MAJOR < 10 -// trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -//#else -// trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -//#endif -// } else { -// trt_context = std::unique_ptr(trt_engine->createExecutionContext()); -// } -// if (!trt_context) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not build execution context for fused node: " + fused_node.Name()); -// } -// } -// -// // Create input to index map -// for (int i = 0; i < num_inputs; ++i) { -// auto input = trt_network->getInput(i); -// const std::string& input_name = input->getName(); -// const auto& iter = input_map.find(input_name); -// if (iter != input_map.end()) { -// input_indexes[input_name] = iter->second; -// } -// } -// -// // Create output to index and type maps -// const auto& graph_output = model_proto->graph().output(); -// for (int i = 0; i < num_outputs; ++i) { -// const std::string& output_name = trt_network->getOutput(i)->getName(); -// const auto& iter = output_map.find(output_name); -// if (iter != output_map.end()) { -// output_indexes[output_name] = iter->second; -// } -// const auto& tensor_type = graph_output[i].type().tensor_type(); -// output_types[output_name] = tensor_type.elem_type(); -// } -// -// // Save TRT engine, other TRT objects and input/output info to map -// parsers_.emplace(fused_node.Name(), std::move(trt_parser)); -// engines_.emplace(fused_node.Name(), std::move(trt_engine)); -// contexts_.emplace(fused_node.Name(), std::move(trt_context)); -// networks_.emplace(fused_node.Name(), std::move(trt_network)); -// input_info_[fused_node.Name()].push_back(input_indexes); -// output_info_[fused_node.Name()].push_back(output_indexes); -// output_info_[fused_node.Name()].push_back(output_types); -// input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges; -// profiles_.emplace(fused_node.Name(), std::move(trt_profiles)); -// -// // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty engine. -// // TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model. -// // However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here. -// if (dump_ep_context_model_ && has_dynamic_shape) { -// // "ep_cache_context" node attribute should be a relative path to context model directory -// if (ep_cache_context_attr_.empty()) { -// auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); -// ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); -// } -// std::string compute_capability_hw_compat = compute_capability_; -// if (engine_cache_enable_ && engine_hw_compatible_) { -// compute_capability_hw_compat = "80+"; -// } -// model_proto_.reset(CreateCtxModel(graph_body_viewer, -// ep_cache_context_attr_, -// nullptr, -// 0, -// ep_context_embed_mode_, -// compute_capability_hw_compat, -// model_path_, -// GetLogger())); -// if (ep_context_embed_mode_ == 0) { -// DumpCtxModel(model_proto_.get(), ctx_model_path_); -// } -// } -// -// // Create function state -// // TODO: remove default capture -// NodeComputeInfo compute_info; -// compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { -// std::unique_ptr p = std::make_unique(); -// // translate tactic sources string to nvinfer1::TacticSources -// nvinfer1::TacticSources tactics = 0; -// if (!tactic_sources_.empty()) { -// tactics = GetTacticSourceFromString(tactic_sources_); -// } -// *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), -// &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], -// &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], -// input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, -// dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, -// engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], -// context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, -// engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_, force_timing_cache_match_, -// detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_, -// auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix, engine_hw_compatible_}; -// *state = p.release(); -// return 0; -// }; -// -// // Release function state -// compute_info.release_state_func = [](FunctionState state) { -// delete static_cast(state); -// }; -// -// // Create compute function -// compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { -// Ort::KernelContext ctx(context); -// -// TensorrtFuncState* trt_state = reinterpret_cast(state); -// -// // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, -// // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. -// // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading -// std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); -// const std::unordered_map& input_indexes = (trt_state->input_info)[0]; -// const std::unordered_map& output_indexes = (trt_state->output_info)[0]; -// const std::unordered_map& output_types = (trt_state->output_info)[1]; -// auto fused_node_name = trt_state->fused_node_name; -// // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. -// // The info is used for both shape tensor and execution tensor: -// // tensor name->(dimension->[min, max, opt]) -// auto& shape_ranges = trt_state->input_shape_ranges; -// std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run -// std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input -// auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; -// auto trt_builder = trt_state->builder; -// auto trt_engine = trt_state->engine->get(); -// auto trt_context = trt_state->context->get(); -// auto trt_profiles = trt_state->profiles; -// auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; -// int num_inputs = static_cast(input_indexes.size()); -// int num_outputs = static_cast(output_indexes.size()); -// bool engine_update = false; -// bool context_update = false; -// std::unordered_set input_names; -// -// OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); -// OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); -// if (alloc_ == nullptr) { -// Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); -// } -// OrtAllocator* alloc = alloc_; -// -// void* cuda_stream; -// Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); -// cudaStream_t stream = static_cast(cuda_stream); -// -// // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache -// // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity -// // Prepare cache name -// std::string cache_path = ""; -// // Customize cache prefix if assigned -// if (!cache_prefix_.empty()) { -// cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; -// } else { -// cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); -// } -// -// // Enable hardware compatility mode if assigned -// std::string cache_hw_compat = "_sm" + compute_capability_; -// if (engine_cache_enable_ && engine_hw_compatible_) { -// cache_hw_compat = "_sm80+"; -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; -// } -// -// // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache -// // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity -// const std::string cache_path_prefix = cache_path + cache_hw_compat; -// std::string engine_cache_path = cache_path_prefix + ".engine"; -// const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; -// const std::string profile_cache_path = cache_path_prefix + ".profile"; -// std::string timing_cache_path = ""; -// if (timing_cache_enable_) { -// timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); -// } -// -// // If weight-stripped engine is enabled and refitted engine cache is not present, -// // TRT EP will use the engine cache with ".stripped.engine" appended to the end. -// const std::filesystem::path engine_cache_fs_path = engine_cache_path; -// if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { -// engine_cache_path = cache_path_prefix + ".stripped.engine"; -// weight_stripped_engine_refit_ = true; -// } -// -// // Load serialized engine -// if (trt_state->engine_cache_enable && trt_engine == nullptr) { -// std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); -// std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); -// if (engine_file && !trt_state->engine_decryption_enable && profile_file) { -// // Deserialize profile -// shape_ranges = DeserializeProfileV2(profile_file); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; -// -// // Prepare buffer -// engine_file.seekg(0, std::ios::end); -// size_t engine_size = engine_file.tellg(); -// engine_file.seekg(0, std::ios::beg); -// std::unique_ptr engine_buf{new char[engine_size]}; -// engine_file.read((char*)engine_buf.get(), engine_size); -// -// // Deserialize engine -// // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc -// // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading -// trt_state->engine->reset(); -// *(trt_state->engine) = std::unique_ptr( -// trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); -// if (!(*(trt_state->engine))) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); -// } -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; -// trt_engine = trt_state->engine->get(); -// context_update = true; -// -// } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { -// shape_ranges = DeserializeProfileV2(profile_file); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; -// // Decrypt engine -// size_t engine_size = 0; -// if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not get engine buffer size"); -// } -// std::unique_ptr engine_buf{new char[engine_size]}; -// if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not call engine decryption function decrypt"); -// } -// // Deserialize engine -// // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc -// // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading -// trt_state->engine->reset(); -// *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); -// if (!(*(trt_state->engine))) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); -// } -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; -// trt_engine = trt_state->engine->get(); -// context_update = true; -// } -// } -// -// // Check and update shape ranges for dynamic shape inputs. -// for (int i = 0, end = num_inputs; i < end; ++i) { -// auto input = trt_state->network->get()->getInput(i); -// const std::string& input_name = input->getName(); -// input_names.insert(input_name); -// -// // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. -// // TRT EP will help determine the min/max/opt profile values based on current input tensor value. -// if (shape_ranges.find(input_name) != shape_ranges.end()) { -// auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, shape_tensor_values, shape_tensor_values_int64, stream, &engine_update); -// if (status != Status::OK()) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); -// } -// } -// } -// -// // Regenerate engine -// if (engine_update) { -// // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. -// trt_state->context->reset(); -// trt_state->engine->reset(); -// auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); -// trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, *(trt_state->max_workspace_size_ptr)); -// for (auto trt_profile : trt_profiles) { -// trt_config->addOptimizationProfile(trt_profile); -// } -// -// // Set INT8 Per Tensor Dynamic range -// if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { -//#if defined(_MSC_VER) -//#pragma warning(push) -//#pragma warning(disable : 4996) -//#endif -// trt_config->setInt8Calibrator(nullptr); -//#if defined(_MSC_VER) -//#pragma warning(pop) -//#endif -// if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); -// } -// } -// -// // Set precision -// if (trt_state->fp16_enable && trt_state->int8_enable) { -// trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); -// } else if (trt_state->fp16_enable) { -// trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); -// } else if (trt_state->int8_enable) { -// trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); -// } -// -// // Set DLA (DLA can only run with FP16 or INT8) -// if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; -// trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); -// trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); -// trt_config->setDLACore(trt_state->dla_core); -// } -// -// // enable sparse weights -// if (trt_state->sparsity_enable) { -// trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; -// } -//#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 -// // enable builder heuristics -// if (trt_state->build_heuristics_enable) { -// trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; -// } -//#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 -// // switch optimizaion level -// if (trt_state->builder_optimization_level != 3) { -// trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; -// } -// -// // limit auxiliary streams -// if (trt_state->auxiliary_streams >= 0) { -// trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; -// } -//#else -// if (trt_state->builder_optimization_level != 3) { -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; -// } -// if (trt_state->auxiliary_streams >= 0) { -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; -// } -//#endif -// if (weight_stripped_engine_enable_) { -//#if NV_TENSORRT_MAJOR >= 10 -// trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; -// trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; -//#else -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; -//#endif -// } -// // limit used tactic sources -// if (trt_state->filter_tactic_sources) { -// nvinfer1::TacticSources tactics = trt_config->getTacticSources(); -// tactics |= trt_state->tactic_sources; -// trt_config->setTacticSources(tactics); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; -// } -// -// // Load timing cache from file. Create a fresh cache if the file doesn't exist -// std::unique_ptr timing_cache = nullptr; -// if (trt_state->timing_cache_enable) { -// std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); -// timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); -// if (timing_cache == nullptr) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not create timing cache: " + timing_cache_path); -// } -// trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); -// if (detailed_build_log_) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; -// } -// } -// -// // Enable hardware compatility mode if assigned -// if (trt_state->engine_hw_compatible) { -// trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); -// LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; -// } -// -// // Build engine -// std::unique_ptr serialized_engine; -// { -// auto lock = GetApiLock(); -// std::chrono::steady_clock::time_point engine_build_start; -// if (detailed_build_log_) { -// engine_build_start = std::chrono::steady_clock::now(); -// } -// serialized_engine = std::unique_ptr( -// trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); -// if (!serialized_engine) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create engine from network."); -// } -// *(trt_state->engine) = std::unique_ptr( -// trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); -// if (!(*(trt_state->engine))) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to deserialize engine."); -// } -// if (detailed_build_log_) { -// auto engine_build_stop = std::chrono::steady_clock::now(); -// LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; -// } -// } -// if (!(*(trt_state->engine))) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); -// } -// trt_engine = trt_state->engine->get(); -// if (trt_state->engine_cache_enable) { -// // Serialize engine profile -// SerializeProfileV2(profile_cache_path, shape_ranges); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; -// -// // Serialize engine -// if (trt_state->engine_decryption_enable) { -// // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. -// if (trt_state->engine_encryption != nullptr) { -// if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not call engine encryption function encrypt"); -// } -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; -// } else { -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; -// } -// } else { -// std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); -// file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; -// } -// } -// -// // serialize and save timing cache -// if (trt_state->timing_cache_enable) { -// auto timing_cache = trt_config->getTimingCache(); -// std::unique_ptr timingCacheHostData{timing_cache->serialize()}; -// if (timingCacheHostData == nullptr) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, -// "TensorRT EP could not serialize timing cache: " + timing_cache_path); -// } -// saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); -// if (detailed_build_log_) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; -// } -// } -// -// // dump ep context model -// if (dump_ep_context_model_ && ep_context_embed_mode_) { -// UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), serialized_engine->size()); -// DumpCtxModel(model_proto_.get(), ctx_model_path_); -// } -// context_update = true; -// -// if (weight_stripped_engine_refit_) { -// auto status = RefitEngine(model_path_, -// onnx_model_folder_path_, -// engine_cache_path, -// false /* path check for security */, -// trt_engine, -// true /* serialize refitted engine to disk */, -// detailed_build_log_); -// if (status != Status::OK()) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); -// } -// } -// } -// -// if (context_update) { -// if (trt_state->context_memory_sharing_enable) { -//#if NV_TENSORRT_MAJOR < 10 -// *(trt_state->context) = std::unique_ptr( -// trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); -//#else -// *(trt_state->context) = std::unique_ptr( -// trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -//#endif -// } else { -// *(trt_state->context) = std::unique_ptr( -// trt_state->engine->get()->createExecutionContext()); -// } -// if (!(*(trt_state->context))) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); -// } -// trt_context = trt_state->context->get(); -// } -// -// // Get input and output binding names -// int total_bindings = trt_engine->getNbIOTensors(); -// std::vector input_binding_names, output_binding_names; -// for (int i = 0, end = total_bindings; i < end; ++i) { -// auto const& name = trt_engine->getIOTensorName(i); -// auto const& mode = trt_engine->getTensorIOMode(name); -// if (mode == nvinfer1::TensorIOMode::kINPUT) { -// input_binding_names.push_back(name); -// } else { -// output_binding_names.push_back(name); -// } -// } -// -// /* -// * Set input shapes and bind input buffers -// */ -// std::vector> scratch_buffers; -// for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { -// char const* input_name = input_binding_names[i]; -// -// size_t input_index = 0; -// const auto iter = input_indexes.find(input_name); -// if (iter != input_indexes.end()) { -// input_index = iter->second; -// } -// auto input_tensor = ctx.GetInput(input_index); -// auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); -// const auto tensor_shapes = tensor_info.GetShape(); -// -// auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); -// if (status != Status::OK()) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); -// } -// } -// -// /* -// * Set output shapes and bind output buffers -// */ -// std::unordered_map buffers; -// buffers.reserve(num_outputs); -// using OutputOrtValue = Ort::UnownedValue; -// std::unordered_map output_tensors; -// output_tensors.reserve(num_outputs); -// std::unordered_map output_dim_sizes; -// output_dim_sizes.reserve(num_outputs); -// -// for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { -// char const* output_name = output_binding_names[i]; -// -// size_t output_index = 0; -// const auto& index_iter = output_indexes.find(output_name); -// if (index_iter != output_indexes.end()) { -// output_index = index_iter->second; -// } -// -// size_t output_type = 0; -// const auto type_iter = output_types.find(output_name); -// if (type_iter != output_types.end()) { -// output_type = type_iter->second; -// } -// -// Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, -// dds_output_allocator_map, scratch_buffers, alloc, buffers); -// if (status != Status::OK()) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); -// } -// } -// -// // Set execution context memory -// if (trt_state->context_memory_sharing_enable) { -//#if defined(_MSC_VER) -//#pragma warning(push) -//#pragma warning(disable : 4996) -//#endif -// size_t mem_size = trt_engine->getDeviceMemorySize(); -//#if defined(_MSC_VER) -//#pragma warning(pop) -//#endif -// if (mem_size > *max_context_mem_size_ptr) { -// *max_context_mem_size_ptr = mem_size; -// } -// trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + } + } + } + + if (weight_stripped_engine_refit_) { + auto status = RefitEngine(model_path_, + onnx_model_folder_path_, + engine_cache_path, + false /* path check for security */, + trt_engine.get(), + true /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != nullptr) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); + } + } + + // Build context + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } +#if NV_TENSORRT_MAJOR < 10 + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not build execution context for fused node: " + std::string(node_name)).c_str()); + } + } + + // Create input to index map + for (int i = 0; i < num_inputs; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + const auto& iter = input_map.find(input_name); + if (iter != input_map.end()) { + input_indexes[input_name] = iter->second; + } + } + + // Create output to index and type maps + for (int i = 0; i < num_outputs; ++i) { + const std::string& output_name = trt_network->getOutput(i)->getName(); + const auto& iter = output_map.find(output_name); + if (iter != output_map.end()) { + output_indexes[output_name] = iter->second; + } + output_types[output_name] = api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); + } + + // Save TRT engine, other TRT objects and input/output info to map + parsers_.emplace(node_name, std::move(trt_parser)); + engines_.emplace(node_name, std::move(trt_engine)); + contexts_.emplace(node_name, std::move(trt_context)); + networks_.emplace(node_name, std::move(trt_network)); + input_info_[node_name].push_back(input_indexes); + output_info_[node_name].push_back(output_indexes); + output_info_[node_name].push_back(output_types); + input_shape_ranges_[node_name] = input_implicit_shape_ranges; + profiles_.emplace(node_name, std::move(trt_profiles)); + + // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty engine. + // TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model. + // However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here. + if (dump_ep_context_model_ && has_dynamic_shape) { + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + } + std::string compute_capability_hw_compat = compute_capability_; + if (engine_cache_enable_ && engine_hw_compatible_) { + compute_capability_hw_compat = "80+"; + } +// model_proto_.reset(CreateCtxModel(graph_body_viewer, +// ep_cache_context_attr_, +// nullptr, +// 0, +// ep_context_embed_mode_, +// compute_capability_hw_compat, +// model_path_, +// GetLogger())); +// if (ep_context_embed_mode_ == 0) { +// DumpCtxModel(model_proto_.get(), ctx_model_path_); // } -// -// // Start CUDA graph capture. -// // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because -// // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + } + + // Create function state + (*node_compute_funcs)->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + std::unique_ptr p = std::make_unique(); + + // translate tactic sources string to nvinfer1::TacticSources + nvinfer1::TacticSources tactics = 0; + if (!this_->tactic_sources_.empty()) { + tactics = GetTacticSourceFromString(this_->tactic_sources_); + } + *p = {context->AllocateFunc, context->DestroyFunc, context->allocator_handle, context->node_name, this_->builder_.get(), + &(this_->parsers_[context->node_name]), &(this_->engines_[context->node_name]), &(this_->contexts_[context->node_name]), + &(this_->networks_[context->node_name]), this_->input_info_[context->node_name], this_->output_info_[context->node_name], + this_->input_shape_ranges_[context->node_name], /*&tensorrt_mu_,*/ this_->fp16_enable_, this_->int8_enable_, this_->int8_calibration_cache_available_, + this_->dla_enable_, this_->dla_core_, &(this_->max_workspace_size_), this_->trt_node_name_with_precision_, + this_->engine_cache_enable_, this_->cache_path_, this_->runtime_.get(), this_->profiles_[context->node_name], + this_->context_memory_sharing_enable_, &(this_->max_ctx_mem_size_), this_->dynamic_range_map_, this_->engine_decryption_enable_, + this_->engine_decryption_, this_->engine_encryption_, this_->timing_cache_enable_, this_->global_cache_path_, this_->force_timing_cache_match_, + this_->detailed_build_log_, this_->build_heuristics_enable_, this_->sparsity_enable_, this_->builder_optimization_level_, + this_->auxiliary_streams_, !(this_->tactic_sources_.empty()), tactics, this_->cuda_graph_enable_, this_->cache_prefix_, this_->cache_suffix_, this_->engine_hw_compatible_}; + *state = p.release(); + return 0; + }; + + // Release function state + (*node_compute_funcs)->DestroyFunctionStateFunc = [](void* state) { + delete static_cast(state); + }; + + // Create compute function + (*node_compute_funcs)->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + Ort::KernelContext ctx(context); + TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); + TensorrtFuncState* trt_state = reinterpret_cast(state); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, + // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + //std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); // TODO(leca) + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto fused_node_name = trt_state->fused_node_name; + // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. + // The info is used for both shape tensor and execution tensor: + // tensor name->(dimension->[min, max, opt]) + auto& shape_ranges = trt_state->input_shape_ranges; + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input + auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; + auto trt_builder = trt_state->builder; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto trt_profiles = trt_state->profiles; + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + int num_inputs = static_cast(input_indexes.size()); + int num_outputs = static_cast(output_indexes.size()); + bool engine_update = false; + bool context_update = false; + std::unordered_set input_names; + + OrtMemoryInfo* mem_info = nullptr; + api->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); + if (this_->alloc_ == nullptr) { + Ort::ThrowOnError(api->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); + } + OrtAllocator* alloc = this_->alloc_; + + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity + // Prepare cache name + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!this_->cache_prefix_.empty()) { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; + } else { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); + } + + // Enable hardware compatility mode if assigned + std::string cache_hw_compat = "_sm" + this_->compute_capability_; + if (this_->engine_cache_enable_ && this_->engine_hw_compatible_) { + cache_hw_compat = "_sm80+"; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + } + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + std::string timing_cache_path = ""; + if (this_->timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(this_->global_cache_path_, this_->compute_capability_); + } + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (this_->weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + this_->weight_stripped_engine_refit_ = true; + } + + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && !trt_state->engine_decryption_enable && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path).c_str()); + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } + } + + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); + + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. + // TRT EP will help determine the min/max/opt profile values based on current input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, shape_tensor_values, shape_tensor_values_int64, stream, &engine_update); + if (status != nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); + } + } + } + + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. + trt_state->context->reset(); + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, *(trt_state->max_workspace_size_ptr)); + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } + + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); + } + } + + // Set precision + if (trt_state->fp16_enable && trt_state->int8_enable) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + } else if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + } else if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + } + + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } + + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } + + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + } +#else + if (trt_state->builder_optimization_level != 3) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (trt_state->auxiliary_streams >= 0) { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } +#endif + if (this_->weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not create timing cache: " + timing_cache_path).c_str()); + } + trt_config->setTimingCache(*timing_cache, this_->force_timing_cache_match_); + if (this_->detailed_build_log_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + } + } + + // Enable hardware compatility mode if assigned + if (trt_state->engine_hw_compatible) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + //LOGS_DEFAULT(INFO) << "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; + } + + // Build engine + std::unique_ptr serialized_engine; + { + //auto lock = GetApiLock(); // TODO(leca) + std::chrono::steady_clock::time_point engine_build_start; + if (this_->detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + serialized_engine = std::unique_ptr( + trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); + if (!serialized_engine) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create engine from network."); + } + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (!(*(trt_state->engine))) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to deserialize engine."); + } + if (this_->detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + //LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + } + } + if (!(*(trt_state->engine))) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + + // Serialize engine + if (trt_state->engine_decryption_enable) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. + if (trt_state->engine_encryption != nullptr) { + if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine encryption function encrypt"); + } + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + //LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + } + } + + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not serialize timing cache: " + timing_cache_path).c_str()); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (this_->detailed_build_log_) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; + } + } + + // dump ep context model + if (this_->dump_ep_context_model_ && this_->ep_context_embed_mode_) { + //UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), serialized_engine->size()); // TODO(leca) + //DumpCtxModel(model_proto_.get(), ctx_model_path_); + } + context_update = true; + + if (this_->weight_stripped_engine_refit_) { + auto status = RefitEngine(this_->model_path_, + this_->onnx_model_folder_path_, + engine_cache_path, + false /* path check for security */, + trt_engine, + true /* serialize refitted engine to disk */, + this_->detailed_build_log_); + if (status != nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + } + } + } + + if (context_update) { + if (trt_state->context_memory_sharing_enable) { +#if NV_TENSORRT_MAJOR < 10 + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); +#else + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext()); + } + if (!(*(trt_state->context))) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create context."); + } + trt_context = trt_state->context->get(); + } + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + OrtStatusPtr status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, + dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + } + } + + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + } + trt_context->setDeviceMemory(MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + } + + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. // if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; // cuda_graph_.SetStream(stream); // CaptureBegin(0); // } -// -// // Run TRT inference -// if (!trt_context->enqueueV3(stream)) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); -// } -// -// /* -// * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, -// * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: -// * -// * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. -// * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, -// * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. -// * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. -// * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream -// * is guaranteed. -// * -// * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. -// * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. -// */ -// if (sync_stream_after_enqueue_) { -// CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); -// } -// -// // Assign TRT output back to ORT output -// // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) -// // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output -// for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { -// char const* output_name = output_binding_names[i]; -// -// size_t output_type = 0; -// const auto& iter = output_types.find(output_name); -// if (iter != output_types.end()) { -// output_type = iter->second; -// } -// -// if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { -// size_t output_index = 0; -// const auto& index_iter = output_indexes.find(output_name); -// if (index_iter != output_indexes.end()) { -// output_index = index_iter->second; -// } -// auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); -// if (status != Status::OK()) { -// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); -// } -// } else { -// auto& output_tensor = output_tensors[i]; + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP execution context enqueue failed."); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, + * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. + * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, + * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. + * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. + * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream + * is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. + * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (this_->sync_stream_after_enqueue_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + auto status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != nullptr) { + return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + } + } else { + auto& output_tensor = output_tensors[i]; //#if NV_TENSORRT_MAJOR < 10 // if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { // auto output_tensor_ptr = output_tensor.GetTensorMutableData(); @@ -1823,9 +2158,9 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); // } // } -// } -// } -// + } + } + // // End CUDA graph capture. // // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture // // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. @@ -1840,9 +2175,9 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // IncrementRegularRunCountBeforeGraphCapture(); // } // } -// -// return nullptr; -// }; + + return nullptr; + }; return nullptr; } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 8b4ca2ff61ed3..805177a7bcf64 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -179,6 +179,9 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { bool serialize_refitted_engine, bool detailed_build_log); static const OrtApi* api_; + std::string trt_node_name_with_precision_; + std::unordered_map dynamic_range_map_; + std::string cache_suffix_; private: // mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index 97b9ffd91961c..2d652488b75d4 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -81,6 +81,168 @@ int GetNumProfiles(std::unordered_map>>>& shape_ranges) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In SerializeProfileV2()"; + // Serialize profile + flexbuffers::Builder builder; + auto tensor_map_start = builder.StartMap(); + for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] input tensor is '" << tensor_it->first.c_str() << "'"; + builder.TypedVector(tensor_it->first.c_str(), [&] { + for (auto dim_it = tensor_it->second.begin(); dim_it != tensor_it->second.end(); dim_it++) { + size_t num_profiles = dim_it->second.size(); + for (size_t i = 0; i < num_profiles; i++) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] profile #" << i << ", dim is " << dim_it->first; + builder.Int(dim_it->first); + builder.Int(dim_it->second[i][0]); + builder.Int(dim_it->second[i][1]); + builder.Int(dim_it->second[i][2]); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << dim_it->first << ", " << dim_it->second[i][0] << ", " << dim_it->second[i][1] << ", " << dim_it->second[i][2]; + } + } + }); + } + builder.EndMap(tensor_map_start); + builder.Finish(); + + // Save flexbuffer + std::ofstream file(file_name, std::ios::binary | std::ios::out); + auto buf = builder.GetBuffer(); + size_t size = builder.GetSize(); + file.write(reinterpret_cast(&buf[0]), size); + file.close(); +} + +std::unordered_map>>> DeserializeProfileV2(std::ifstream& infile) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In DeserializeProfileV2()"; + // Load flexbuffer + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read((char*)data.get(), length); + infile.close(); + + // Deserialize profile + std::unordered_map>>> shape_ranges; + auto tensors_range_entries = flexbuffers::GetRoot((const uint8_t*)data.get(), length).AsMap(); + auto keys = tensors_range_entries.Keys(); + auto values = tensors_range_entries.Values(); + for (size_t i = 0, end = keys.size(); i < end; ++i) { // iterate tensors + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] input tensor is '" << keys[i].AsString().c_str() << "'"; + auto dim_range_vector = values[i].AsTypedVector(); + std::unordered_map>> inner_map; + std::vector> profile_vector; + + for (size_t k = 0; k < (dim_range_vector.size() / 4); k++) { // iterate dim, min, max, opt for all profiles + std::vector shape_vector; + auto idx = 4 * k; + auto dim = dim_range_vector[idx].AsInt64(); + shape_vector.push_back(dim_range_vector[idx + 1].AsInt64()); // min shape + shape_vector.push_back(dim_range_vector[idx + 2].AsInt64()); // max shape + shape_vector.push_back(dim_range_vector[idx + 3].AsInt64()); // opt shape + + if (inner_map.find(dim) == inner_map.end()) { + inner_map[dim] = profile_vector; + } + inner_map[dim].push_back(shape_vector); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << dim << ", " << shape_vector[0] << ", " << shape_vector[1] << ", " << shape_vector[2]; + } + shape_ranges[keys[i].AsString().c_str()] = inner_map; + } + return shape_ranges; +} + +bool CompareProfiles(const std::string& file_name, + std::unordered_map>>& profile_min_shapes, + std::unordered_map>>& profile_max_shapes, + std::unordered_map>>& profile_opt_shapes) { + std::ifstream profile_file(file_name, std::ios::binary | std::ios::in); + if (!profile_file) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << file_name << " doesn't exist."; + return true; + } + + std::unordered_map>>> shape_ranges; + shape_ranges = DeserializeProfileV2(profile_file); + + /* The format of the two data structures are below, for example: + * + * shape_ranges: + * { + * tensor_a: { + * dim_0: [[min_shape, max_shape, opt_shape]], + * dim_2: [[min_shape, max_shape, opt_shape]] + * }, + * tensor_b: { + * dim_1: [[min_shape, max_shape, opt_shape]] + * } + * } + * + * profile_min_shapes: + * { + * tensor_a: [[dim_0_value_0, dim_1_value_1, dim_2_value_2]], + * tensor_b: [[dim_0_value_3, dim_1_value_4, dim_2_value_5]] + * } + * + */ + + // Check number of dynamic shape inputs + if (profile_min_shapes.size() != shape_ranges.size()) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of dynamic shape inputs are not the same."; + return true; + } + + // Iterate through shape_ranges map + for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors + auto tensor_name = tensor_it->first; + if (profile_min_shapes.find(tensor_name) == profile_min_shapes.end()) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tensor name '" << tensor_name << "' doesn't exist in trt_profile_min_shapes."; + return true; + } + + for (auto dim_it = tensor_it->second.begin(); dim_it != tensor_it->second.end(); dim_it++) { // iterate dimensions + auto dim = dim_it->first; + auto num_profiles = GetNumProfiles(profile_min_shapes); + + if (dim_it->second.size() != static_cast(num_profiles)) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of profiles are not the same."; + return true; + } + + for (size_t i = 0; i < dim_it->second.size(); i++) { // iterate (multiple) profile(s) + auto shape_values = dim_it->second[i]; + if (dim > (profile_min_shapes[tensor_name][i].size() - 1)) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dimension " << dim << " of '" << tensor_name << "' in " << file_name << " exceeds the total dimension of trt_profile_min_shapes."; + return true; + } + + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_min_shapes[tensor_name][i][dim]; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[0] << " in " << file_name; + if (profile_min_shapes[tensor_name][i][dim] != shape_values[0]) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_max_shapes[tensor_name][i][dim]; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[1] << " in " << file_name; + if (profile_max_shapes[tensor_name][i][dim] != shape_values[1]) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_opt_shapes[tensor_name][i][dim]; + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[2] << " in " << file_name; + if (profile_opt_shapes[tensor_name][i][dim] != shape_values[2]) { + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + } + } + } + return false; +} + std::string GetCachePath(const std::string& root, const std::string& name) { if (root.empty()) { return name; @@ -90,4 +252,53 @@ std::string GetCachePath(const std::string& root, const std::string& name) { return path.string(); } } + +std::string GetComputeCapacity(const cudaDeviceProp& prop) { + const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor); + return compute_capability; +} + +std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) { + // append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache + const std::string timing_cache_name = "TensorrtExecutionProvider_cache_sm" + + compute_cap + ".timing"; + return GetCachePath(root, timing_cache_name); +} + +std::vector split(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +std::string join(const std::vector& vec, const std::string& delimiter) { + std::string result; + for (size_t i = 0; i < vec.size(); ++i) { + result += vec[i]; + if (i < vec.size() - 1) { + result += delimiter; + } + } + return result; +} + +std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) { + std::vector split_fused_node_name = split(fused_node_name, '_'); + if (split_fused_node_name.size() >= 3) { + // Get index of model hash from fused_node_name + std::string model_hash = split_fused_node_name[split_fused_node_name.size() - 3]; + size_t index = fused_node_name.find(model_hash); + // Parse suffix from trt_node_name_with_precision, as it has additional precision info + std::vector suffix_group = split(trt_node_name_with_precision.substr(index), '_'); + if (suffix_group.size() > 2) { + suffix_group.erase(suffix_group.begin() + 2); + } + return join(suffix_group, "_"); + } + return ""; +} } From 7bdb36a0b99589d0e7b26b1527b135a723878be9 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 26 Aug 2024 22:16:57 +0000 Subject: [PATCH 22/81] add override function implementation and cudart dependency for tensorrt --- samples/c_test/test.cpp | 9 +++- samples/tensorRTEp/CMakeLists.txt | 3 +- .../tensorRTEp/tensorrt_execution_provider.cc | 45 +++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 7a6824d7a8c04..6826a3a5bc10e 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -18,6 +18,12 @@ void TestKernelBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "kernelEp", env, keys.data(), values.data(), keys.size())); } +void TestTensorRTEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { + THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); + std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; + THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); +} + int main() { const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); OrtEnv* p_env = nullptr; @@ -27,7 +33,8 @@ int main() { THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); //TestCompileBasedEp(g_ort, p_env, so); - TestKernelBasedEp(g_ort, p_env, so); + //TestKernelBasedEp(g_ort, p_env, so); + TestTensorRTEp(g_ort, p_env, so); OrtSession* session = nullptr; THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", so, &session)); diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index 759c43060de4b..fb6770405537f 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -28,7 +28,8 @@ target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/Linu "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer.so" "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer_plugin.so" "/home/leca/TensorRT-10.0.1.6/lib/libnvonnxparser.so" - "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/flatbuffers-build/libflatbuffers.a") + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/flatbuffers-build/libflatbuffers.a" + CUDA::cudart) # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" # "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index c00cbca0fef7b..2dd1f7741b864 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -230,6 +230,51 @@ inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::I oFile.close(); } +#if NV_TENSORRT_MAJOR >= 10 +void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#else +// Only override this method when TensorRT <= 8.6 +void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#endif + +void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { + output_shapes.clear(); + output_shapes.reserve(dims.nbDims); + for (int i = 0; i < dims.nbDims; i++) { + output_shapes.push_back(dims.d[i]); + } +} + TensorrtLogger& GetTensorrtLogger(bool verbose_log) { const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING; static TensorrtLogger trt_logger(log_level); From 7d915b7a2660e9be3aabf0083342f3df90cb1d3a Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Tue, 27 Aug 2024 23:06:08 +0800 Subject: [PATCH 23/81] add outOfTree tensorrt ep.1 (#21830) add GetCapability for tensorrt ep --- .../core/session/onnxruntime_c_api.h | 16 +- onnxruntime/core/session/onnxruntime_c_api.cc | 57 ++- onnxruntime/core/session/ort_apis.h | 14 +- samples/outTreeEp/out_tree_ep.cc | 2 +- samples/tensorRTEp/murmurhash3.cc | 349 +++++++++++++++++ samples/tensorRTEp/murmurhash3.h | 16 + samples/tensorRTEp/onnx_ctx_model_helper.cc | 104 ++++- samples/tensorRTEp/onnx_ctx_model_helper.h | 6 + .../tensorRTEp/tensorrt_execution_provider.cc | 362 +++++++++++++++++- .../tensorRTEp/tensorrt_execution_provider.h | 16 + 10 files changed, 933 insertions(+), 9 deletions(-) create mode 100644 samples/tensorRTEp/murmurhash3.cc create mode 100644 samples/tensorRTEp/murmurhash3.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c7de6f3538499..36d5c01d6237b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -307,6 +307,8 @@ ORT_RUNTIME_CLASS(ShapeInferContext); ORT_RUNTIME_CLASS(ExecutionProvider); ORT_RUNTIME_CLASS(ExecutionProviderFactory); ORT_RUNTIME_CLASS(Node); +ORT_RUNTIME_CLASS(Model); +ORT_RUNTIME_CLASS(Graph); ORT_RUNTIME_CLASS(GraphViewer); ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(TypeConstraints); @@ -4757,7 +4759,19 @@ struct OrtApi { ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); - ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); + ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); + + ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); + + ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); + + ORT_API2_STATUS(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); + + ORT_API2_STATUS(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path); + + ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); + + ORT_API2_STATUS(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Out_ size_t* num_inputs, _Outptr_ const char*** input_names); ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 26f3df9f647c2..92af3800123d3 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -100,6 +100,8 @@ using onnxruntime::common::Status; using namespace onnxruntime; +typedef std::unordered_map ModelMetaData; + #ifndef ORT_STATUS_PTR #ifdef _WIN32 #define ORT_STATUS_PTR _Check_return_ _Ret_maybenull_ OrtStatusPtr @@ -2419,14 +2421,53 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsConstantInitializer, const OrtGraphViewe return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order) { +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - const std::vector& nodes = graph_viewer->GetNodesInTopologicalOrder(); + const std::vector& nodes = graph_viewer->GetNodesInTopologicalOrder(static_cast(execution_order)); *len = nodes.size(); *nodes_index_in_topological_order = nodes.data(); return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret) { + const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); + *ret = graph_ptr->IsSubgraph(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph) { + const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); + *parent_graph = reinterpret_cast(graph_ptr->ParentGraph()); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *parent_node = reinterpret_cast(graph_viewer->ParentNode()); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *path = reinterpret_cast(&graph_viewer->ModelPath()); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph) { + const ::onnxruntime::GraphViewer* graph_viewer_ptr = reinterpret_cast(graph_viewer); + *graph = reinterpret_cast(&graph_viewer_ptr->GetGraph()); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Out_ size_t* num_inputs, _Outptr_ const char*** input_names) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const auto& inputs = graph_viewer->GetInputsIncludingInitializers(); + *num_inputs = inputs.size(); + *input_names = new const char*[*num_inputs]; + for (size_t i = 0; i < *num_inputs; i++) (*input_names)[i] = inputs[i]->Name().c_str(); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); *node = reinterpret_cast(graph_viewer->GetNode(node_index)); @@ -2563,7 +2604,11 @@ ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetOutputSize, const OrtNode* node, _Out_ s ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name) { const ::onnxruntime::Node* n = reinterpret_cast(node); assert(i < n->OutputDefs().size()); - *ith_output_name = n->OutputDefs()[i]->Name().c_str(); + if (n->OutputDefs()[i]->Exists()){ + *ith_output_name = n->OutputDefs()[i]->Name().c_str(); + } else { + *ith_output_name = nullptr; + } return nullptr; } @@ -3052,6 +3097,12 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtGraph_IsConstantInitializer, &OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, + &OrtApis::OrtGraph_IsSubgraph, + &OrtApis::OrtGraph_GetParentGraph, + &OrtApis::OrtGraph_GetParenNode, + &OrtApis::OrtGraph_GetModelPath, + &OrtApis::OrtGraph_GetOrtGraph, + &OrtApis::OrtGraph_GetInputsIncludingInitializers, &OrtApis::OrtGraph_GetOrtNode, &OrtApis::OrtGraph_GetNodesConsumingInput, &OrtApis::OrtGraph_GetNodeProducingOutput, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c92b33d3e91ce..e188515ceeeeb 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -541,7 +541,19 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOpt ORT_API_STATUS_IMPL(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); -ORT_API_STATUS_IMPL(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); +ORT_API_STATUS_IMPL(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); + +ORT_API_STATUS_IMPL(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); + +ORT_API_STATUS_IMPL(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); + +ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path); + +ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); + +ORT_API_STATUS_IMPL(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); + +ORT_API_STATUS_IMPL(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Out_ size_t* num_inputs, _Outptr_ const char*** input_names); ORT_API_STATUS_IMPL(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index f2602201efe96..71b950abc9c4a 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -11,7 +11,7 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe std::vector cache; size_t nodes_count = 0; const size_t* nodes_index = nullptr; - api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, &nodes_count, &nodes_index); + api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 0, &nodes_count, &nodes_index); for (size_t i = 0; i < nodes_count; i++) { const OrtNode* node = nullptr; api->OrtGraph_GetOrtNode(graph, nodes_index[i], &node); diff --git a/samples/tensorRTEp/murmurhash3.cc b/samples/tensorRTEp/murmurhash3.cc new file mode 100644 index 0000000000000..0f69de579b2b6 --- /dev/null +++ b/samples/tensorRTEp/murmurhash3.cc @@ -0,0 +1,349 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "murmurhash3.h" + +// Original source: https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. + +/* Modifications Copyright (c) Microsoft. */ + +#include "core/framework/endian.h" + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) + +#define FORCE_INLINE __forceinline + +#include + +#define ROTL32(x, y) _rotl(x, y) +#define ROTL64(x, y) _rotl64(x, y) + +#define BIG_CONSTANT(x) (x) + +// Other compilers + +#else // defined(_MSC_VER) + +#define FORCE_INLINE inline __attribute__((always_inline)) + +inline uint32_t rotl32(uint32_t x, int8_t r) { + return (x << r) | (x >> (32 - r)); +} + +inline uint64_t rotl64(uint64_t x, int8_t r) { + return (x << r) | (x >> (64 - r)); +} + +#define ROTL32(x, y) rotl32(x, y) +#define ROTL64(x, y) rotl64(x, y) + +#define BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) +#include +//----------------------------------------------------------------------------- +// Block read - on little-endian machines this is a single load, +// while on big-endian or unknown machines the byte accesses should +// still get optimized into the most efficient instruction. +// +// Changes to support big-endian from https://github.com/explosion/murmurhash/pull/27/ +// were manually applied to original murmurhash3 source code. +FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) { + if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { + return p[i]; + } else { + const uint8_t* c = (const uint8_t*)&p[i]; + return (uint32_t)c[0] | + (uint32_t)c[1] << 8 | + (uint32_t)c[2] << 16 | + (uint32_t)c[3] << 24; + } +} + +FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) { + if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { + return p[i]; + } else { + const uint8_t* c = (const uint8_t*)&p[i]; + return (uint64_t)c[0] | + (uint64_t)c[1] << 8 | + (uint64_t)c[2] << 16 | + (uint64_t)c[3] << 24 | + (uint64_t)c[4] << 32 | + (uint64_t)c[5] << 40 | + (uint64_t)c[6] << 48 | + (uint64_t)c[7] << 56; + } +} + +//----------------------------------------------------------------------------- +// Finalization mix - force all bits of a hash block to avalanche + +FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; +} + +//---------- + +FORCE_INLINE constexpr uint64_t fmix64(uint64_t k) { + k ^= k >> 33; + k *= BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= k >> 33; + k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= k >> 33; + + return k; +} + +//----------------------------------------------------------------------------- + +namespace onnxruntime { +void MurmurHash3::x86_32(const void* key, int len, + uint32_t seed, void* out) { + const uint8_t* data = (const uint8_t*)key; + const int nblocks = len / 4; + + uint32_t h1 = seed; + + constexpr uint32_t c1 = 0xcc9e2d51; + constexpr uint32_t c2 = 0x1b873593; + + //---------- + // body + + const uint32_t* blocks = (const uint32_t*)(data + static_cast(nblocks) * 4); + + for (int i = -nblocks; i; i++) { + uint32_t k1 = getblock32(blocks, i); + + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + + //---------- + // tail + + const uint8_t* tail = (const uint8_t*)(data + static_cast(nblocks) * 4); + + uint32_t k1 = 0; + + switch (len & 3) { + case 3: + k1 ^= tail[2] << 16; + [[fallthrough]]; + case 2: + k1 ^= tail[1] << 8; + [[fallthrough]]; + case 1: + k1 ^= tail[0]; + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + + h1 = fmix32(h1); + + *(uint32_t*)out = h1; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) { + const uint8_t* data = (const uint8_t*)key; + const int nblocks = len / 16; + + uint32_t h1 = seed; + uint32_t h2 = seed; + uint32_t h3 = seed; + uint32_t h4 = seed; + + constexpr uint32_t c1 = 0x239b961b; + constexpr uint32_t c2 = 0xab0e9789; + constexpr uint32_t c3 = 0x38b34ae5; + constexpr uint32_t c4 = 0xa1e38b93; + + //---------- + // body + + const uint32_t* blocks = (const uint32_t*)(data + static_cast(nblocks) * 16); + + for (int i = -nblocks; i; i++) { + uint32_t k1 = getblock32(blocks, i * 4 + 0); + uint32_t k2 = getblock32(blocks, i * 4 + 1); + uint32_t k3 = getblock32(blocks, i * 4 + 2); + uint32_t k4 = getblock32(blocks, i * 4 + 3); + + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + h1 ^= k1; + + h1 = ROTL32(h1, 19); + h1 += h2; + h1 = h1 * 5 + 0x561ccd1b; + + k2 *= c2; + k2 = ROTL32(k2, 16); + k2 *= c3; + h2 ^= k2; + + h2 = ROTL32(h2, 17); + h2 += h3; + h2 = h2 * 5 + 0x0bcaa747; + + k3 *= c3; + k3 = ROTL32(k3, 17); + k3 *= c4; + h3 ^= k3; + + h3 = ROTL32(h3, 15); + h3 += h4; + h3 = h3 * 5 + 0x96cd1c35; + + k4 *= c4; + k4 = ROTL32(k4, 18); + k4 *= c1; + h4 ^= k4; + + h4 = ROTL32(h4, 13); + h4 += h1; + h4 = h4 * 5 + 0x32ac3b17; + } + + //---------- + // tail + + const uint8_t* tail = (const uint8_t*)(data + static_cast(nblocks) * 16); + + uint32_t k1 = 0; + uint32_t k2 = 0; + uint32_t k3 = 0; + uint32_t k4 = 0; + + switch (len & 15) { + case 15: + k4 ^= tail[14] << 16; + [[fallthrough]]; + case 14: + k4 ^= tail[13] << 8; + [[fallthrough]]; + case 13: + k4 ^= tail[12] << 0; + k4 *= c4; + k4 = ROTL32(k4, 18); + k4 *= c1; + h4 ^= k4; + [[fallthrough]]; + case 12: + k3 ^= tail[11] << 24; + [[fallthrough]]; + case 11: + k3 ^= tail[10] << 16; + [[fallthrough]]; + case 10: + k3 ^= tail[9] << 8; + [[fallthrough]]; + case 9: + k3 ^= tail[8] << 0; + k3 *= c3; + k3 = ROTL32(k3, 17); + k3 *= c4; + h3 ^= k3; + [[fallthrough]]; + case 8: + k2 ^= tail[7] << 24; + [[fallthrough]]; + case 7: + k2 ^= tail[6] << 16; + [[fallthrough]]; + case 6: + k2 ^= tail[5] << 8; + [[fallthrough]]; + case 5: + k2 ^= tail[4] << 0; + k2 *= c2; + k2 = ROTL32(k2, 16); + k2 *= c3; + h2 ^= k2; + [[fallthrough]]; + case 4: + k1 ^= tail[3] << 24; + [[fallthrough]]; + case 3: + k1 ^= tail[2] << 16; + [[fallthrough]]; + case 2: + k1 ^= tail[1] << 8; + [[fallthrough]]; + case 1: + k1 ^= tail[0] << 0; + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + h2 ^= len; + h3 ^= len; + h4 ^= len; + + h1 += h2; + h1 += h3; + h1 += h4; + h2 += h1; + h3 += h1; + h4 += h1; + + h1 = fmix32(h1); + h2 = fmix32(h2); + h3 = fmix32(h3); + h4 = fmix32(h4); + + h1 += h2; + h1 += h3; + h1 += h4; + h2 += h1; + h3 += h1; + h4 += h1; + + ((uint32_t*)out)[0] = h1; + ((uint32_t*)out)[1] = h2; + ((uint32_t*)out)[2] = h3; + ((uint32_t*)out)[3] = h4; +} + +} // namespace onnxruntime diff --git a/samples/tensorRTEp/murmurhash3.h b/samples/tensorRTEp/murmurhash3.h new file mode 100644 index 0000000000000..ab86a3e591adf --- /dev/null +++ b/samples/tensorRTEp/murmurhash3.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +struct MurmurHash3 { + // generate 32-bit hash from input and write to 'out' + static void x86_32(const void* key, int len, uint32_t seed, void* out); + + // generate 128-bit hash from input and write to 'out'. + static void x86_128(const void* key, int len, uint32_t seed, void* out); +}; +} // namespace onnxruntime diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index 9b8e16b0eb549..ee3ee3cb992d6 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -1,10 +1,109 @@ #include #include #include +#include +#include "murmurhash3.h" #include "onnx_ctx_model_helper.h" #include "tensorrt_execution_provider.h" namespace onnxruntime { + +HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { + HashValue model_hash = 0; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + const OrtGraph* cur_graph = nullptr; + api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph); + bool is_subgraph = false; + api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + while (is_subgraph) { + const OrtGraph* parent_graph = nullptr; + api->OrtGraph_GetParentGraph(cur_graph, &parent_graph); + cur_graph = parent_graph; + api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + } + + const OrtGraph* main_graph = cur_graph; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + const std::filesystem::path* model_path = nullptr; + api->OrtGraph_GetModelPath(graph_viewer, (const void**)&model_path); + + // Use the model's file name instead of the entire path to avoid cache regeneration if path changes + if (model_path->has_filename()) { + std::string model_name = model_path->filename(); + + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name; + // Ensure enough characters are hashed in case model names are too short + const size_t model_name_length = model_name.size(); + constexpr size_t hash_string_length = 500; + std::string repeat_model_name = model_name; + for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { + repeat_model_name += model_name; + } + hash_str(repeat_model_name); + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty"; + } + + // fingerprint current graph by hashing graph inputs + // const std::vector& input_names = nullptr; + const char** input_names = nullptr; + size_t input_count = 0; + api->OrtGraph_GetInputsIncludingInitializers(graph_viewer, &input_count, &input_names); + for (size_t i = 0; i < input_count; ++i) { + hash_str(input_names[i]); + } + + // hashing output of each node + const int number_of_ort_nodes = api->OrtGraph_NumberOfNodes(graph_viewer); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + size_t nodes_count = 0; + const size_t* nodes_index = nullptr; + api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_count, &nodes_index); + for (const auto& index : nodes_vector) { + const OrtNode* node = nullptr; + api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index], &node); + size_t output_size = 0; + api->OrtNode_GetOutputSize(node, &output_size); + for (size_t i = 0; i < output_size; ++i) { + const char* output_name = nullptr; + api->OrtNode_GetIthOutputName(node, i, &output_name); + if (output_name != nullptr) { + hash_str(output_name); + } + } + } + +#ifdef __linux__ + hash_str("LINUX"); +#elif defined(_WIN32) + hash_str("WINDOWS"); +#endif + +#ifdef ORT_VERSION + hash_str(ORT_VERSION); +#endif + +#ifdef CUDA_VERSION + hash_str(std::to_string(CUDA_VERSION)); +#endif + +#if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR) + std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR); + hash_str(TRT_VERSION); +#endif + + model_hash = hash[0] | (uint64_t(hash[1]) << 32); + + // return the current unique id + return model_hash; +} + bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); int maxNodeIndex = 0; @@ -12,9 +111,12 @@ bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { for (int i = 0; i < maxNodeIndex; ++i) { const OrtNode* node = nullptr; api->OrtGraph_GetOrtNode(graph_viewer, i, &node); + if (node == nullptr) { + continue; + } const char* opType = nullptr; api->OrtNode_GetOpType(node, &opType); - if (node != nullptr && strcmp(opType, EPCONTEXT_OP.c_str()) == 0) { + if (strcmp(opType, EPCONTEXT_OP.c_str()) == 0) { return true; } } diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.h b/samples/tensorRTEp/onnx_ctx_model_helper.h index 3fcb809b4bded..cf4db883616bd 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.h +++ b/samples/tensorRTEp/onnx_ctx_model_helper.h @@ -2,13 +2,18 @@ // Licensed under the MIT License. #pragma once +#include #include +#include #include #include #include "core/session/onnxruntime_c_api.h" #include "nv_includes.h" namespace onnxruntime { + +using HashValue = uint64_t; + static const std::string EPCONTEXT_OP = "EPContext"; static const std::string EMBED_MODE = "embed_mode"; static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; @@ -20,6 +25,7 @@ static const std::string EPCONTEXT_WARNING = make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ for the best model loading time"; +HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer); bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); std::string GetCtxModelPath(const std::string& ep_context_file_path, diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 2dd1f7741b864..8a64b5abe820a 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -5,7 +5,6 @@ #include "core/session/onnxruntime_cxx_api.h" // TODO(leca): we should be able to use cxx APIs which are built upon C API #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" -#include "onnx_ctx_model_helper.h" void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } @@ -15,6 +14,25 @@ template using IAllocatorUniquePtr = std::unique_ptr>; const OrtApi* TensorrtExecutionProvider::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +// Check if cycle exists in the graph after partitioning +bool FindCycleHelper(size_t i, const std::list* adjacency_map, bool visited[], bool* st, std::vector& cycles) { + if (!visited[i]) { + visited[i] = true; + st[i] = true; + for (auto iter = adjacency_map[i].begin(); iter != adjacency_map[i].end(); ++iter) { + if (!visited[*iter] && FindCycleHelper(*iter, adjacency_map, visited, st, cycles)) { + cycles.push_back(*iter); + return true; + } else if (st[*iter]) { + cycles.push_back(*iter); + return true; + } + } + } + st[i] = false; + return false; +} + bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { size_t alloc_size = size; if (alignment == 0) { @@ -889,8 +907,335 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, return nullptr; } +// Detect and remove cycles from supported node list +bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles) const { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + size_t node_count = 0; + const size_t* nodes_index = nullptr; + api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_count, &nodes_index); + bool trt_cycle = true, cycle_detected = false; + while (trt_cycle) { + trt_cycle = false; + std::unordered_map node_to_index_map; + std::unordered_map index_to_node_map; + std::unordered_map> input_to_nodes_map, node_to_outputs_map; + std::unordered_set non_trt_node_index; + for (size_t i = 0; i < node_count; ++i) { + non_trt_node_index.insert(nodes_index[i]); + } + size_t id = 0; + int subgraph_index = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + // Construct subgraph from node list + // std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); + OrtIndexedSubGraph* subgraph = new OrtIndexedSubGraph(); + + // Create node to inputs/outputs/index maps + const std::string node_name = subgraph->meta_def->name; + if (node_to_index_map.find(node_name) == node_to_index_map.end()) { + index_to_node_map[id] = node_name; + node_to_index_map[node_name] = id++; + } + + if (subgraph->meta_def != nullptr) { + for (size_t j = 0; j < subgraph->meta_def->input_len; j++) { + input_to_nodes_map[std::string(subgraph->meta_def->inputs[j])].insert(node_name); + } + for (size_t j = 0; j < subgraph->meta_def->output_len; j++) { + node_to_outputs_map[node_name].insert(std::string(subgraph->meta_def->outputs[j])); + } + } + + // Remove TensorRT nodes from node index list + for (const auto& index : group.first) { + non_trt_node_index.erase(nodes_index[index]); + } + subgraph_index++; + } + } + + // Add non TensorRT nodes to the maps + for (const auto& index : non_trt_node_index) { + const OrtNode* node = nullptr; + api->OrtGraph_GetOrtNode(graph, index, &node); + const char* node_name_char = nullptr; + api->OrtNode_GetName(node, &node_name_char); + const std::string node_name(node_name_char); + if (node_to_index_map.find(node_name) == node_to_index_map.end()) { + index_to_node_map[id] = node_name; + node_to_index_map[node_name] = id++; + } + + size_t input_count = 0; + api->OrtNode_GetInputSize(node, &input_count); + for (size_t i = 0; i < input_count; ++i) { + const char* input_name_char = nullptr; + api->OrtNode_GetIthInputName(node, i, &input_name_char); + input_to_nodes_map[std::string(input_name_char)].insert(node_name); + } + + size_t implicit_input_count = 0; + api->OrtNode_GetImplicitInputSize(node, &implicit_input_count); + for (size_t i = 0; i < implicit_input_count; ++i) { + const char* input_name_char = nullptr; + api->OrtNode_GetIthImplicitInputName(node, i, &input_name_char); + input_to_nodes_map[std::string(input_name_char)].insert(node_name); + } + + size_t output_count = 0; + api->OrtNode_GetOutputSize(node, &output_count); + for (size_t i = 0; i < output_count; ++i) { + const char* output_name_char = nullptr; + api->OrtNode_GetIthOutputName(node, i, &output_name_char); + node_to_outputs_map[node_name].insert(std::string(output_name_char)); + } + } + + // Create adjacency list + size_t graph_size = node_to_index_map.size(); + std::list* adjacency_map = new std::list[graph_size]; + for (const auto& node : node_to_outputs_map) { + for (auto iter = node.second.begin(); iter != node.second.end(); ++iter) { + const auto& loc = input_to_nodes_map.find(*iter); + if (loc != input_to_nodes_map.end()) { + size_t parent_node_index = node_to_index_map.find(node.first)->second; + for (auto child_node : loc->second) { + size_t child_node_index = node_to_index_map.find(child_node)->second; + adjacency_map[parent_node_index].push_back(child_node_index); + } + } + } + } + + // Check cycle in the graph + bool* visited = new bool[graph_size]; + bool* st = new bool[graph_size]; + for (size_t i = 0; i < graph_size; ++i) { + visited[i] = false; + st[i] = false; + } + + std::vector cycles; + bool has_cycle = false; + for (size_t i = 0; i < graph_size; ++i) { + if (FindCycleHelper(i, adjacency_map, visited, st, cycles)) { + has_cycle = true; + cycle_detected = true; + break; + } + } + + // Remove TensorRT subgraph from the supported node list if it's part of the cycle + if (has_cycle && remove_cycles) { + for (size_t i = 0; i < cycles.size(); ++i) { + auto loc = index_to_node_map.find(cycles[i]); + if (loc != index_to_node_map.end() && loc->second.find("TRTKernel") != std::string::npos) { + supported_nodes_vector.erase(supported_nodes_vector.begin() + cycles[i]); + trt_cycle = true; + break; + } + } + } + + delete[] adjacency_map; + delete[] visited; + delete[] st; + } + return cycle_detected; +} + +// Check the graph is the subgraph of control flow op +bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + const OrtGraph* cur_graph = nullptr; + api->OrtGraph_GetOrtGraph(graph, &cur_graph); + bool is_subgraph = false; + api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + if (is_subgraph) { + const OrtNode* node = nullptr; + api->OrtGraph_GetParenNode(graph, &node); + const char* node_op_type; + api->OrtNode_GetOpType(node, &node_op_type); + if (control_flow_op_set_.find(std::string(node_op_type)) != control_flow_op_set_.end()) { + return true; + } + } + return false; +} + +// Check whether all the nodes of the graph are assigned to specific ep +bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + std::vector nodes_vector(api->OrtGraph_NumberOfNodes(graph)); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + size_t node_count = 0; + const size_t* nodes_index = nullptr; + api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_count, &nodes_index); + for (const auto& index : nodes_vector) { + const OrtNode* node = nullptr; + api->OrtGraph_GetOrtNode(graph, nodes_index[index], &node); + const char* node_ep_type; + api->OrtNode_GetExecutionProviderType(node, &node_ep_type); + if (!strcmp(node_ep_type, provider_type.c_str())) { + return false; + } + } + return true; + +} + +// Check whether all the nodes of subgraph are supported +bool TensorrtExecutionProvider::IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const { + int number_of_trt_nodes = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + number_of_trt_nodes += static_cast(group.first.size()); + } + } + + return number_of_trt_nodes == number_of_ort_nodes; +} + + TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + const TensorrtExecutionProvider* p = static_cast(this_); + // Get ModelPath + const std::filesystem::path* model_path = nullptr; + api->OrtGraph_GetModelPath(graph, (const void**)&model_path); + const auto& path_string = model_path->string(); +#ifdef _WIN32 + std::strncpy_s(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); +#else + std::strncpy(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); +#endif + p->model_path_[sizeof(p->model_path_) - 1] = '\0'; + + if (api->OrtGraph_NumberOfNodes(graph) == 1 && GraphHasCtxNode(graph)) { + SubGraph_t supported_node_vector = {{0}, true}; + // std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); + // result.push_back(ComputeCapability::Create(std::move(sub_graph))); + // return result; + } + + // Generate unique kernel name for TRT graph + HashValue model_hash = TRTGenerateId(graph); + + // Get supported node list from TensorRT parser + const int number_of_ort_nodes = api->OrtGraph_NumberOfNodes(graph); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + + std::vector filtered_nodes_vector; + size_t nodes_count = 0; + const size_t* nodes_index = nullptr; + api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_count, &nodes_index); + for (const auto& index : nodes_vector) { + const OrtNode* node = nullptr; + api->OrtGraph_GetOrtNode(graph, nodes_index[index], &node); + const char* node_op_type; + api->OrtNode_GetOpType(node, &node_op_type); + + /* If current node is control flow op, we take different approach based on following four cases: + * + * (1) control flow op is supported by TRT, and its subgraphs are all supported by TRT. Assign this node to TRT. + * (2) control flow op is supported by TRT, but not all its subgraphs supported by TRT. Don't assign this node to TRT. + * (3) control flow op is not supported by TRT, but its subgraphs all supported by TRT. Don't assign this node to TRT. + * (4) control flow op is not supported by TRT, and not all its subgraphs supported by TRT. Don't assign this node to TRT. + * + * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. + */ + if (p->control_flow_op_set_.find(std::string(node_op_type)) != p->control_flow_op_set_.end()) { + size_t subgraph_count = 0; + const OrtGraphViewer** subgraphs = nullptr; + api->OrtNode_GetSubgraphs(node, &subgraph_count, &subgraphs); + if (subgraph_count == 0) { + bool all_subgraphs_are_supported = true; + for (size_t i = 0; i < subgraph_count; i++) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (api->OrtGraph_NumberOfNodes(subgraphs[i]) == 0) { + continue; + } + if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], kTensorrtExecutionProvider)) { + all_subgraphs_are_supported = false; + break; + } + } + if (!all_subgraphs_are_supported) { + // if not all its subgraphs are supported, we need to exclude this control flow op + continue; + } + } + } + filtered_nodes_vector.push_back(index); + } + + SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}}; + bool early_termination = false; + // supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); + if (early_termination) { + supported_nodes_vector.clear(); + } + + // Remove subgraphs if its size is less than the predefined minimal size + for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { + const size_t subgraph_size = it->first.size(); + if (subgraph_size < p->min_subgraph_size_) { + supported_nodes_vector.erase(it--); + } + } + + // Detect and remove cycles from supported node list + p->DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash); + + // Consolidate supported node list + if (supported_nodes_vector.size() > 1) { + nodes_vector.clear(); + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + nodes_vector.insert(nodes_vector.end(), group.first.begin(), group.first.end()); + } + } + SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}}; + if (p->DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph"; + supported_nodes_vector = consolidated_supported_nodes_vector; + } + } + + // Handle the case where the graph is subgraph of control flow op. + // The purpose is to make control flow op as well as its subgraphs run on TRT. + // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. + if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) { + } + + int number_of_trt_nodes = 0, subgraph_index = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + // std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); + // result.push_back(ComputeCapability::Create(std::move(sub_graph))); + number_of_trt_nodes += static_cast(group.first.size()); + subgraph_index++; + } + } + + const size_t number_of_subgraphs = supported_nodes_vector.size(); + if (number_of_trt_nodes == 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider"; + } else if (number_of_trt_nodes == number_of_ort_nodes) { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; + } + + // The context map is only used during EP compile time, release it to save memory space. + // subgraph_context_map_.clear(); + // return result; + }; OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { @@ -913,7 +1258,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const for (size_t i = 0; i < output_size; i++) { const char* ith_output_name = nullptr; api->OrtNode_GetIthOutputName(node[j], i, &ith_output_name); - output_map[ith_output_name] = i; + if (ith_output_name != nullptr) { + output_map[ith_output_name] = i; + } } OrtStatusPtr ret = nullptr; @@ -2542,6 +2889,17 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi return nullptr; } +SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, + const OrtGraphViewer& graph, bool* early_termination) const { + // Return if iterations are exceeding predefined number + SubGraphCollection_t nodes_list_output; + if (iterations > max_iterations) { + *early_termination = true; + return nodes_list_output; + } + return nodes_list_output; +} + } // namespace onnxruntime #ifdef __cplusplus diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 805177a7bcf64..8d76381c4cd08 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -4,7 +4,9 @@ #include #include "core/session/onnxruntime_c_api.h" #include "core/framework/provider_options.h" +#include "core/graph/constants.h" #include "nv_includes.h" +#include "onnx_ctx_model_helper.h" #ifdef _WIN32 #define EXPORT_API __declspec(dllexport) @@ -178,6 +180,20 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log); + SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, + const OrtGraphViewer& graph, bool* early_termination) const; + + bool DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles = true) const; + + /**Check the graph is the subgraph of control flow op*/ + bool IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const; + + /**Check whether all the nodes of the graph are assigned to specific ep*/ + bool AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const; + + /**Check whether all the nodes of subgraph are supported*/ + bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; + static const OrtApi* api_; std::string trt_node_name_with_precision_; std::unordered_map dynamic_range_map_; From 4aea94b9eb265c76a8c36e49297b54c8ee190d43 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 28 Aug 2024 00:42:02 +0000 Subject: [PATCH 24/81] GetSupportedList --- .../core/session/onnxruntime_c_api.h | 1 - samples/c_test/CMakeLists.txt | 3 +- samples/c_test/test.cpp | 20 + samples/tensorRTEp/CMakeLists.txt | 4 +- samples/tensorRTEp/onnx_ctx_model_helper.cc | 98 ----- samples/tensorRTEp/onnx_ctx_model_helper.h | 8 +- .../tensorRTEp/tensorrt_execution_provider.cc | 362 +++++++++++++++++- .../tensorRTEp/tensorrt_execution_provider.h | 9 +- .../tensorrt_execution_provider_utils.h | 99 +++++ 9 files changed, 484 insertions(+), 120 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 36d5c01d6237b..a3025690fb666 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -307,7 +307,6 @@ ORT_RUNTIME_CLASS(ShapeInferContext); ORT_RUNTIME_CLASS(ExecutionProvider); ORT_RUNTIME_CLASS(ExecutionProviderFactory); ORT_RUNTIME_CLASS(Node); -ORT_RUNTIME_CLASS(Model); ORT_RUNTIME_CLASS(Graph); ORT_RUNTIME_CLASS(GraphViewer); ORT_RUNTIME_CLASS(KernelRegistry); diff --git a/samples/c_test/CMakeLists.txt b/samples/c_test/CMakeLists.txt index c8bf77b99c1a0..9a460ecb72560 100644 --- a/samples/c_test/CMakeLists.txt +++ b/samples/c_test/CMakeLists.txt @@ -7,4 +7,5 @@ project(TestOutTreeEp) add_executable(TestOutTreeEp test.cpp) target_include_directories(TestOutTreeEp PUBLIC "../../include/onnxruntime") -target_link_libraries(TestOutTreeEp PUBLIC "/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so") +#target_link_libraries(TestOutTreeEp PUBLIC "/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so") +target_link_libraries(TestOutTreeEp PUBLIC "/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so") diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 6826a3a5bc10e..5b03221ef9d41 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -22,6 +22,25 @@ void TestTensorRTEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); + + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + THROW_ON_ERROR(g_ort->CreateCUDAProviderOptions(&cuda_options)); + THROW_ON_ERROR(g_ort->SessionOptionsAppendExecutionProvider_CUDA_V2(so, cuda_options)); + + g_ort->ReleaseCUDAProviderOptions(cuda_options); +} + +void TestOriginalTensorRTEp(const OrtApi* g_ort, OrtSessionOptions* so) { + OrtTensorRTProviderOptionsV2* tensorrt_options = nullptr; + THROW_ON_ERROR(g_ort->CreateTensorRTProviderOptions(&tensorrt_options)); + THROW_ON_ERROR(g_ort->SessionOptionsAppendExecutionProvider_TensorRT_V2(so, tensorrt_options)); + + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + THROW_ON_ERROR(g_ort->CreateCUDAProviderOptions(&cuda_options)); + THROW_ON_ERROR(g_ort->SessionOptionsAppendExecutionProvider_CUDA_V2(so, cuda_options)); + + g_ort->ReleaseCUDAProviderOptions(cuda_options); + g_ort->ReleaseTensorRTProviderOptions(tensorrt_options); } int main() { @@ -35,6 +54,7 @@ int main() { //TestCompileBasedEp(g_ort, p_env, so); //TestKernelBasedEp(g_ort, p_env, so); TestTensorRTEp(g_ort, p_env, so); + //TestOriginalTensorRTEp(g_ort, so); OrtSession* session = nullptr; THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", so, &session)); diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index fb6770405537f..21ece12846d42 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -17,8 +17,8 @@ add_library(TensorRTEp SHARED ${tensorrt_src}) target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "/usr/local/cuda/include" "/home/leca/TensorRT-10.0.1.6/include" - "../../build/Linux/Debug/_deps/flatbuffers-src/include") -# "../../build/Linux/Debug/_deps/gsl-src/include" + "../../build/Linux/Debug/_deps/flatbuffers-src/include" + "../../build/Linux/Debug/_deps/gsl-src/include") # "../../build/Linux/Debug/_deps/onnx-src" # "../../build/Linux/Debug/_deps/onnx-build" # "../../build/Linux/Debug/_deps/protobuf-src/src") diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index ee3ee3cb992d6..afb88345675cb 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -1,109 +1,11 @@ #include #include #include -#include -#include "murmurhash3.h" #include "onnx_ctx_model_helper.h" #include "tensorrt_execution_provider.h" namespace onnxruntime { -HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { - HashValue model_hash = 0; - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - const OrtGraph* cur_graph = nullptr; - api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph); - bool is_subgraph = false; - api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); - while (is_subgraph) { - const OrtGraph* parent_graph = nullptr; - api->OrtGraph_GetParentGraph(cur_graph, &parent_graph); - cur_graph = parent_graph; - api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); - } - - const OrtGraph* main_graph = cur_graph; - uint32_t hash[4] = {0, 0, 0, 0}; - - auto hash_str = [&hash](const std::string& str) { - MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); - }; - - const std::filesystem::path* model_path = nullptr; - api->OrtGraph_GetModelPath(graph_viewer, (const void**)&model_path); - - // Use the model's file name instead of the entire path to avoid cache regeneration if path changes - if (model_path->has_filename()) { - std::string model_name = model_path->filename(); - - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name; - // Ensure enough characters are hashed in case model names are too short - const size_t model_name_length = model_name.size(); - constexpr size_t hash_string_length = 500; - std::string repeat_model_name = model_name; - for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { - repeat_model_name += model_name; - } - hash_str(repeat_model_name); - } else { - // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty"; - } - - // fingerprint current graph by hashing graph inputs - // const std::vector& input_names = nullptr; - const char** input_names = nullptr; - size_t input_count = 0; - api->OrtGraph_GetInputsIncludingInitializers(graph_viewer, &input_count, &input_names); - for (size_t i = 0; i < input_count; ++i) { - hash_str(input_names[i]); - } - - // hashing output of each node - const int number_of_ort_nodes = api->OrtGraph_NumberOfNodes(graph_viewer); - std::vector nodes_vector(number_of_ort_nodes); - std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - size_t nodes_count = 0; - const size_t* nodes_index = nullptr; - api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_count, &nodes_index); - for (const auto& index : nodes_vector) { - const OrtNode* node = nullptr; - api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index], &node); - size_t output_size = 0; - api->OrtNode_GetOutputSize(node, &output_size); - for (size_t i = 0; i < output_size; ++i) { - const char* output_name = nullptr; - api->OrtNode_GetIthOutputName(node, i, &output_name); - if (output_name != nullptr) { - hash_str(output_name); - } - } - } - -#ifdef __linux__ - hash_str("LINUX"); -#elif defined(_WIN32) - hash_str("WINDOWS"); -#endif - -#ifdef ORT_VERSION - hash_str(ORT_VERSION); -#endif - -#ifdef CUDA_VERSION - hash_str(std::to_string(CUDA_VERSION)); -#endif - -#if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR) - std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR); - hash_str(TRT_VERSION); -#endif - - model_hash = hash[0] | (uint64_t(hash[1]) << 32); - - // return the current unique id - return model_hash; -} - bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); int maxNodeIndex = 0; diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.h b/samples/tensorRTEp/onnx_ctx_model_helper.h index cf4db883616bd..c90574ebd4bae 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.h +++ b/samples/tensorRTEp/onnx_ctx_model_helper.h @@ -2,18 +2,15 @@ // Licensed under the MIT License. #pragma once -#include -#include -#include + #include #include +#include #include "core/session/onnxruntime_c_api.h" #include "nv_includes.h" namespace onnxruntime { -using HashValue = uint64_t; - static const std::string EPCONTEXT_OP = "EPContext"; static const std::string EMBED_MODE = "embed_mode"; static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; @@ -25,7 +22,6 @@ static const std::string EPCONTEXT_WARNING = make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ for the best model loading time"; -HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer); bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); std::string GetCtxModelPath(const std::string& ep_context_file_path, diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 8a64b5abe820a..df85934af3fab 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1,10 +1,12 @@ #include #include +#include #include #include #include "core/session/onnxruntime_cxx_api.h" // TODO(leca): we should be able to use cxx APIs which are built upon C API #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" +#include "onnx_ctx_model_helper.h" void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } @@ -1097,6 +1099,157 @@ bool TensorrtExecutionProvider::IsSubGraphFullySupported(SubGraphCollection_t su return number_of_trt_nodes == number_of_ort_nodes; } +std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const OrtGraphViewer* graph, const HashValue& model_hash, int subgraph_index) const { + size_t nodes_count = 0; + const size_t* node_index = nullptr; + api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_count, &node_index); + std::unordered_set node_set; + node_set.reserve(graph_nodes_index.first.size()); + for (const auto& index : graph_nodes_index.first) { + node_set.insert(node_index[index]); + } + + // Get parent graph output names + std::unordered_set graph_output_names; + size_t graph_output_size = api_->OrtGraph_GetOutputSize(graph); + for (size_t i = 0; i < graph_output_size; i++) { + graph_output_names.insert(api_->OrtGraph_GetIthOutputName(graph, i)); + } + + // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->meta_def = new OrtMetaDef(); +// std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; +// std::unordered_set erased; +// int input_order = 0; +// int output_order = 0; +// +// std::vector initializers; +// for (const auto& index : graph_nodes_index.first) { +// sub_graph->Nodes().push_back(node_index[index]); +// const auto& node = graph.GetNode(node_index[index]); +// for (const auto& input : node->InputDefs()) { +// if (graph.IsConstantInitializer(input->Name(), true)) { +// initializers.push_back(input->Name()); +// continue; +// } +// const auto& it = fused_outputs.find(input); +// if (it != fused_outputs.end()) { +// fused_outputs.erase(it); +// erased.insert(input); +// } else if (erased.find(input) == erased.end()) { +// // Only when input is neither in output list nor erased list, add the input to input list +// fused_inputs[input] = input_order++; +// } +// } +// +// for (const auto& input : node->ImplicitInputDefs()) { +// if (graph.IsConstantInitializer(input->Name(), true)) { +// initializers.push_back(input->Name()); +// continue; +// } +// const auto& it = fused_outputs.find(input); +// if (it != fused_outputs.end()) { +// fused_outputs.erase(it); +// erased.insert(input); +// } else if (erased.find(input) == erased.end()) { +// // Only when input is neither in output list nor erased list, add the input to input list +// fused_inputs[input] = input_order++; +// } +// } +// +// // For output searching, there are two special cases, +// // One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, +// // if the output is connected to nodes that don't belong to the subgraph, the output need to be added +// // to the output list +// // The other one is, if subgraph's node output is parent graph's output. the node output should +// // be also added to the subgraph's output list +// if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { +// for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { +// const auto& node_idx = it->GetNode().Index(); +// const onnxruntime::NodeArg* output; +// // The dst_arg_index from GetDstArgIndex() could be the index for explicit/implicit input defs of the node. +// // We need to get the correct input index accordingly. (See Graph::BuildConnections() in graph.cc for more details) +// if (it->GetDstArgIndex() < static_cast(it->GetNode().InputDefs().size())) { +// output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()]; +// } else { +// output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; +// } +// if (node_set.find(node_idx) != node_set.end()) { +// const auto& iter = fused_inputs.find(output); +// if (iter != fused_inputs.end()) { +// fused_inputs.erase(iter); +// erased.insert(output); +// } else if (erased.find(output) == erased.end()) { +// if (graph_output_names.find(output->Name()) != graph_output_names.end()) { +// graph_outputs_to_add[output] = output_order; +// } +// fused_outputs[output] = output_order++; +// } +// } else { +// fused_outputs_to_add[output] = output_order++; +// } +// } +// } else { +// for (const auto& output : node->OutputDefs()) { +// const auto& it = fused_inputs.find(output); +// if (it != fused_inputs.end()) { +// fused_inputs.erase(it); +// erased.insert(output); +// } +// // Only when output is neither in input list nor erased list, add the output to output list +// else if (erased.find(output) == erased.end()) { +// if (graph_output_names.find(output->Name()) != graph_output_names.end()) { +// graph_outputs_to_add[output] = output_order; +// } +// fused_outputs[output] = output_order++; +// } +// } +// } +// } +// +// fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); +// fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); +// +// // Sort inputs and outputs by the order they were added +// std::multimap inputs, outputs; +// for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { +// inputs.insert(std::pair(it->second, it->first)); +// } +// +// for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { +// outputs.insert(std::pair(it->second, it->first)); +// } +// +// // Generate unique kernel name for TRT subgraph +// std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); +// auto meta_def = IndexedSubGraph_MetaDef::Create(); +// const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph"; +// meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + subgraph_id; +// LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT subgraph MetaDef name " + meta_def->name(); +// +// // Assign inputs and outputs to subgraph's meta_def +// for (const auto& input : inputs) { +// if (input.second->Exists()) { +// meta_def->inputs().push_back(input.second->Name()); +// } +// } +// +// for (const auto& initializer : initializers) { +// meta_def->constant_initializers().push_back(initializer); +// } +// +// for (const auto& output : outputs) { +// if (output.second->Exists()) { +// meta_def->outputs().push_back(output.second->Name()); +// } +// } + + sub_graph->meta_def->domain = "com.microsoft"; + sub_graph->meta_def->since_version = 1; + + return sub_graph; +} TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { @@ -1158,7 +1311,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const if (api->OrtGraph_NumberOfNodes(subgraphs[i]) == 0) { continue; } - if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], kTensorrtExecutionProvider)) { + if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "tensorrtEp")) { all_subgraphs_are_supported = false; break; } @@ -1174,7 +1327,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}}; bool early_termination = false; - // supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); + supported_nodes_vector = p->GetSupportedList(parser_nodes_vector, 0, p->max_partition_iterations_, graph, &early_termination); if (early_termination) { supported_nodes_vector.clear(); } @@ -1231,11 +1384,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } else { // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } - - // The context map is only used during EP compile time, release it to save memory space. - // subgraph_context_map_.clear(); - // return result; - }; OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { @@ -2890,13 +3038,211 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi } SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, - const OrtGraphViewer& graph, bool* early_termination) const { + const OrtGraphViewer* graph, bool* early_termination) const { // Return if iterations are exceeding predefined number SubGraphCollection_t nodes_list_output; if (iterations > max_iterations) { *early_termination = true; return nodes_list_output; } + + std::unordered_set graph_output_names; + size_t output_size = api_->OrtGraph_GetOutputSize(graph); + for (size_t i = 0; i < output_size; i++) { + graph_output_names.insert(api_->OrtGraph_GetIthOutputName(graph, i)); + } + + iterations++; + size_t nodes_count = 0; + const size_t* node_index = nullptr; + api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_count, &node_index); + for (const auto& group : nodes_vector_input) { + // Construct subgraph + if (!group.first.empty()) { + if (group.second) { + nodes_list_output.push_back(group); + } else { +// auto model_build = graph.CreateModel(*GetLogger()); +// auto& graph_build = model_build->MainGraph(); +// bool has_control_flow_op = false; +// +// // Add node and node args +// // If node output is also parent graph output, the output will be added to the +// // subgraph's output list +// std::vector subgraph_output_names; +// for (const auto& index : group.first) { +// const auto& node = graph.GetNode(node_index[index]); +// std::vector inputs, outputs; +// for (auto input : node->InputDefs()) { +// auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); +// inputs.push_back(&n_input); +// const ONNX_NAMESPACE::TensorProto* initializer = nullptr; +// if (graph.GetInitializedTensor(input->Name(), initializer)) { +// const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; +// if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { +// graph_build.AddInitializedTensor(*(initializer)); +// } +// } +// } +// +// for (auto input : node->ImplicitInputDefs()) { +// const ONNX_NAMESPACE::TensorProto* initializer = nullptr; +// if (graph.GetInitializedTensor(input->Name(), initializer)) { +// const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; +// if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { +// graph_build.AddInitializedTensor(*(initializer)); +// } +// } +// } +// for (auto output : node->OutputDefs()) { +// auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); +// outputs.push_back(&n_output); +// const auto name = output->Name(); +// if (graph_output_names.find(name) != graph_output_names.end()) { +// subgraph_output_names.push_back(name); +// } +// } +// +// if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { +// has_control_flow_op = true; +// } +// +// // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node attributes are not in sync because of graph optimization. +// // Therefore, we need to force GraphProto attributes to be updated in order to get the valid GraphProto. +// if (node->GetAttributes().size() > 0) { +// auto node_proto = ONNX_NAMESPACE::NodeProto::Create(); +// // we need to update any GraphProto attributes for subgraphs so that any changes made by things +// // such as the optimizers are captured. otherwise we can end up saving an invalid graph. +// node->ToProto(*node_proto, /* update_subgraphs */ true); +// const int num_attributes = node_proto->attribute_size(); +// auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); +// node_attributes->reserve(num_attributes); +// +// for (int i = 0; i < num_attributes; ++i) { +// auto& attr = node_proto->attribute(i); +// node_attributes->emplace(attr.name(), attr); +// } +// +// // The GraphProto attributes are the updated ones. +// graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, node_attributes.get(), node->Domain()); +// } else { +// // The GraphProto attributes are the original ones. +// graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); +// } +// } +// +// // Only if the newly built graph has control flow op as well as it has parent node, +// // it needs to handle outer scope values before calling graph.Resolve(). +// if (has_control_flow_op && graph.ParentNode()) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); +// BuildSubGraphContext(graph_build); +// SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); +// SetAllGraphInputs(graph_build); +// } +// +// ORT_ENFORCE(graph_build.Resolve().IsOK()); +// +// // Add parent graph output to the subgraph +// int i = 0; +// std::vector subgraph_outputs; +// subgraph_outputs.resize(subgraph_output_names.size()); +// for (auto& name : subgraph_output_names) { +// auto output_arg = graph.GetNodeArg(name); +// auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); +// subgraph_outputs[i] = &subgraph_output_arg; +// ++i; +// } +// auto& graph_build_outputs = graph_build.GetOutputs(); +// subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); +// graph_build.SetOutputs(graph_build_outputs); +// ORT_ENFORCE(graph_build.Resolve().IsOK()); +// +// // Check if input tensors have shapes +// if (iterations > 1) { +// auto graph_inputs = graph_build.GetInputs(); +// for (auto input_arg : graph_inputs) { +// bool has_dim_value_or_param = true; +// auto input_shape = input_arg->Shape(); +// if (input_shape != nullptr) { +// auto dim_size = input_shape->dim_size(); +// for (int i = 0; i < dim_size; ++i) { +// auto& dim = input_shape->dim(i); +// if (!dim.has_dim_value() && !dim.has_dim_param()) { +// has_dim_value_or_param = false; +// break; +// } +// } +// } +// +// if (input_shape == nullptr || !has_dim_value_or_param) { +// ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, +// "TensorRT input: " + input_arg->Name() + " has no shape specified. " + +// "Please run shape inference on the onnx model first. Details can be found in " + +// "https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs")); +// } +// } +// } +// +// // Serialize modelproto to string +// auto graph_viewer = graph_build.CreateGraphViewer(); +// auto model = graph_viewer->CreateModel(*GetLogger()); +// auto model_proto = model->ToProto(); +// +// // ORT's default topological sort is using reversed DFS. +// // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. +// // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating +// // the model proto that has different node ordering compared to original onnx model. +// graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); +// model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); +// +// std::string string_buf; +// model_proto->SerializeToString(string_buf); +// +// if (dump_subgraphs_) { +// // Dump TensorRT subgraph for debugging +// std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx", std::ios::out | std::ios::trunc | std::ios::binary); +// model_proto->SerializeToOstream(dump); +// } + + void* buf_data = nullptr; + size_t buf_size = api_->OrtGraph_SerializeToArray(graph, &buf_data); + std::string string_buf(reinterpret_cast(buf_data), buf_size); + + // Get supported node list recursively + SubGraphCollection_t parser_nodes_list; + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#endif + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + + auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + SubGraphCollection_t next_nodes_list; + size_t subgraph_node_count = 0; + const size_t* subgraph_node_index = nullptr; + api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &subgraph_node_count, &subgraph_node_index); + next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, graph, early_termination); + for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { + for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; + } + nodes_list_output.push_back(next_nodes_list[i]); + } + } + } + } return nodes_list_output; } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 8d76381c4cd08..b9b9e01e41b3c 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -4,9 +4,7 @@ #include #include "core/session/onnxruntime_c_api.h" #include "core/framework/provider_options.h" -#include "core/graph/constants.h" #include "nv_includes.h" -#include "onnx_ctx_model_helper.h" #ifdef _WIN32 #define EXPORT_API __declspec(dllexport) @@ -15,7 +13,7 @@ #endif namespace onnxruntime { - +using HashValue = uint64_t; using AllocateFunc = void* (*)(void*, size_t, size_t); using DestroyFunc = void (*)(void*, void*); @@ -180,8 +178,11 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log); + + std::unique_ptr GetSubGraph(SubGraph_t graph_nodes_index, + const OrtGraphViewer* graph, const HashValue& model_hash, int subgraph_index) const; SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations, - const OrtGraphViewer& graph, bool* early_termination) const; + const OrtGraphViewer* graph, bool* early_termination) const; bool DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles = true) const; diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index 2d652488b75d4..e9a9ff0cd46c1 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -1,10 +1,13 @@ #pragma once #include #include +#include #include #include +#include #include "flatbuffers/idl.h" #include "ort_trt_int8_cal_table.fbs.h" +#include "murmurhash3.h" namespace fs = std::filesystem; @@ -265,6 +268,102 @@ std::string GetTimingCachePath(const std::string& root, std::string& compute_cap return GetCachePath(root, timing_cache_name); } +HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { + HashValue model_hash = 0; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + const OrtGraph* cur_graph = nullptr; + api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph); + bool is_subgraph = false; + api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + while (is_subgraph) { + const OrtGraph* parent_graph = nullptr; + api->OrtGraph_GetParentGraph(cur_graph, &parent_graph); + cur_graph = parent_graph; + api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + } + + const OrtGraph* main_graph = cur_graph; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + const std::filesystem::path* model_path = nullptr; + api->OrtGraph_GetModelPath(graph_viewer, (const void**)&model_path); + + // Use the model's file name instead of the entire path to avoid cache regeneration if path changes + if (model_path->has_filename()) { + std::string model_name = model_path->filename(); + + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name; + // Ensure enough characters are hashed in case model names are too short + const size_t model_name_length = model_name.size(); + constexpr size_t hash_string_length = 500; + std::string repeat_model_name = model_name; + for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { + repeat_model_name += model_name; + } + hash_str(repeat_model_name); + } else { + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty"; + } + + // fingerprint current graph by hashing graph inputs + // const std::vector& input_names = nullptr; + const char** input_names = nullptr; + size_t input_count = 0; + api->OrtGraph_GetInputsIncludingInitializers(graph_viewer, &input_count, &input_names); + for (size_t i = 0; i < input_count; ++i) { + hash_str(input_names[i]); + } + + // hashing output of each node + const int number_of_ort_nodes = api->OrtGraph_NumberOfNodes(graph_viewer); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + size_t nodes_count = 0; + const size_t* nodes_index = nullptr; + api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_count, &nodes_index); + for (const auto& index : nodes_vector) { + const OrtNode* node = nullptr; + api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index], &node); + size_t output_size = 0; + api->OrtNode_GetOutputSize(node, &output_size); + for (size_t i = 0; i < output_size; ++i) { + const char* output_name = nullptr; + api->OrtNode_GetIthOutputName(node, i, &output_name); + if (output_name != nullptr) { + hash_str(output_name); + } + } + } + +#ifdef __linux__ + hash_str("LINUX"); +#elif defined(_WIN32) + hash_str("WINDOWS"); +#endif + +#ifdef ORT_VERSION + hash_str(ORT_VERSION); +#endif + +#ifdef CUDA_VERSION + hash_str(std::to_string(CUDA_VERSION)); +#endif + +#if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR) + std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR); + hash_str(TRT_VERSION); +#endif + + model_hash = hash[0] | (uint64_t(hash[1]) << 32); + + // return the current unique id + return model_hash; +} + std::vector split(const std::string& str, char delimiter) { std::vector tokens; std::string token; From 865a17f8b61a4c60c3d1176b11a7c6c92ee74bb7 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 29 Aug 2024 00:17:59 +0000 Subject: [PATCH 25/81] GetSubGraph and TensorrtExecutionProviderInfo --- .../core/session/onnxruntime_c_api.h | 16 +- onnxruntime/core/session/onnxruntime_c_api.cc | 12 +- onnxruntime/core/session/ort_apis.h | 4 +- samples/outTreeEp/out_tree_ep.cc | 14 +- .../tensorRTEp/tensorrt_execution_provider.cc | 247 ++++++++----- .../tensorRTEp/tensorrt_execution_provider.h | 3 +- .../tensorrt_execution_provider_info.cc | 340 ++++++++++++++++++ .../tensorrt_execution_provider_info.h | 67 ++++ .../tensorrt_execution_provider_utils.h | 4 +- 9 files changed, 588 insertions(+), 119 deletions(-) create mode 100644 samples/tensorRTEp/tensorrt_execution_provider_info.cc create mode 100644 samples/tensorRTEp/tensorrt_execution_provider_info.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a3025690fb666..fad01dc90d5f3 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -710,18 +710,18 @@ typedef struct OrtCreateStream { } OrtCreateStream; typedef struct OrtMetaDef { - const char* name; - const char* domain; + char* name; + char* domain; int since_version; - const char** inputs; + char** inputs; size_t input_len; - const char** outputs; + char** outputs; size_t output_len; - const char** constant_initializers; + char** constant_initializers; size_t initializer_len; - const char* doc_string; + char* doc_string; } OrtMetaDef; typedef struct OrtIndexedSubGraph { @@ -4756,11 +4756,13 @@ struct OrtApi { ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + const char*(ORT_API_CALL* OrtGraph_GetName)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); - ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); + ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret); ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 92af3800123d3..c3807dac76d17 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2415,6 +2415,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendOrtExecutionProvider, _In_ OrtS return nullptr; } +ORT_API(const char*, OrtApis::OrtGraph_GetName, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->Name().c_str(); +} + ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); *ret = graph_viewer->IsConstantInitializer(name, check_outer_scope); @@ -2429,9 +2434,9 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, const Ort return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret) { - const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); - *ret = graph_ptr->IsSubgraph(); +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *ret = graph_viewer->IsSubgraph(); return nullptr; } @@ -3095,6 +3100,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::RegisterOrtExecutionProviderLibrary, &OrtApis::SessionOptionsAppendOrtExecutionProvider, + &OrtApis::OrtGraph_GetName, &OrtApis::OrtGraph_IsConstantInitializer, &OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, &OrtApis::OrtGraph_IsSubgraph, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index e188515ceeeeb..bdd1077d2624c 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -539,6 +539,8 @@ ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* l ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); +ORT_API(const char*, OrtGraph_GetName, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; + ORT_API_STATUS_IMPL(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); ORT_API_STATUS_IMPL(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); @@ -549,7 +551,7 @@ ORT_API_STATUS_IMPL(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path); -ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); +ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret); ORT_API_STATUS_IMPL(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index 71b950abc9c4a..e8362eacbb024 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -27,18 +27,20 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe subgraph->meta_def->name = "Relu_subgraph"; subgraph->meta_def->input_len = 0; api->OrtNode_GetInputSize(node, &(subgraph->meta_def->input_len)); - subgraph->meta_def->inputs = new const char* [subgraph->meta_def->input_len]; + subgraph->meta_def->inputs = new char* [subgraph->meta_def->input_len]; for (size_t j = 0; j < subgraph->meta_def->input_len; j++) { - subgraph->meta_def->inputs[j] = nullptr; - api->OrtNode_GetIthInputName(node, j, &(subgraph->meta_def->inputs[j])); + const char* input_j = nullptr; + api->OrtNode_GetIthInputName(node, j, &input_j); + subgraph->meta_def->inputs[j] = const_cast(input_j); } subgraph->meta_def->output_len = 0; api->OrtNode_GetOutputSize(node, &(subgraph->meta_def->output_len)); - subgraph->meta_def->outputs = new const char* [subgraph->meta_def->output_len]; + subgraph->meta_def->outputs = new char* [subgraph->meta_def->output_len]; for (size_t j = 0; j < subgraph->meta_def->output_len; j++) { - subgraph->meta_def->outputs[j] = nullptr; - api->OrtNode_GetIthOutputName(node, j, &(subgraph->meta_def->outputs[j])); + const char* output_j = nullptr; + api->OrtNode_GetIthOutputName(node, j, &output_j); + subgraph->meta_def->outputs[j] = const_cast(output_j); } cache.push_back(subgraph); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index df85934af3fab..bb8aa4d6c1e8d 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -930,8 +930,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { // Construct subgraph from node list - // std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); - OrtIndexedSubGraph* subgraph = new OrtIndexedSubGraph(); + std::unique_ptr subgraph = GetSubGraph(group, graph, model_hash, subgraph_index); // Create node to inputs/outputs/index maps const std::string node_name = subgraph->meta_def->name; @@ -1053,7 +1052,7 @@ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* const OrtGraph* cur_graph = nullptr; api->OrtGraph_GetOrtGraph(graph, &cur_graph); bool is_subgraph = false; - api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + api->OrtGraph_IsSubgraph(graph, &is_subgraph); if (is_subgraph) { const OrtNode* node = nullptr; api->OrtGraph_GetParenNode(graph, &node); @@ -1118,46 +1117,62 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr // Find inputs and outputs of the subgraph std::unique_ptr sub_graph = std::make_unique(); + sub_graph->node_index_len = graph_nodes_index.first.size(); + sub_graph->node_index = new size_t [sub_graph->node_index_len]; sub_graph->meta_def = new OrtMetaDef(); -// std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; -// std::unordered_set erased; -// int input_order = 0; -// int output_order = 0; -// -// std::vector initializers; -// for (const auto& index : graph_nodes_index.first) { -// sub_graph->Nodes().push_back(node_index[index]); -// const auto& node = graph.GetNode(node_index[index]); -// for (const auto& input : node->InputDefs()) { -// if (graph.IsConstantInitializer(input->Name(), true)) { -// initializers.push_back(input->Name()); -// continue; -// } -// const auto& it = fused_outputs.find(input); -// if (it != fused_outputs.end()) { -// fused_outputs.erase(it); -// erased.insert(input); -// } else if (erased.find(input) == erased.end()) { -// // Only when input is neither in output list nor erased list, add the input to input list -// fused_inputs[input] = input_order++; -// } -// } -// -// for (const auto& input : node->ImplicitInputDefs()) { -// if (graph.IsConstantInitializer(input->Name(), true)) { -// initializers.push_back(input->Name()); -// continue; -// } -// const auto& it = fused_outputs.find(input); -// if (it != fused_outputs.end()) { -// fused_outputs.erase(it); -// erased.insert(input); -// } else if (erased.find(input) == erased.end()) { -// // Only when input is neither in output list nor erased list, add the input to input list -// fused_inputs[input] = input_order++; -// } -// } -// + std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_set erased; + int input_order = 0; + int output_order = 0; + + std::vector initializers; + int i = 0; + for (const auto& index : graph_nodes_index.first) { + sub_graph->node_index[i++] = node_index[index]; + const OrtNode* node = nullptr; + api_->OrtGraph_GetOrtNode(graph, node_index[index], &node); + size_t input_size = 0; + api_->OrtNode_GetInputSize(node, &input_size); + for (size_t j = 0; j < input_size; j++) { + const char* input_name = nullptr; + api_->OrtNode_GetIthInputName(node, j, &input_name); + bool is_constant_initializer = false; + api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_constant_initializer); + if (is_constant_initializer) { + initializers.push_back(input_name); + continue; + } + const auto& it = fused_outputs.find(input_name); + if (it != fused_outputs.end()) { + fused_outputs.erase(it); + erased.insert(input_name); + } else if (erased.find(input_name) == erased.end()) { + // Only when input is neither in output list nor erased list, add the input to input list + fused_inputs[input_name] = input_order++; + } + } + + size_t implicit_input_size = 0; + api_->OrtNode_GetImplicitInputSize(node, &implicit_input_size); + for (size_t j = 0; j < implicit_input_size; j++) { + const char* input_name = nullptr; + api_->OrtNode_GetIthImplicitInputName(node, j, &input_name); + bool is_constant_initializer = false; + api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_constant_initializer); + if (is_constant_initializer) { + initializers.push_back(input_name); + continue; + } + const auto& it = fused_outputs.find(input_name); + if (it != fused_outputs.end()) { + fused_outputs.erase(it); + erased.insert(input_name); + } else if (erased.find(input_name) == erased.end()) { + // Only when input is neither in output list nor erased list, add the input to input list + fused_inputs[input_name] = input_order++; + } + } + // // For output searching, there are two special cases, // // One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, // // if the output is connected to nodes that don't belong to the subgraph, the output need to be added @@ -1191,59 +1206,74 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr // } // } // } else { -// for (const auto& output : node->OutputDefs()) { -// const auto& it = fused_inputs.find(output); -// if (it != fused_inputs.end()) { -// fused_inputs.erase(it); -// erased.insert(output); -// } -// // Only when output is neither in input list nor erased list, add the output to output list -// else if (erased.find(output) == erased.end()) { -// if (graph_output_names.find(output->Name()) != graph_output_names.end()) { -// graph_outputs_to_add[output] = output_order; -// } -// fused_outputs[output] = output_order++; -// } -// } -// } -// } -// -// fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); -// fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); -// -// // Sort inputs and outputs by the order they were added -// std::multimap inputs, outputs; -// for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { -// inputs.insert(std::pair(it->second, it->first)); -// } -// -// for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { -// outputs.insert(std::pair(it->second, it->first)); -// } -// -// // Generate unique kernel name for TRT subgraph -// std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); -// auto meta_def = IndexedSubGraph_MetaDef::Create(); -// const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph"; -// meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + subgraph_id; -// LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT subgraph MetaDef name " + meta_def->name(); -// -// // Assign inputs and outputs to subgraph's meta_def -// for (const auto& input : inputs) { -// if (input.second->Exists()) { -// meta_def->inputs().push_back(input.second->Name()); -// } -// } -// -// for (const auto& initializer : initializers) { -// meta_def->constant_initializers().push_back(initializer); -// } -// -// for (const auto& output : outputs) { -// if (output.second->Exists()) { -// meta_def->outputs().push_back(output.second->Name()); + size_t output_size = 0; + api_->OrtNode_GetOutputSize(node, &output_size); + for (size_t j = 0; j < output_size; j++) { + const char* output_name = nullptr; + api_->OrtNode_GetIthOutputName(node, j, &output_name); + const auto& it = fused_inputs.find(output_name); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output_name); + } + // Only when output is neither in input list nor erased list, add the output to output list + else if (erased.find(output_name) == erased.end()) { + if (graph_output_names.find(output_name) != graph_output_names.end()) { + graph_outputs_to_add[output_name] = output_order; + } + fused_outputs[output_name] = output_order++; + } + } // } -// } + } + + fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); + fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); + + // Sort inputs and outputs by the order they were added + std::multimap inputs, outputs; + for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { + inputs.insert(std::pair(it->second, it->first)); + } + + for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { + outputs.insert(std::pair(it->second, it->first)); + } + + // Generate unique kernel name for TRT subgraph + std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); + bool is_subgraph = false; + api_->OrtGraph_IsSubgraph(graph, &is_subgraph); + const std::string graph_type = is_subgraph ? "subgraph" : "graph"; + const char* graph_name = api_->OrtGraph_GetName(graph); + std::string meta_def_name = "TRTKernel_" + graph_type + "_" + std::string(graph_name) + subgraph_id; + sub_graph->meta_def->name = new char [meta_def_name.length() + 1]; + strcpy(sub_graph->meta_def->name, meta_def_name.c_str()); + + // Assign inputs and outputs to subgraph's meta_def + sub_graph->meta_def->input_len = inputs.size(); + sub_graph->meta_def->inputs = new char* [sub_graph->meta_def->input_len]; + i = 0; + for (const auto& input : inputs) { + sub_graph->meta_def->inputs[i] = new char [input.second.length() + 1]; + strcpy(sub_graph->meta_def->inputs[i++], input.second.c_str()); + } + + sub_graph->meta_def->initializer_len = initializers.size(); + sub_graph->meta_def->constant_initializers = new char* [sub_graph->meta_def->initializer_len]; + i = 0; + for (const auto& initializer : initializers) { + sub_graph->meta_def->constant_initializers[i] = new char [initializer.length() + 1]; + strcpy(sub_graph->meta_def->constant_initializers[i++], initializer.c_str()); + } + + sub_graph->meta_def->output_len = outputs.size(); + sub_graph->meta_def->outputs = new char* [sub_graph->meta_def->output_len]; + i = 0; + for (const auto& output : outputs) { + sub_graph->meta_def->outputs[i] = new char [output.second.length() + 1]; + strcpy(sub_graph->meta_def->outputs[i++], output.second.c_str()); + } sub_graph->meta_def->domain = "com.microsoft"; sub_graph->meta_def->since_version = 1; @@ -1268,9 +1298,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const if (api->OrtGraph_NumberOfNodes(graph) == 1 && GraphHasCtxNode(graph)) { SubGraph_t supported_node_vector = {{0}, true}; - // std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); - // result.push_back(ComputeCapability::Create(std::move(sub_graph))); - // return result; + std::unique_ptr sub_graph = p->GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); + *cnt = 1; + *indexed_sub_graph = new OrtIndexedSubGraph* [1]; + (*indexed_sub_graph)[0] = sub_graph.release(); + return; } // Generate unique kernel name for TRT graph @@ -1360,17 +1392,26 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } } + std::vector cache; // Handle the case where the graph is subgraph of control flow op. // The purpose is to make control flow op as well as its subgraphs run on TRT. // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) { + bool all_subgraphs_are_supported = true; + + if (all_subgraphs_are_supported) { + for (const auto& group : supported_nodes_vector) { + + } + return; + } } int number_of_trt_nodes = 0, subgraph_index = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { - // std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); - // result.push_back(ComputeCapability::Create(std::move(sub_graph))); + std::unique_ptr sub_graph = p->GetSubGraph(group, graph, model_hash, subgraph_index); + cache.push_back(sub_graph.release()); number_of_trt_nodes += static_cast(group.first.size()); subgraph_index++; } @@ -1384,6 +1425,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } else { // LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } + + *cnt = cache.size(); + *indexed_sub_graph = new OrtIndexedSubGraph* [*cnt]; + for (size_t i = 0; i < *cnt; i++) { + (*indexed_sub_graph)[i] = cache[i]; + } }; OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { @@ -1477,6 +1524,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const }; api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, 0, &default_device); + + info_ = TensorrtExecutionProviderInfo::FromProviderOptions(ep_info); } TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index b9b9e01e41b3c..2a12de3c10f9c 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -4,6 +4,7 @@ #include #include "core/session/onnxruntime_c_api.h" #include "core/framework/provider_options.h" +#include "tensorrt_execution_provider_info.h" #include "nv_includes.h" #ifdef _WIN32 @@ -200,7 +201,7 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { std::unordered_map dynamic_range_map_; std::string cache_suffix_; private: -// mutable TensorrtExecutionProviderInfo info_; + mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; int max_partition_iterations_ = 1000; diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.cc b/samples/tensorRTEp/tensorrt_execution_provider_info.cc new file mode 100644 index 0000000000000..a6caab6642662 --- /dev/null +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.cc @@ -0,0 +1,340 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//#include "core/providers/tensorrt/tensorrt_execution_provider_info.h" +//#include "core/providers/tensorrt/tensorrt_provider_options.h" +// +//#include "core/common/make_string.h" +//#include "core/common/parse_string.h" +//#include "core/framework/provider_options_utils.h" +//#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace tensorrt { +namespace provider_option_names { +constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; +constexpr const char* kMaxPartitionIterations = "trt_max_partition_iterations"; +constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size"; +constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; +constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kInt8Enable = "trt_int8_enable"; +constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name"; +constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table"; +constexpr const char* kDLAEnable = "trt_dla_enable"; +constexpr const char* kDLACore = "trt_dla_core"; +constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs"; +constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable"; +constexpr const char* kEngineCachePath = "trt_engine_cache_path"; +constexpr const char* kWeightStrippedEngineEnable = "trt_weight_stripped_engine_enable"; +constexpr const char* kOnnxModelFolderPath = "trt_onnx_model_folder_path"; +constexpr const char* kEngineCachePrefix = "trt_engine_cache_prefix"; +constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable"; +constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path"; +constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build"; +// add new provider option name here. +constexpr const char* kContextMemorySharingEnable = "trt_context_memory_sharing_enable"; +constexpr const char* kLayerNormFP32Fallback = "trt_layer_norm_fp32_fallback"; +constexpr const char* kTimingCacheEnable = "trt_timing_cache_enable"; +constexpr const char* kTimingCachePath = "trt_timing_cache_path"; +constexpr const char* kForceTimingCacheMatch = "trt_force_timing_cache"; +constexpr const char* kDetailedBuildLog = "trt_detailed_build_log"; +constexpr const char* kBuildHeuristics = "trt_build_heuristics_enable"; +constexpr const char* kSparsityEnable = "trt_sparsity_enable"; +constexpr const char* kBuilderOptimizationLevel = "trt_builder_optimization_level"; +constexpr const char* kAuxiliaryStreams = "trt_auxiliary_streams"; +constexpr const char* kTacticSources = "trt_tactic_sources"; +constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths"; +constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes"; +constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes"; +constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes"; +constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable"; +constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode"; +constexpr const char* kEpContextFilePath = "trt_ep_context_file_path"; +constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; +constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; + +} // namespace provider_option_names +} // namespace tensorrt + +TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { + TensorrtExecutionProviderInfo info{}; +// void* user_compute_stream = nullptr; +// ORT_THROW_IF_ERROR( +// ProviderOptionsParser{} +// .AddValueParser( +// tensorrt::provider_option_names::kDeviceId, +// [&info](const std::string& value_str) -> Status { +// ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); +// int num_devices{}; +// CUDA_RETURN_IF_ERROR(cudaGetDeviceCount(&num_devices)); +// ORT_RETURN_IF_NOT( +// 0 <= info.device_id && info.device_id < num_devices, +// "Invalid device ID: ", info.device_id, +// ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); +// return Status::OK(); +// }) +// .AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations) +// .AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) +// .AddValueParser( +// tensorrt::provider_option_names::kUserComputeStream, +// [&user_compute_stream](const std::string& value_str) -> Status { +// size_t address; +// ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); +// user_compute_stream = reinterpret_cast(address); +// return Status::OK(); +// }) +// .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) +// .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) +// .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) +// .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) +// .AddAssignmentToReference(tensorrt::provider_option_names::kDLAEnable, info.dla_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kDLACore, info.dla_core) +// .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) +// .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) +// .AddAssignmentToReference(tensorrt::provider_option_names::kWeightStrippedEngineEnable, info.weight_stripped_engine_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kOnnxModelFolderPath, info.onnx_model_folder_path) +// .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix) +// .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) +// .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) +// .AddAssignmentToReference(tensorrt::provider_option_names::kContextMemorySharingEnable, info.context_memory_sharing_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kLayerNormFP32Fallback, info.layer_norm_fp32_fallback) +// .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCachePath, info.timing_cache_path) +// .AddAssignmentToReference(tensorrt::provider_option_names::kForceTimingCacheMatch, info.force_timing_cache) +// .AddAssignmentToReference(tensorrt::provider_option_names::kDetailedBuildLog, info.detailed_build_log) +// .AddAssignmentToReference(tensorrt::provider_option_names::kBuildHeuristics, info.build_heuristics_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kSparsityEnable, info.sparsity_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level) +// .AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams) +// .AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources) +// .AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths) +// .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes) +// .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) +// .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) +// .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) +// .AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model) +// .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) +// .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) +// .AddAssignmentToReference(tensorrt::provider_option_names::kEngineHwCompatible, info.engine_hw_compatible) +// .Parse(options)); // add new provider option here. +// +// info.user_compute_stream = user_compute_stream; +// info.has_user_compute_stream = (user_compute_stream != nullptr); + return info; +} + +//ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtExecutionProviderInfo& info) { +// const ProviderOptions options{ +// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, +// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.max_partition_iterations)}, +// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, +// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, +// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)}, +// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, +// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, +// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, +// {tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)}, +// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)}, +// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.dla_enable)}, +// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.dla_core)}, +// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, +// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, +// {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, +// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.weight_stripped_engine_enable)}, +// {tensorrt::provider_option_names::kOnnxModelFolderPath, MakeStringWithClassicLocale(info.onnx_model_folder_path)}, +// {tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)}, +// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, +// {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, +// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, +// // add new provider option here. +// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.context_memory_sharing_enable)}, +// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.layer_norm_fp32_fallback)}, +// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)}, +// {tensorrt::provider_option_names::kTimingCachePath, MakeStringWithClassicLocale(info.timing_cache_path)}, +// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.force_timing_cache)}, +// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)}, +// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.build_heuristics_enable)}, +// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.sparsity_enable)}, +// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)}, +// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)}, +// {tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)}, +// {tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)}, +// {tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)}, +// {tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, +// {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, +// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, +// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)}, +// {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, +// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, +// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, +// }; +// return options; +//} +// +//ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensorRTProviderOptionsV2& info) { +// auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; }; +// const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name); +// const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path); +// const std::string kEngineCachePrefix_ = empty_if_null(info.trt_engine_cache_prefix); +// const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path); +// const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources); +// const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path); +// const std::string kExtraPluginLibPaths_ = empty_if_null(info.trt_extra_plugin_lib_paths); +// const std::string kProfilesMinShapes_ = empty_if_null(info.trt_profile_min_shapes); +// const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes); +// const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); +// const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); +// const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path); +// +// const ProviderOptions options{ +// {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, +// {tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, +// {tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, +// {tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.trt_max_partition_iterations)}, +// {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)}, +// {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)}, +// {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.trt_fp16_enable)}, +// {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.trt_int8_enable)}, +// {tensorrt::provider_option_names::kInt8CalibTable, kInt8CalibTable_}, +// {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.trt_int8_use_native_calibration_table)}, +// {tensorrt::provider_option_names::kDLAEnable, MakeStringWithClassicLocale(info.trt_dla_enable)}, +// {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.trt_dla_core)}, +// {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)}, +// {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, +// {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, +// {tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_}, +// {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.trt_weight_stripped_engine_enable)}, +// {tensorrt::provider_option_names::kOnnxModelFolderPath, kOnnxModelFolderPath_}, +// {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, +// {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, +// {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, +// {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.trt_context_memory_sharing_enable)}, +// {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.trt_layer_norm_fp32_fallback)}, +// {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.trt_timing_cache_enable)}, +// {tensorrt::provider_option_names::kTimingCachePath, kTimingCachePath_}, +// {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.trt_force_timing_cache)}, +// {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.trt_detailed_build_log)}, +// {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.trt_build_heuristics_enable)}, +// {tensorrt::provider_option_names::kSparsityEnable, MakeStringWithClassicLocale(info.trt_sparsity_enable)}, +// {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.trt_builder_optimization_level)}, +// {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.trt_auxiliary_streams)}, +// {tensorrt::provider_option_names::kTacticSources, kTacticSources_}, +// {tensorrt::provider_option_names::kExtraPluginLibPaths, kExtraPluginLibPaths_}, +// {tensorrt::provider_option_names::kProfilesMinShapes, kProfilesMinShapes_}, +// {tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_}, +// {tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_}, +// {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, +// {tensorrt::provider_option_names::kEpContextFilePath, kEpContextFilePath_}, +// {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, +// {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, +// {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, +// }; +// return options; +//} +// +///** +// * Update OrtTensorRTProviderOptionsV2 instance with ProviderOptions (map of string-based key-value pairs) +// * +// * Please note that it will reset the OrtTensorRTProviderOptionsV2 instance first and then set up the provided provider options +// * See TensorrtExecutionProviderInfo::FromProviderOptions() for more details. This function will be called by the C API UpdateTensorRTProviderOptions() also. +// * +// * \param provider_options - a pointer to OrtTensorRTProviderOptionsV2 instance +// * \param options - a reference to ProviderOptions instance +// * \param string_copy - if it's true, it uses strncpy() to copy 'provider option' string from ProviderOptions instance to where the 'provider option' const char pointer in OrtTensorRTProviderOptionsV2 instance points to. +// * it it's false, it only saves the pointer and no strncpy(). +// * +// * Note: If there is strncpy involved, please remember to deallocate or simply call C API ReleaseTensorRTProviderOptions. +// */ +//void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) { +// if (provider_options == nullptr) { +// return; +// } +// auto copy_string_if_needed = [&](std::string& s_in) { +// if (string_copy) { +// char* dest = nullptr; +// auto str_size = s_in.size(); +// if (str_size == 0) { +// return (const char*)nullptr; +// } else { +// dest = new char[str_size + 1]; +//#ifdef _MSC_VER +// strncpy_s(dest, str_size + 1, s_in.c_str(), str_size); +//#else +// strncpy(dest, s_in.c_str(), str_size); +//#endif +// dest[str_size] = '\0'; +// return (const char*)dest; +// } +// } else { +// return s_in.c_str(); +// } +// }; +// +// TensorrtExecutionProviderInfo internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options); +// auto& trt_provider_options_v2 = *reinterpret_cast(provider_options); +// trt_provider_options_v2.device_id = internal_options.device_id; +// +// // The 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance can be set by C API UpdateTensorRTProviderOptionsWithValue() as well +// // We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options or user_compute_stream is provided +// if (options.find("has_user_compute_stream") != options.end()) { +// trt_provider_options_v2.has_user_compute_stream = internal_options.has_user_compute_stream; +// } +// if (options.find("user_compute_stream") != options.end() && internal_options.user_compute_stream != nullptr) { +// trt_provider_options_v2.user_compute_stream = internal_options.user_compute_stream; +// trt_provider_options_v2.has_user_compute_stream = true; +// } +// +// trt_provider_options_v2.trt_max_partition_iterations = internal_options.max_partition_iterations; +// trt_provider_options_v2.trt_min_subgraph_size = internal_options.min_subgraph_size; +// trt_provider_options_v2.trt_max_workspace_size = internal_options.max_workspace_size; +// trt_provider_options_v2.trt_fp16_enable = internal_options.fp16_enable; +// trt_provider_options_v2.trt_int8_enable = internal_options.int8_enable; +// +// trt_provider_options_v2.trt_int8_calibration_table_name = copy_string_if_needed(internal_options.int8_calibration_table_name); +// +// trt_provider_options_v2.trt_int8_use_native_calibration_table = internal_options.int8_use_native_calibration_table; +// trt_provider_options_v2.trt_dla_enable = internal_options.dla_enable; +// trt_provider_options_v2.trt_dla_core = internal_options.dla_core; +// trt_provider_options_v2.trt_dump_subgraphs = internal_options.dump_subgraphs; +// trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable; +// trt_provider_options_v2.trt_weight_stripped_engine_enable = internal_options.weight_stripped_engine_enable; +// trt_provider_options_v2.trt_onnx_model_folder_path = copy_string_if_needed(internal_options.onnx_model_folder_path); +// +// trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path); +// trt_provider_options_v2.trt_engine_cache_prefix = copy_string_if_needed(internal_options.engine_cache_prefix); +// trt_provider_options_v2.trt_timing_cache_path = copy_string_if_needed(internal_options.timing_cache_path); +// +// trt_provider_options_v2.trt_engine_decryption_enable = internal_options.engine_decryption_enable; +// +// trt_provider_options_v2.trt_engine_decryption_lib_path = copy_string_if_needed(internal_options.engine_decryption_lib_path); +// +// trt_provider_options_v2.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build; +// trt_provider_options_v2.trt_context_memory_sharing_enable = internal_options.context_memory_sharing_enable; +// trt_provider_options_v2.trt_layer_norm_fp32_fallback = internal_options.layer_norm_fp32_fallback; +// trt_provider_options_v2.trt_timing_cache_enable = internal_options.timing_cache_enable; +// trt_provider_options_v2.trt_force_timing_cache = internal_options.force_timing_cache; +// trt_provider_options_v2.trt_detailed_build_log = internal_options.detailed_build_log; +// trt_provider_options_v2.trt_build_heuristics_enable = internal_options.build_heuristics_enable; +// trt_provider_options_v2.trt_sparsity_enable = internal_options.sparsity_enable; +// trt_provider_options_v2.trt_builder_optimization_level = internal_options.builder_optimization_level; +// trt_provider_options_v2.trt_auxiliary_streams = internal_options.auxiliary_streams; +// +// trt_provider_options_v2.trt_tactic_sources = copy_string_if_needed(internal_options.tactic_sources); +// trt_provider_options_v2.trt_extra_plugin_lib_paths = copy_string_if_needed(internal_options.extra_plugin_lib_paths); +// trt_provider_options_v2.trt_profile_min_shapes = copy_string_if_needed(internal_options.profile_min_shapes); +// trt_provider_options_v2.trt_profile_max_shapes = copy_string_if_needed(internal_options.profile_max_shapes); +// trt_provider_options_v2.trt_profile_opt_shapes = copy_string_if_needed(internal_options.profile_opt_shapes); +// +// trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable; +// trt_provider_options_v2.trt_dump_ep_context_model = internal_options.dump_ep_context_model; +// trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode; +// trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path); +// trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible; +//} +} // namespace onnxruntime diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.h b/samples/tensorRTEp/tensorrt_execution_provider_info.h new file mode 100644 index 0000000000000..f64e3a972e01b --- /dev/null +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +//#include "core/framework/ortdevice.h" +//#include "core/framework/provider_options.h" +//#include "core/framework/framework_provider_common.h" +//#include "core/session/onnxruntime_c_api.h" +//#include "core/framework/library_handles.h" + +#define TRT_DEFAULT_OPTIMIZER_LEVEL 3 + +namespace onnxruntime { +// Information needed to construct trt execution providers. +struct TensorrtExecutionProviderInfo { + int device_id{0}; + bool has_user_compute_stream{false}; + void* user_compute_stream{nullptr}; + bool has_trt_options{false}; + int max_partition_iterations{1000}; + int min_subgraph_size{1}; + size_t max_workspace_size{1 << 30}; + bool fp16_enable{false}; + bool int8_enable{false}; + std::string int8_calibration_table_name{""}; + bool int8_use_native_calibration_table{false}; + bool dla_enable{false}; + int dla_core{0}; + bool dump_subgraphs{false}; + bool engine_cache_enable{false}; + std::string engine_cache_path{""}; + bool weight_stripped_engine_enable{false}; + std::string onnx_model_folder_path{""}; + bool engine_decryption_enable{false}; + std::string engine_decryption_lib_path{""}; + bool force_sequential_engine_build{false}; + bool context_memory_sharing_enable{false}; + bool layer_norm_fp32_fallback{false}; + bool timing_cache_enable{false}; + std::string timing_cache_path{""}; + bool force_timing_cache{false}; + bool detailed_build_log{false}; + bool build_heuristics_enable{false}; + bool sparsity_enable{false}; + int builder_optimization_level{3}; + int auxiliary_streams{-1}; + std::string tactic_sources{""}; + std::string extra_plugin_lib_paths{""}; + std::string profile_min_shapes{""}; + std::string profile_max_shapes{""}; + std::string profile_opt_shapes{""}; + bool cuda_graph_enable{false}; + bool dump_ep_context_model{false}; + std::string ep_context_file_path{""}; + int ep_context_embed_mode{0}; + std::string engine_cache_prefix{""}; + bool engine_hw_compatible{false}; + + static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); +// static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); +// static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); +// static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy); +// +// std::vector custom_op_domain_list; +}; +} // namespace onnxruntime diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index e9a9ff0cd46c1..f0f0374865087 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -274,12 +274,12 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { const OrtGraph* cur_graph = nullptr; api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph); bool is_subgraph = false; - api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + api->OrtGraph_IsSubgraph(graph_viewer, &is_subgraph); while (is_subgraph) { const OrtGraph* parent_graph = nullptr; api->OrtGraph_GetParentGraph(cur_graph, &parent_graph); cur_graph = parent_graph; - api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + api->OrtGraph_IsSubgraph(graph_viewer, &is_subgraph); } const OrtGraph* main_graph = cur_graph; From 281154132757e1900463f8494aedb0d1bd48f5ce Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Wed, 28 Aug 2024 22:54:51 -0700 Subject: [PATCH 26/81] Add simple CUDA allocators for TRT EP (#21901) Only implemented two simple CUDA allocators (without BFC) for now. Note: several TODO. --- .../core/session/onnxruntime_c_api.h | 1 + onnxruntime/core/framework/provider_adapter.h | 10 +++ samples/tensorRTEp/tensorrt_cuda_allocator.cc | 79 ++++++++++++++++++ samples/tensorRTEp/tensorrt_cuda_allocator.h | 82 +++++++++++++++++++ .../tensorRTEp/tensorrt_execution_provider.cc | 11 +++ 5 files changed, 183 insertions(+) create mode 100644 samples/tensorRTEp/tensorrt_cuda_allocator.cc create mode 100644 samples/tensorRTEp/tensorrt_cuda_allocator.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index fad01dc90d5f3..89c8471344ff7 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -753,6 +753,7 @@ typedef struct OrtExecutionProvider { void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); + int(ORT_API_CALL* CreatePreferredAllocators)(OrtExecutionProvider* this_, OrtAllocator*** ort_allocators); const char* type; OrtCreateStream* create_stream; const OrtDevice* default_device; diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index ce3792cd94fe6..5a750cbbfecec 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -126,6 +126,16 @@ class ExecutionProviderAdapter : public IExecutionProvider { } virtual std::shared_ptr GetKernelRegistry() const override { return kernel_registry_; } + + virtual std::vector CreatePreferredAllocators() override { + std::vector ret; + OrtAllocator** ort_allocators = nullptr; + int cnt = ep_impl_ -> CreatePreferredAllocators(ep_impl_, &ort_allocators); + for (int i = 0; i < cnt; i++) { + ret.push_back(std::make_shared(ort_allocators[i])); + } + return ret; + } private: OrtExecutionProvider* ep_impl_; std::shared_ptr kernel_registry_; // TODO(leca): should be static local diff --git a/samples/tensorRTEp/tensorrt_cuda_allocator.cc b/samples/tensorRTEp/tensorrt_cuda_allocator.cc new file mode 100644 index 0000000000000..044bba043f5a4 --- /dev/null +++ b/samples/tensorRTEp/tensorrt_cuda_allocator.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "tensorrt_cuda_allocator.h" + +void CUDA_RETURN_IF_ERROR(cudaError_t res); + +namespace onnxruntime { +void CUDAAllocator::CheckDevice(bool throw_when_fail) const { +#ifndef NDEBUG + // check device to match at debug build + // if it's expected to change, call cudaSetDevice instead of the check + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + assert(current_device == CUDAAllocator::GetDeviceId()); + } else if (throw_when_fail) { + CUDA_RETURN_IF_ERROR(cuda_err); + } +#else + ORT_UNUSED_PARAMETER(throw_when_fail); +#endif +} + +void CUDAAllocator::SetDevice(bool throw_when_fail) const { + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + int allocator_device_id = CUDAAllocator::GetDeviceId(); + if (current_device != allocator_device_id) { + cuda_err = cudaSetDevice(allocator_device_id); + } + } + + if (cuda_err != cudaSuccess && throw_when_fail) { + CUDA_RETURN_IF_ERROR(cuda_err); + } +} + +void* CUDAAllocator::Alloc(size_t size) { + SetDevice(true); + CheckDevice(true); + void* p = nullptr; + if (size > 0) { + // BFCArena was updated recently to handle the exception and adjust the request size + CUDA_RETURN_IF_ERROR(cudaMalloc((void**)&p, size)); + } + return p; +} + +void CUDAAllocator::Free(void* p) { + SetDevice(false); + CheckDevice(false); // ignore CUDA failure when free + cudaFree(p); // do not throw error since it's OK for cudaFree to fail during shutdown +} + +const OrtMemoryInfo* CUDAAllocator::Info() const { + return mem_info_; +} + +void* CUDAPinnedAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + CUDA_RETURN_IF_ERROR(cudaMallocHost((void**)&p, size)); + } + return p; +} + +void CUDAPinnedAllocator::Free(void* p) { + CUDA_RETURN_IF_ERROR(cudaFreeHost(p)); +} + +const OrtMemoryInfo* CUDAPinnedAllocator::Info() const { + return mem_info_; +} + +} // namespace onnxruntime diff --git a/samples/tensorRTEp/tensorrt_cuda_allocator.h b/samples/tensorRTEp/tensorrt_cuda_allocator.h new file mode 100644 index 0000000000000..cf360536561e3 --- /dev/null +++ b/samples/tensorRTEp/tensorrt_cuda_allocator.h @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "core/session/onnxruntime_c_api.h" +#define ORT_API_MANUAL_INIT +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { + +// Following names are originally defined in allocator.h +constexpr const char* CUDA_ALLOCATOR = "Cuda"; +constexpr const char* CUDA_PINNED_ALLOCATOR = "CudaPinned"; + +using DeviceId = int16_t; + +struct CUDAAllocator : OrtAllocator { + CUDAAllocator(DeviceId device_id, const char* name = onnxruntime::CUDA_ALLOCATOR) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; + OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + + device_id_ = device_id; + + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + api->CreateMemoryInfo(name, + OrtAllocatorType::OrtDeviceAllocator, + static_cast(device_id), + OrtMemType::OrtMemTypeDefault, + &mem_info_); + } + //~CUDAAllocator(); + + void* Alloc(size_t size); + void Free(void* p); + const OrtMemoryInfo* Info() const; + DeviceId GetDeviceId() const { return device_id_; }; + + private: + CUDAAllocator(const CUDAAllocator&) = delete; + CUDAAllocator& operator=(const CUDAAllocator&) = delete; + + void CheckDevice(bool throw_when_fail) const; + void SetDevice(bool throw_when_fail) const; + + DeviceId device_id_; + OrtMemoryInfo* mem_info_ = nullptr; +}; + +struct CUDAPinnedAllocator : OrtAllocator { + CUDAPinnedAllocator(const char* name = onnxruntime::CUDA_PINNED_ALLOCATOR) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; + OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + api->CreateMemoryInfo(name, + OrtAllocatorType::OrtDeviceAllocator, + 0 /* CPU device always with id 0 */, + OrtMemType::OrtMemTypeDefault, + &mem_info_); + } + //~CUDAPinnedAllocator(); + + void* Alloc(size_t size); + void Free(void* p); + const OrtMemoryInfo* Info() const; + + DeviceId GetDeviceId() const { return device_id_; }; + + private: + CUDAPinnedAllocator(const CUDAPinnedAllocator&) = delete; + CUDAPinnedAllocator& operator=(const CUDAPinnedAllocator&) = delete; + + DeviceId device_id_ = 0; + OrtMemoryInfo* mem_info_ = nullptr; +}; + + +} // namespace onnxruntime diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index bb8aa4d6c1e8d..cf37450c7e219 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -6,6 +6,7 @@ #include "core/session/onnxruntime_cxx_api.h" // TODO(leca): we should be able to use cxx APIs which are built upon C API #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" +#include "tensorrt_cuda_allocator.h" #include "onnx_ctx_model_helper.h" void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } @@ -1515,6 +1516,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const return nullptr; }; + OrtExecutionProvider::CreatePreferredAllocators = [](OrtExecutionProvider* this_, OrtAllocator*** ort_allocators) -> int { + TensorrtExecutionProvider* p = static_cast(this_); + int ret = 2; + *ort_allocators = new OrtAllocator * [2]; + (*ort_allocators)[0] = new CUDAAllocator(static_cast(p->device_id_)); // TODO(Chi): Add BFC Arena implementation + (*ort_allocators)[1] = new CUDAPinnedAllocator(); + // TODO(Chi): Free allocators' memory + return ret; + }; + type = ep_type; create_stream = new OrtCreateStream(); create_stream->CreateStreamFunc = [](const OrtDevice* device) -> void* { From c97b19f2d5180909ad3314443b0cdaf79a4b9244 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Thu, 29 Aug 2024 20:27:28 +0800 Subject: [PATCH 27/81] add constructor for tensorrt ep and refine GetCapability (#21914) --- .../tensorRTEp/tensorrt_execution_provider.cc | 506 +++++++++++++++++- .../tensorRTEp/tensorrt_execution_provider.h | 44 ++ .../tensorrt_execution_provider_info.cc | 14 +- .../tensorrt_execution_provider_info.h | 10 +- 4 files changed, 556 insertions(+), 18 deletions(-) diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index cf37450c7e219..cfd366732c896 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -9,6 +9,18 @@ #include "tensorrt_cuda_allocator.h" #include "onnx_ctx_model_helper.h" +#ifdef _WIN32 +#include +#define LIBTYPE HINSTANCE +#define OPENLIB(libname) LoadLibrary(libname) +#define LIBFUNC(lib, fn) GetProcAddress((lib), (fn)) +#else +#include +#define LIBTYPE void* +#define OPENLIB(libname) dlopen((libname), RTLD_LAZY) +#define LIBFUNC(lib, fn) dlsym((lib), (fn)) +#endif + void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } namespace onnxruntime { @@ -1398,14 +1410,78 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const // The purpose is to make control flow op as well as its subgraphs run on TRT. // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. if (p->IsSubGraphOfControlFlowOp(graph) && p->IsSubGraphFullySupported(supported_nodes_vector, number_of_ort_nodes)) { - bool all_subgraphs_are_supported = true; + bool all_subgraphs_are_supported = true; + + // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. + // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT. + const OrtNode* parent_node = nullptr; + api->OrtGraph_GetParenNode(graph, &parent_node); + const char* parent_node_op_type = nullptr; + api->OrtNode_GetOpType(parent_node, &parent_node_op_type); + if (strcmp(parent_node_op_type, "If") == 0) { + all_subgraphs_are_supported = false; + SubGraphCollection_t subgraph_supported_nodes_vector; + size_t subgraph_count = 0; + const OrtGraphViewer** subgraphs = nullptr; + api->OrtNode_GetSubgraphs(parent_node, &subgraph_count, &subgraphs); + const OrtGraph* origin_graph = nullptr; + api->OrtGraph_GetOrtGraph(graph, &origin_graph); + for (size_t i = 0; i < subgraph_count; i++) { + const OrtGraph* subgraph = nullptr; + api->OrtGraph_GetOrtGraph(subgraphs[i], &subgraph); + if (subgraph == origin_graph) { + continue; + } + const int number_of_ort_subgraph_nodes = api->OrtGraph_NumberOfNodes(subgraphs[i]); + std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); + std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); + SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; + bool subgraph_early_termination = false; + + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (number_of_ort_subgraph_nodes == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], "TensorrtExecutionProvider")) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. + // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) + else if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "")) { + all_subgraphs_are_supported = false; + break; + } - if (all_subgraphs_are_supported) { - for (const auto& group : supported_nodes_vector) { + // Another subgraph of "If" control flow has not yet been parsed by GetCapability. + subgraph_supported_nodes_vector = p->GetSupportedList(parser_subgraph_nodes_vector, 0, p->max_partition_iterations_, subgraphs[i], &subgraph_early_termination); + all_subgraphs_are_supported = p->IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); + break; + } + } + + + if (all_subgraphs_are_supported) { + // for (const auto& group : supported_nodes_vector) { + // if (!group.first.empty()) { + // *cnt = group.first.size(); + // *indexed_sub_graph = new OrtIndexedSubGraph* [group.first.size()]; + // int i = 0; + // for (const auto& index : group.first) { + // (*indexed_sub_graph)[i]->node_index_len = 1; + // (*indexed_sub_graph)[i]->node_index = new size_t [(*indexed_sub_graph)[i]->node_index_len]; + // (*indexed_sub_graph)[i]->node_index[0] = node_index[index]; + // i++; + // } + // } + // } + // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + return; } - return; - } } int number_of_trt_nodes = 0, subgraph_index = 0; @@ -1522,7 +1598,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const *ort_allocators = new OrtAllocator * [2]; (*ort_allocators)[0] = new CUDAAllocator(static_cast(p->device_id_)); // TODO(Chi): Add BFC Arena implementation (*ort_allocators)[1] = new CUDAPinnedAllocator(); - // TODO(Chi): Free allocators' memory + // TODO(Chi): Free allocators' memory return ret; }; @@ -1537,6 +1613,424 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, 0, &default_device); info_ = TensorrtExecutionProviderInfo::FromProviderOptions(ep_info); + + std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; + + // incase the EP context is dumped the engine cache has to be enabled + auto enable_engine_cache_for_ep_context_model = [this]() { + if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + engine_cache_enable_ = true; + } + }; + + // Get environment variables + if (info_.has_trt_options) { + max_partition_iterations_ = info_.max_partition_iterations; + min_subgraph_size_ = info_.min_subgraph_size; + max_workspace_size_ = info_.max_workspace_size; + fp16_enable_ = info_.fp16_enable; + int8_enable_ = info_.int8_enable; + if (int8_enable_) { + int8_calibration_cache_name_ = info_.int8_calibration_table_name; + int8_use_native_tensorrt_calibration_table_ = info_.int8_use_native_calibration_table; + } + if (fp16_enable_ || int8_enable_) { // DLA can only be enabled with FP16 or INT8 + dla_enable_ = info_.dla_enable; + dla_core_ = info_.dla_core; + } + dump_subgraphs_ = info_.dump_subgraphs; + engine_cache_enable_ = info_.engine_cache_enable; + weight_stripped_engine_enable_ = info_.weight_stripped_engine_enable; + onnx_model_folder_path_ = info_.onnx_model_folder_path; + timing_cache_enable_ = info_.timing_cache_enable; + force_timing_cache_match_ = info_.force_timing_cache; + detailed_build_log_ = info_.detailed_build_log; + dump_ep_context_model_ = info_.dump_ep_context_model; + ep_context_file_path_ = info_.ep_context_file_path; + ep_context_embed_mode_ = info_.ep_context_embed_mode; + enable_engine_cache_for_ep_context_model(); + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + cache_path_ = info_.engine_cache_path; + cache_prefix_ = info_.engine_cache_prefix; + } + // use a more global cache if given + if (timing_cache_enable_) { + if (!info_.timing_cache_path.empty()) { + global_cache_path_ = info_.timing_cache_path; + } else { + global_cache_path_ = cache_path_; + } + } + engine_decryption_enable_ = info_.engine_decryption_enable; + if (engine_decryption_enable_) { + engine_decryption_lib_path_ = info_.engine_decryption_lib_path; + } + force_sequential_engine_build_ = info_.force_sequential_engine_build; + context_memory_sharing_enable_ = info_.context_memory_sharing_enable; + if (fp16_enable_) { + layer_norm_fp32_fallback_ = info_.layer_norm_fp32_fallback; + } + build_heuristics_enable_ = info_.build_heuristics_enable; + sparsity_enable_ = info_.sparsity_enable; + builder_optimization_level_ = info_.builder_optimization_level; + auxiliary_streams_ = info_.auxiliary_streams; + tactic_sources_ = info_.tactic_sources; + profile_min_shapes = info_.profile_min_shapes; + profile_max_shapes = info_.profile_max_shapes; + profile_opt_shapes = info_.profile_opt_shapes; + cuda_graph_enable_ = info_.cuda_graph_enable; + engine_hw_compatible_ = info_.engine_hw_compatible; + } else { + try { + // const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); + // if (!max_partition_iterations_env.empty()) { + // max_partition_iterations_ = std::stoi(max_partition_iterations_env); + // } + + // const std::string min_subgraph_size_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMinSubgraphSize); + // if (!min_subgraph_size_env.empty()) { + // min_subgraph_size_ = std::stoi(min_subgraph_size_env); + // } + + // const std::string max_workspace_size_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxWorkspaceSize); + // if (!max_workspace_size_env.empty()) { + // max_workspace_size_ = std::stoull(max_workspace_size_env); + // } + + // const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kFP16Enable); + // if (!fp16_enable_env.empty()) { + // fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); + // } + + // const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8Enable); + // if (!int8_enable_env.empty()) { + // int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + // } + + // if (int8_enable_) { + // const std::string int8_calibration_cache_name_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8CalibrationTableName); + // if (!int8_calibration_cache_name_env.empty()) { + // int8_calibration_cache_name_ = int8_calibration_cache_name_env; + // } + + // const std::string int8_use_native_tensorrt_calibration_table_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8UseNativeTensorrtCalibrationTable); + // if (!int8_use_native_tensorrt_calibration_table_env.empty()) { + // int8_use_native_tensorrt_calibration_table_ = (std::stoi(int8_use_native_tensorrt_calibration_table_env) == 0 ? false : true); + // } + // } + + // if (fp16_enable_ || int8_enable_) { // DLA can only be enabled with FP16 or INT8 + // const std::string dla_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDLAEnable); + // if (!dla_enable_env.empty()) { + // dla_enable_ = (std::stoi(dla_enable_env) == 0 ? false : true); + // } + + // if (dla_enable_) { + // const std::string dla_core_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDLACore); + // if (!dla_core_env.empty()) { + // dla_core_ = std::stoi(dla_core_env); + // } + // } + // } + + // const std::string dump_subgraphs_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpSubgraphs); + // if (!dump_subgraphs_env.empty()) { + // dump_subgraphs_ = (std::stoi(dump_subgraphs_env) == 0 ? false : true); + // } + + // const std::string engine_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCacheEnable); + // if (!engine_cache_enable_env.empty()) { + // engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true); + // } + + // const std::string weight_stripped_engine_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kWeightStrippedEngineEnable); + // if (!weight_stripped_engine_enable_env.empty()) { + // weight_stripped_engine_enable_ = std::stoi(weight_stripped_engine_enable_env) != 0; + // } + + // const std::string onnx_model_folder_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOnnxModelFolderPath); + // if (!onnx_model_folder_path_env.empty()) { + // onnx_model_folder_path_ = onnx_model_folder_path_env; + // } + + // const std::string timing_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCacheEnable); + // if (!timing_cache_enable_env.empty()) { + // timing_cache_enable_ = (std::stoi(timing_cache_enable_env) == 0 ? false : true); + // } + + // const std::string detailed_build_log_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDetailedBuildLog); + // if (!detailed_build_log_env.empty()) { + // detailed_build_log_ = (std::stoi(detailed_build_log_env) == 0 ? false : true); + // } + + // const std::string timing_force_match_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kForceTimingCache); + // if (!timing_force_match_env.empty()) { + // force_timing_cache_match_ = (std::stoi(timing_force_match_env) == 0 ? false : true); + // } + + // const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel); + // if (!dump_ep_context_model_env.empty()) { + // dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true); + // } + + // const std::string ep_context_file_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable); + // if (!ep_context_file_path_env.empty()) { + // ep_context_file_path_ = ep_context_file_path_env; + // } + + // const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode); + // if (!ep_context_embed_mode_env.empty()) { + // ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); + // } + // // incase the EP context is dumped the engine cache has to be enabled + // if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + // engine_cache_enable_ = true; + // } + + // enable_engine_cache_for_ep_context_model(); + + // if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + // const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath); + // cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath); + // cache_prefix_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePrefix); + // if (!engine_cache_path.empty() && cache_path_.empty()) { + // cache_path_ = engine_cache_path; + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_ENGINE_CACHE_PATH is deprecated! Please use ORT_TENSORRT_CACHE_PATH to specify engine cache path"; + // } + // } + // if (timing_cache_enable_) { + // std::string timing_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCachePath); + // // use a more global cache if given + // if (!timing_cache_path.empty()) { + // global_cache_path_ = timing_cache_path; + // } else { + // global_cache_path_ = cache_path_; + // } + // } + + // const std::string engine_decryption_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionEnable); + // if (!engine_decryption_enable_env.empty()) { + // engine_decryption_enable_ = (std::stoi(engine_decryption_enable_env) == 0 ? false : true); + // } + + // if (engine_decryption_enable_) { + // engine_decryption_lib_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath); + // } + + // const std::string force_sequential_engine_build_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kForceSequentialEngineBuild); + // if (!force_sequential_engine_build_env.empty()) { + // force_sequential_engine_build_ = (std::stoi(force_sequential_engine_build_env) == 0 ? false : true); + // } + + // const std::string context_memory_sharing_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kContextMemorySharingEnable); + // if (!context_memory_sharing_enable_env.empty()) { + // context_memory_sharing_enable_ = (std::stoi(context_memory_sharing_enable_env) == 0 ? false : true); + // } + + // const std::string layer_norm_fp32_fallback_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kLayerNormFP32Fallback); + // if (!layer_norm_fp32_fallback_env.empty()) { + // layer_norm_fp32_fallback_ = (std::stoi(layer_norm_fp32_fallback_env) == 0 ? false : true); + // } + + // const std::string build_heuristics_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kBuildHeuristics); + // if (!build_heuristics_env.empty()) { + // build_heuristics_enable_ = (std::stoi(build_heuristics_env) == 0 ? false : true); + // } + + // const std::string sparsity_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kSparsityEnable); + // if (!sparsity_enable_env.empty()) { + // sparsity_enable_ = (std::stoi(sparsity_enable_env) == 0 ? false : true); + // } + + // const std::string builder_optimization_level_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kBuilderOptimizationLevel); + // if (!builder_optimization_level_env.empty()) { + // builder_optimization_level_ = std::stoi(builder_optimization_level_env); + // } + + // const std::string auxiliary_streams_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kAuxiliaryStreams); + // if (!auxiliary_streams_env.empty()) { + // auxiliary_streams_ = std::stoi(auxiliary_streams_env); + // } + + // const std::string tactic_sources_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTacticSources); + // if (!tactic_sources_env.empty()) { + // tactic_sources_ = tactic_sources_env; + // } + + // profile_min_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMinShapes); + // profile_max_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMaxShapes); + // profile_opt_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesOptShapes); + + // const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCudaGraphEnable); + // if (!cuda_graph_enable_env.empty()) { + // cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); + // } + + } catch (const std::invalid_argument& ex) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what(); + } catch (const std::out_of_range& ex) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Out Of Range Error (from environment variables): " << ex.what(); + } catch (...) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Unknown Exception (from environment variables)"; + } + } + + // Validate setting + if (max_partition_iterations_ <= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_max_partition_iterations must be a positive integer value. Set it to 1000"; + max_partition_iterations_ = 1000; + } + if (min_subgraph_size_ <= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_min_subgraph_size must be a positive integer value. Set it to 1"; + min_subgraph_size_ = 1; + } + if (max_workspace_size_ <= 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_max_workspace_size must be a positive integer value. Set it to 1073741824 (1GB)"; + max_workspace_size_ = 1 << 30; + } + if (dla_core_ < 0) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_dla_core must be a non-negative integer value. Set it to 0"; + dla_core_ = 0; + } + + // If ep_context_file_path_ is provided as a directory, create it if it's not existed + if (dump_ep_context_model_ && !ep_context_file_path_.empty() && std::filesystem::path(ep_context_file_path_).extension().empty() && !std::filesystem::is_directory(ep_context_file_path_)) { + if (!std::filesystem::create_directory(ep_context_file_path_)) { + throw std::runtime_error("Failed to create directory " + ep_context_file_path_); + } + } + + // If dump_ep_context_model_ is enable, TRT EP forces cache_path_ to be the relative path of ep_context_file_path_. + // For example, + // - original cache path = "engine_cache_dir" -> new cache path = "./context_model_dir/engine_cache_dir" + // - original cache path = "" -> new cache path = "./context_model_dir" + // The new cache path will be saved as the "ep_cache_context" node attritue of the EP context node. + // For security reason, it needs to make sure the engine cache is saved inside context model directory. + if (dump_ep_context_model_ && engine_cache_enable_) { + if (IsAbsolutePath(cache_path_)) { + // LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, the trt_engine_cache_path should be set with a relative path, but it is an absolute path: " << cache_path_; + } + if (IsRelativePathToParentPath(cache_path_)) { + // LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory."; + } + + // Engine cache relative path to context model directory. + // It's used when dumping the "ep_cache_context" node attribute. + engine_cache_relative_path_to_context_model_dir = cache_path_; + + // Make cache_path_ to be the relative path of ep_context_file_path_ + cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string(); + } + + // Hardware compatibility: pre-check on environment + if (engine_cache_enable_ && engine_hw_compatible_) { +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + if (std::stoi(compute_capability_) < 80) { + // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled as GPU arch < 80. "; + engine_hw_compatible_ = false; + } else if (std::stoi(compute_capability_) == 87) { + // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled on Jetson Orin. "; + engine_hw_compatible_ = false; + } +#else + // LOGS_DEFAULT(WARNING) << "Engine hardware compatibility cannot be enabled as TRT < 8.6. "; + engine_hw_compatible_ = false; +#endif + } + + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + if (!cache_path_.empty() && !fs::is_directory(cache_path_)) { + if (!fs::create_directory(cache_path_)) { + throw std::runtime_error("Failed to create directory " + cache_path_); + } + } + if (!global_cache_path_.empty() && !fs::is_directory(global_cache_path_)) { + if (!fs::create_directory(global_cache_path_)) { + throw std::runtime_error("Failed to create directory " + global_cache_path_); + } + } + } + + if (engine_decryption_enable_) { + LIBTYPE handle = OPENLIB(engine_decryption_lib_path_.c_str()); + if (handle == nullptr) { + // TODO(yang) + // ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + // "TensorRT EP could not open shared library from " + engine_decryption_lib_path_)); + } + engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); + engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); + if (engine_decryption_ == nullptr) { + // TODO(yang) + // ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + // "TensorRT EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); + } + } + + if (int8_enable_) { + int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); + } + + /* + * Parse explicit min/max/opt profile shapes from provider options. + * + * The format of min/max/opt profile shapes is defined as below: + * "input1:dim1xdim2...,input2:dim1xdim2...,...,input1:dim3xdim4...,input2:dim3xdim4...,..." + * + * (Note: if multiple shapes with same input name are specified, TRT EP will consider them as multiple profiles. + * Please refer to ParserProfileShapes() for more details) + * + */ + bool status = true; + // if (status) { + // status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_); + // if (!status) { + // profile_min_shapes_.clear(); + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // } + // } + + // if (status) { + // status = ParseProfileShapes(profile_max_shapes, profile_max_shapes_); + // if (!status) { + // profile_max_shapes_.clear(); + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // } + // } + + // if (status) { + // status = ParseProfileShapes(profile_opt_shapes, profile_opt_shapes_); + // if (!status) { + // profile_opt_shapes_.clear(); + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + // } + // } + + // if (status) { + // status = ValidateProfileShapes(profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + // if (!status) { + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile."; + // // LOGS_DEFAULT(WARNING) << "[TensorRT EP] TRT EP will implicitly create optimization profiles based on input tensor for you."; + // profile_min_shapes_.clear(); + // profile_max_shapes_.clear(); + // profile_opt_shapes_.clear(); + // } + // } + + // cuda graph: + // cudaStreamSynchronize() is not allowed in cuda graph capture. + // + // external stream: + // If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently. + // So, no need to synchronize different streams after enqueueV3. + if (cuda_graph_enable_ || external_stream_) { + sync_stream_after_enqueue_ = false; + } + + { + // auto lock = GetApiLock(); // TODO(leca) + runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); + } } TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 2a12de3c10f9c..52eca02531743 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -14,6 +14,50 @@ #endif namespace onnxruntime { + +namespace tensorrt_env_vars { +static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_ITERATIONS"; +static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE"; +static const std::string kMaxWorkspaceSize = "ORT_TENSORRT_MAX_WORKSPACE_SIZE"; +static const std::string kFP16Enable = "ORT_TENSORRT_FP16_ENABLE"; +static const std::string kINT8Enable = "ORT_TENSORRT_INT8_ENABLE"; +static const std::string kINT8CalibrationTableName = "ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"; +static const std::string kINT8UseNativeTensorrtCalibrationTable = "ORT_TENSORRT_INT8_USE_NATIVE_CALIBRATION_TABLE"; +static const std::string kDLAEnable = "ORT_TENSORRT_DLA_ENABLE"; +static const std::string kDLACore = "ORT_TENSORRT_DLA_CORE"; +static const std::string kDumpSubgraphs = "ORT_TENSORRT_DUMP_SUBGRAPHS"; +static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE"; +static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH"; +static const std::string kWeightStrippedEngineEnable = "ORT_TENSORRT_WEIGHT_STRIPPED_ENGINE_ENABLE"; +static const std::string kOnnxModelFolderPath = "ORT_TENSORRT_ONNX_MODEL_FOLDER_PATH"; +// As a timing cache can be used across multiple ONNX files it makes sense to have a separate cache path +static const std::string kTimingCachePath = "ORT_TENSORRT_GLOBAL_CACHE_PATH"; +static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE"; +static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH"; +static const std::string kForceSequentialEngineBuild = "ORT_TENSORRT_FORCE_SEQUENTIAL_ENGINE_BUILD"; +static const std::string kContextMemorySharingEnable = "ORT_TENSORRT_CONTEXT_MEMORY_SHARING_ENABLE"; +static const std::string kLayerNormFP32Fallback = "ORT_TENSORRT_LAYER_NORM_FP32_FALLBACK"; +static const std::string kTimingCacheEnable = "ORT_TENSORRT_TIMING_CACHE_ENABLE"; +static const std::string kForceTimingCache = "ORT_TENSORRT_FORCE_TIMING_CACHE_ENABLE"; +static const std::string kDetailedBuildLog = "ORT_TENSORRT_DETAILED_BUILD_LOG_ENABLE"; +static const std::string kBuildHeuristics = "ORT_TENSORRT_BUILD_HEURISTICS_ENABLE"; +static const std::string kSparsityEnable = "ORT_TENSORRT_SPARSITY_ENABLE"; +static const std::string kBuilderOptimizationLevel = "ORT_TENSORRT_BUILDER_OPTIMIZATION_LEVEL"; +static const std::string kAuxiliaryStreams = "ORT_TENSORRT_AUXILIARY_STREAMS"; +static const std::string kTacticSources = "ORT_TENSORRT_TACTIC_SOURCES"; +static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_PATHS"; +static const std::string kProfilesMinShapes = "ORT_TENSORRT_PROFILE_MIN_SHAPES"; +static const std::string kProfilesMaxShapes = "ORT_TENSORRT_PROFILE_MAX_SHAPES"; +static const std::string kProfilesOptShapes = "ORT_TENSORRT_PROFILE_OPT_SHAPES"; +static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE"; +static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL"; +static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE"; +static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE"; +static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX"; +// Old env variable for backward compatibility +static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; +} // namespace tensorrt_env_vars + using HashValue = uint64_t; using AllocateFunc = void* (*)(void*, size_t, size_t); using DestroyFunc = void (*)(void*, void*); diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.cc b/samples/tensorRTEp/tensorrt_execution_provider_info.cc index a6caab6642662..55f8d331ce078 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_info.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.cc @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -//#include "core/providers/tensorrt/tensorrt_execution_provider_info.h" -//#include "core/providers/tensorrt/tensorrt_provider_options.h" -// -//#include "core/common/make_string.h" -//#include "core/common/parse_string.h" -//#include "core/framework/provider_options_utils.h" -//#include "core/providers/cuda/cuda_common.h" +#include "tensorrt_execution_provider_info.h" +#include "core/providers/tensorrt/tensorrt_provider_options.h" + +#include "core/common/make_string.h" +#include "core/common/parse_string.h" +#include "core/framework/provider_options_utils.h" +// #include "onnxruntime/core/providers/cuda/cuda_common.h" namespace onnxruntime { namespace tensorrt { diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.h b/samples/tensorRTEp/tensorrt_execution_provider_info.h index f64e3a972e01b..4a021a032c825 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_info.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.h @@ -3,11 +3,11 @@ #pragma once -//#include "core/framework/ortdevice.h" -//#include "core/framework/provider_options.h" -//#include "core/framework/framework_provider_common.h" -//#include "core/session/onnxruntime_c_api.h" -//#include "core/framework/library_handles.h" +#include "core/framework/ortdevice.h" +#include "core/framework/provider_options.h" +#include "core/framework/framework_provider_common.h" +#include "core/session/onnxruntime_c_api.h" +// #include "core/framework/library_handles.h" #define TRT_DEFAULT_OPTIMIZER_LEVEL 3 From 36f97b53edd27c424ae6e942b90d6e3f4008d2c2 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 29 Aug 2024 18:11:11 +0000 Subject: [PATCH 28/81] relu can work on out tree TRT now --- samples/c_test/test.cpp | 9 ++++----- samples/tensorRTEp/tensorrt_execution_provider.cc | 3 ++- samples/tensorRTEp/tensorrt_execution_provider_info.cc | 6 ------ samples/tensorRTEp/tensorrt_execution_provider_info.h | 5 +---- 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 5b03221ef9d41..0bc54782169e5 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -23,11 +23,10 @@ void TestTensorRTEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - THROW_ON_ERROR(g_ort->CreateCUDAProviderOptions(&cuda_options)); - THROW_ON_ERROR(g_ort->SessionOptionsAppendExecutionProvider_CUDA_V2(so, cuda_options)); - - g_ort->ReleaseCUDAProviderOptions(cuda_options); +// OrtCUDAProviderOptionsV2* cuda_options = nullptr; +// THROW_ON_ERROR(g_ort->CreateCUDAProviderOptions(&cuda_options)); +// THROW_ON_ERROR(g_ort->SessionOptionsAppendExecutionProvider_CUDA_V2(so, cuda_options)); +// g_ort->ReleaseCUDAProviderOptions(cuda_options); } void TestOriginalTensorRTEp(const OrtApi* g_ort, OrtSessionOptions* so) { diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index cfd366732c896..be7d2a9375a44 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include #include "core/session/onnxruntime_cxx_api.h" // TODO(leca): we should be able to use cxx APIs which are built upon C API #include "tensorrt_execution_provider.h" @@ -3269,7 +3270,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // IncrementRegularRunCountBeforeGraphCapture(); // } // } - + std::cout << "end of ComputeFunc in TRTEp's CreateNodeComputeInfoFromGraph()\n"; return nullptr; }; diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.cc b/samples/tensorRTEp/tensorrt_execution_provider_info.cc index 55f8d331ce078..5349e4c879e81 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_info.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.cc @@ -2,12 +2,6 @@ // Licensed under the MIT License. #include "tensorrt_execution_provider_info.h" -#include "core/providers/tensorrt/tensorrt_provider_options.h" - -#include "core/common/make_string.h" -#include "core/common/parse_string.h" -#include "core/framework/provider_options_utils.h" -// #include "onnxruntime/core/providers/cuda/cuda_common.h" namespace onnxruntime { namespace tensorrt { diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.h b/samples/tensorRTEp/tensorrt_execution_provider_info.h index 4a021a032c825..92a14daf539e8 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_info.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.h @@ -3,11 +3,8 @@ #pragma once -#include "core/framework/ortdevice.h" +#include #include "core/framework/provider_options.h" -#include "core/framework/framework_provider_common.h" -#include "core/session/onnxruntime_c_api.h" -// #include "core/framework/library_handles.h" #define TRT_DEFAULT_OPTIMIZER_LEVEL 3 From 2fc7aacc7eeff234e8ec182defde2edc58ec2872 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Sat, 31 Aug 2024 00:30:51 +0000 Subject: [PATCH 29/81] rebuild graph proto from scratch with the information needed from graph C API --- samples/tensorRTEp/CMakeLists.txt | 20 +++---- .../tensorRTEp/tensorrt_execution_provider.cc | 54 +++++++++++++++++++ 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index 21ece12846d42..f3a4a35371925 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -18,19 +18,19 @@ target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "/usr/local/cuda/include" "/home/leca/TensorRT-10.0.1.6/include" "../../build/Linux/Debug/_deps/flatbuffers-src/include" - "../../build/Linux/Debug/_deps/gsl-src/include") -# "../../build/Linux/Debug/_deps/onnx-src" -# "../../build/Linux/Debug/_deps/onnx-build" -# "../../build/Linux/Debug/_deps/protobuf-src/src") -# + "../../build/Linux/Debug/_deps/gsl-src/include" + "../../build/Linux/Debug/_deps/onnx-src" + "../../build/Linux/Debug/_deps/onnx-build" + "../../build/Linux/Debug/_deps/protobuf-src/src") + ## looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer.so" "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer_plugin.so" "/home/leca/TensorRT-10.0.1.6/lib/libnvonnxparser.so" "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/flatbuffers-build/libflatbuffers.a" - CUDA::cudart) -# "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" -# "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" -# "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" -# "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotocd.a") + CUDA::cudart + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" + "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotocd.a") diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index be7d2a9375a44..3b67578223d3f 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -9,6 +9,7 @@ #include "tensorrt_execution_provider_utils.h" #include "tensorrt_cuda_allocator.h" #include "onnx_ctx_model_helper.h" +#include "onnx/onnx_pb.h" #ifdef _WIN32 #include @@ -3617,6 +3618,59 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (group.second) { nodes_list_output.push_back(group); } else { + onnx::ModelProto m; + m.set_ir_version(3); + onnx::OperatorSetIdProto* p = m.add_opset_import(); + p->set_domain(""); + p->set_version(10); + onnx::GraphProto* g = m.mutable_graph(); + for (size_t i = 0; i < nodes_count; i++) { + onnx::NodeProto* n = g->add_node(); + const OrtNode* node = nullptr; + api_->OrtGraph_GetOrtNode(graph, i, &node); + + const char* op_type = nullptr; + api_->OrtNode_GetOpType(node, &op_type); + n->set_op_type(op_type); + + const char* name = nullptr; + api_->OrtNode_GetName(node, &name); + n->set_name(name); + + // TODO(leca): Implicit input? & attributes + size_t input_size = 0; + api_->OrtNode_GetInputSize(node, &input_size); + for (size_t j = 0; j < input_size; j++) { + const char* jth_input_name = nullptr; + api_->OrtNode_GetIthInputName(node, j, &jth_input_name); + n->add_input(jth_input_name, strlen(jth_input_name)); + } + + size_t output_size = 0; + api_->OrtNode_GetOutputSize(node, &output_size); + for (size_t j = 0; j < output_size; j++) { + const char* jth_output_name = nullptr; + api_->OrtNode_GetIthOutputName(node, j, &jth_output_name); + n->add_output(jth_output_name, strlen(jth_output_name)); + } + } + + // TODO(leca): set_elem_type, set_dim_value for graph input and output + size_t graph_inputs = 0; + const char** graph_input_names = nullptr; + api_->OrtGraph_GetInputsIncludingInitializers(graph, &graph_inputs, &graph_input_names); + for (size_t i = 0; i < graph_inputs; i++) { + onnx::ValueInfoProto* input = g->add_input(); + input->set_name(graph_input_names[i]); + } + + size_t graph_outputs = api_->OrtGraph_GetOutputSize(graph); + for (size_t i = 0; i < graph_outputs; i++) { + onnx::ValueInfoProto* output = g->add_output(); + output->set_name(api_->OrtGraph_GetIthOutputName(graph, i)); + output->mutable_type()->mutable_tensor_type()->set_elem_type(api_->OrtGraph_GetIthOutputElemType(graph, i)); + } + // auto model_build = graph.CreateModel(*GetLogger()); // auto& graph_build = model_build->MainGraph(); // bool has_control_flow_op = false; From 4ad6993dceab090fe42546f0f35b0dcb165c1744 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:48:43 +0800 Subject: [PATCH 30/81] complete the GetCapability (#21956) --- .../tensorRTEp/tensorrt_execution_provider.cc | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 3b67578223d3f..b59933086e5f2 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1468,19 +1468,22 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const if (all_subgraphs_are_supported) { - // for (const auto& group : supported_nodes_vector) { - // if (!group.first.empty()) { - // *cnt = group.first.size(); - // *indexed_sub_graph = new OrtIndexedSubGraph* [group.first.size()]; - // int i = 0; - // for (const auto& index : group.first) { - // (*indexed_sub_graph)[i]->node_index_len = 1; - // (*indexed_sub_graph)[i]->node_index = new size_t [(*indexed_sub_graph)[i]->node_index_len]; - // (*indexed_sub_graph)[i]->node_index[0] = node_index[index]; - // i++; - // } - // } - // } + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + for (const auto& index : group.first) { + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->node_index_len = 1; + sub_graph->node_index = new size_t [sub_graph->node_index_len]; + sub_graph->node_index[0] = nodes_index[index]; + cache.push_back(sub_graph.release()); + } + } + } + *cnt = cache.size(); + *indexed_sub_graph = new OrtIndexedSubGraph* [*cnt]; + for (size_t i = 0; i < *cnt; i++) { + (*indexed_sub_graph)[i] = cache[i]; + } // LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; return; } From 53c736f1011dc698db70078d8efc19e74ecfbb81 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 4 Sep 2024 18:21:30 +0000 Subject: [PATCH 31/81] Chi's fix and reorder ep for registering shared resource --- .../core/session/onnxruntime_c_api.h | 2 +- onnxruntime/core/framework/provider_adapter.h | 13 +++++++--- onnxruntime/core/framework/session_state.cc | 26 ++++++++++++++++--- samples/c_test/test.cpp | 1 + 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 89c8471344ff7..80ecf138a9fd8 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -745,7 +745,7 @@ typedef struct OrtNodeComputeInfo { typedef struct OrtExecutionProvider { #ifdef __cplusplus - OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr}, + OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, CreatePreferredAllocators{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr}, extra_param_for_create_state_func{nullptr}, extra_param_for_compute_func{nullptr} {} #endif void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 5a750cbbfecec..7bae1c91fe87e 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -4,6 +4,9 @@ #pragma once #include "core/session/onnxruntime_c_api.h" #include "core/framework/compute_capability.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/kernel_registry.h" +#include "core/session/allocator_adapters.h" namespace onnxruntime { @@ -129,10 +132,12 @@ class ExecutionProviderAdapter : public IExecutionProvider { virtual std::vector CreatePreferredAllocators() override { std::vector ret; - OrtAllocator** ort_allocators = nullptr; - int cnt = ep_impl_ -> CreatePreferredAllocators(ep_impl_, &ort_allocators); - for (int i = 0; i < cnt; i++) { - ret.push_back(std::make_shared(ort_allocators[i])); + if (ep_impl_->CreatePreferredAllocators) { + OrtAllocator** ort_allocators = nullptr; + int cnt = ep_impl_ -> CreatePreferredAllocators(ep_impl_, &ort_allocators); + for (int i = 0; i < cnt; i++) { + ret.push_back(std::make_shared(ort_allocators[i])); + } } return ret; } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 42fb7b392283a..142055d0c37b1 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -13,6 +13,7 @@ #include "core/framework/node_index_info.h" #include "core/framework/op_kernel.h" #include "core/framework/ort_value_pattern_planner.h" +#include "core/framework/provider_adapter.h" #include "core/framework/session_state_utils.h" #include "core/framework/utils.h" #include "core/providers/cpu/controlflow/utils.h" @@ -61,6 +62,17 @@ class StreamCommandHandleRegistryImpl : public IStreamCommandHandleRegistry { }; #endif +#ifdef ORT_ENABLE_STREAM +static std::string ShouldPostPoneRegisterResourceFor(IExecutionProvider* ep, const ExecutionProviders& all_ep) { + ExecutionProviderAdapter* ep_adapter = dynamic_cast(ep); + if (ep_adapter == nullptr) return ""; // TODO(leca): or add a member property for performance? + for (auto& any_ep : all_ep) { + if (any_ep->Type() != ep->Type() && any_ep->GetOrtDeviceByMemType(OrtMemTypeDefault) == ep->GetOrtDeviceByMemType(OrtMemTypeDefault)) return any_ep->Type(); + } + return ""; +} +#endif + SessionState::SessionState(Graph& graph, const ExecutionProviders& execution_providers, concurrency::ThreadPool* thread_pool, @@ -1367,9 +1379,17 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringRegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); + std::string register_resource_after = ""; + IExecutionProvider* out_tree_ep = nullptr; + for (auto& ep : execution_providers_) { + if (register_resource_after == "") { + register_resource_after = ShouldPostPoneRegisterResourceFor(ep.get(), execution_providers_); + if (register_resource_after == "") ep->RegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); + else out_tree_ep = ep.get(); + } else { + ep->RegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); + if (register_resource_after == ep->Type()) out_tree_ep->RegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); + } } #endif diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 0bc54782169e5..f34f3f8429456 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -78,5 +78,6 @@ int main() { std::cout<<"Result:\n"; for (size_t i = 0; i < 4; i++) std::cout<ReleaseEnv(p_env); return 0; } From 5fcb972cccff03854bf6a074a0a91e9ceee80a14 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Thu, 5 Sep 2024 18:56:08 +0800 Subject: [PATCH 32/81] complete the GetSubGraph (#21998) --- .../tensorRTEp/tensorrt_execution_provider.cc | 139 +++++++----------- 1 file changed, 54 insertions(+), 85 deletions(-) diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index b59933086e5f2..433c4f4c8d0f1 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1135,8 +1135,9 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr sub_graph->node_index_len = graph_nodes_index.first.size(); sub_graph->node_index = new size_t [sub_graph->node_index_len]; sub_graph->meta_def = new OrtMetaDef(); - std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; + std::vector inputs; + std::vector outputs; int input_order = 0; int output_order = 0; @@ -1157,13 +1158,18 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr initializers.push_back(input_name); continue; } - const auto& it = fused_outputs.find(input_name); - if (it != fused_outputs.end()) { - fused_outputs.erase(it); - erased.insert(input_name); - } else if (erased.find(input_name) == erased.end()) { - // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input_name] = input_order++; + const OrtNode* producer = nullptr; + api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); + // If the input is not produced by any node, it is a graph input + if (producer == nullptr) { + inputs.push_back(input_name); + continue; + } + size_t producer_index = 0; + api_->OrtNode_GetIndex(producer, &producer_index); + // If the producer node is not in the subgraph, the input is a graph input + if (node_set.find(producer_index) == node_set.end()) { + inputs.push_back(input_name); } } @@ -1178,81 +1184,44 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr initializers.push_back(input_name); continue; } - const auto& it = fused_outputs.find(input_name); - if (it != fused_outputs.end()) { - fused_outputs.erase(it); - erased.insert(input_name); - } else if (erased.find(input_name) == erased.end()) { - // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input_name] = input_order++; - } - } - -// // For output searching, there are two special cases, -// // One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, -// // if the output is connected to nodes that don't belong to the subgraph, the output need to be added -// // to the output list -// // The other one is, if subgraph's node output is parent graph's output. the node output should -// // be also added to the subgraph's output list -// if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { -// for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { -// const auto& node_idx = it->GetNode().Index(); -// const onnxruntime::NodeArg* output; -// // The dst_arg_index from GetDstArgIndex() could be the index for explicit/implicit input defs of the node. -// // We need to get the correct input index accordingly. (See Graph::BuildConnections() in graph.cc for more details) -// if (it->GetDstArgIndex() < static_cast(it->GetNode().InputDefs().size())) { -// output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()]; -// } else { -// output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; -// } -// if (node_set.find(node_idx) != node_set.end()) { -// const auto& iter = fused_inputs.find(output); -// if (iter != fused_inputs.end()) { -// fused_inputs.erase(iter); -// erased.insert(output); -// } else if (erased.find(output) == erased.end()) { -// if (graph_output_names.find(output->Name()) != graph_output_names.end()) { -// graph_outputs_to_add[output] = output_order; -// } -// fused_outputs[output] = output_order++; -// } -// } else { -// fused_outputs_to_add[output] = output_order++; -// } -// } -// } else { - size_t output_size = 0; - api_->OrtNode_GetOutputSize(node, &output_size); - for (size_t j = 0; j < output_size; j++) { - const char* output_name = nullptr; - api_->OrtNode_GetIthOutputName(node, j, &output_name); - const auto& it = fused_inputs.find(output_name); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output_name); - } - // Only when output is neither in input list nor erased list, add the output to output list - else if (erased.find(output_name) == erased.end()) { - if (graph_output_names.find(output_name) != graph_output_names.end()) { - graph_outputs_to_add[output_name] = output_order; - } - fused_outputs[output_name] = output_order++; - } + const OrtNode* producer = nullptr; + api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); + // If the input is not produced by any node, it is a graph input + if (producer == nullptr) { + inputs.push_back(input_name); + continue; } -// } - } - - fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); - fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); - - // Sort inputs and outputs by the order they were added - std::multimap inputs, outputs; - for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { - inputs.insert(std::pair(it->second, it->first)); - } + size_t producer_index = 0; + api_->OrtNode_GetIndex(producer, &producer_index); + // If the producer node is not in the subgraph, the input is a graph input + if (node_set.find(producer_index) == node_set.end()) { + inputs.push_back(input_name); + } + } - for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { - outputs.insert(std::pair(it->second, it->first)); + size_t output_size = 0; + api_->OrtNode_GetOutputSize(node, &output_size); + for (size_t j = 0; j < output_size; j++) { + const char* output_name = nullptr; + api_->OrtNode_GetIthOutputName(node, j, &output_name); + // If the output is the graph output, it is a subgraph output + if (graph_output_names.find(output_name) != graph_output_names.end()) { + outputs.push_back(output_name); + continue; + } + size_t consumer_count = 0; + const OrtNode** consumers = nullptr; + api_->OrtGraph_GetNodesConsumingInput(graph, output_name, &consumer_count, &consumers); + for (size_t k = 0; k < consumer_count; k++) { + size_t consumer_index = 0; + api_->OrtNode_GetIndex(consumers[k], &consumer_index); + // If the consumer node is not in the subgraph, the output is a subgraph output + if (node_set.find(consumer_index) == node_set.end()) { + outputs.push_back(output_name); + break; + } + } + } } // Generate unique kernel name for TRT subgraph @@ -1270,8 +1239,8 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr sub_graph->meta_def->inputs = new char* [sub_graph->meta_def->input_len]; i = 0; for (const auto& input : inputs) { - sub_graph->meta_def->inputs[i] = new char [input.second.length() + 1]; - strcpy(sub_graph->meta_def->inputs[i++], input.second.c_str()); + sub_graph->meta_def->inputs[i] = new char [input.length() + 1]; + strcpy(sub_graph->meta_def->inputs[i++], input.c_str()); } sub_graph->meta_def->initializer_len = initializers.size(); @@ -1286,8 +1255,8 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr sub_graph->meta_def->outputs = new char* [sub_graph->meta_def->output_len]; i = 0; for (const auto& output : outputs) { - sub_graph->meta_def->outputs[i] = new char [output.second.length() + 1]; - strcpy(sub_graph->meta_def->outputs[i++], output.second.c_str()); + sub_graph->meta_def->outputs[i] = new char [output.length() + 1]; + strcpy(sub_graph->meta_def->outputs[i++], output.c_str()); } sub_graph->meta_def->domain = "com.microsoft"; From c3bb437d754e219e5fcf23bd7f34ef34018f23f6 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Fri, 6 Sep 2024 00:27:23 +0000 Subject: [PATCH 33/81] run resnet18v1_7, crash on GetSubGraph() --- samples/c_test/test.cpp | 58 +++++++++++++++---- .../tensorRTEp/tensorrt_execution_provider.cc | 4 +- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index f34f3f8429456..b85b90b369a6e 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -42,19 +42,38 @@ void TestOriginalTensorRTEp(const OrtApi* g_ort, OrtSessionOptions* so) { g_ort->ReleaseTensorRTProviderOptions(tensorrt_options); } -int main() { - const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - OrtEnv* p_env = nullptr; - OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;//OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; - THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env)); - OrtSessionOptions* so = nullptr; - THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); +void RunResnet18v1_7(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { + // download resnet18-v1-7 model at: + // https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet18-v1-7.tar.gz + OrtSession* session = nullptr; + THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/resnet18-v1-7/resnet18-v1-7.onnx", so, &session)); + + const int input_data_cnt = 3 * 224 * 224; + float input_data[input_data_cnt]; + for (int i = 0; i < input_data_cnt; i++) { + input_data[i] = -1 + static_cast(rand()) / (static_cast(RAND_MAX/(2))); // [-1, 1) uniform distribution + } + const size_t input_len = input_data_cnt * sizeof(float); + const int64_t input_shape[] = {1, 3, 224, 224}; + const size_t shape_len = sizeof(input_shape)/sizeof(input_shape[0]); - //TestCompileBasedEp(g_ort, p_env, so); - //TestKernelBasedEp(g_ort, p_env, so); - TestTensorRTEp(g_ort, p_env, so); - //TestOriginalTensorRTEp(g_ort, so); + OrtMemoryInfo* memory_info = nullptr; + THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); + OrtValue* input_tensor = nullptr; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); + const char* input_names[] = {"data"}; + const char* output_names[] = {"resnetv15_dense0_fwd"}; + OrtValue* output_tensor = nullptr; + THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)&input_tensor, 1, output_names, 1, &output_tensor)); + + float* output_tensor_data = nullptr; + THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data)); + std::cout<<"Result:\n"; + for (size_t i = 0; i < 4; i++) std::cout<CreateSession(p_env, "/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", so, &session)); @@ -77,6 +96,23 @@ int main() { THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data)); std::cout<<"Result:\n"; for (size_t i = 0; i < 4; i++) std::cout<GetApi(ORT_API_VERSION); + OrtEnv* p_env = nullptr; + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;//OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; + THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env)); + OrtSessionOptions* so = nullptr; + THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); + + //TestCompileBasedEp(g_ort, p_env, so); + //TestKernelBasedEp(g_ort, p_env, so); + TestTensorRTEp(g_ort, p_env, so); + //TestOriginalTensorRTEp(g_ort, so); + + //RunRelu(g_ort, p_env, so); + RunResnet18v1_7(g_ort, p_env, so); g_ort->ReleaseEnv(p_env); return 0; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index b59933086e5f2..4e134093c908d 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -3621,7 +3621,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (group.second) { nodes_list_output.push_back(group); } else { - onnx::ModelProto m; +/* onnx::ModelProto m; m.set_ir_version(3); onnx::OperatorSetIdProto* p = m.add_opset_import(); p->set_domain(""); @@ -3673,7 +3673,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect output->set_name(api_->OrtGraph_GetIthOutputName(graph, i)); output->mutable_type()->mutable_tensor_type()->set_elem_type(api_->OrtGraph_GetIthOutputElemType(graph, i)); } - +*/ // auto model_build = graph.CreateModel(*GetLogger()); // auto& graph_build = model_build->MainGraph(); // bool has_control_flow_op = false; From 3efac9775d326b574464bcbb0c83134caa0b89f4 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Fri, 6 Sep 2024 23:52:31 +0000 Subject: [PATCH 34/81] resnet18-v1-7 works for TRT EP, with next_nodes_list assignment commented out in GetSupportedList() --- samples/c_test/test.cpp | 2 +- .../tensorRTEp/tensorrt_execution_provider.cc | 17 +++++++++-------- .../tensorrt_execution_provider_info.cc | 3 +++ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index b85b90b369a6e..94fa534d9f704 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -20,7 +20,7 @@ void TestKernelBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) void TestTensorRTEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); - std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; + std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); // OrtCUDAProviderOptionsV2* cuda_options = nullptr; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 31ab8529edb6a..1e0aa77f56063 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1578,15 +1578,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const type = ep_type; create_stream = new OrtCreateStream(); + create_stream->device_type = 1; // GPU create_stream->CreateStreamFunc = [](const OrtDevice* device) -> void* { cudaStream_t stream = nullptr; cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); return stream; }; - api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, 0, &default_device); - info_ = TensorrtExecutionProviderInfo::FromProviderOptions(ep_info); + device_id_ = info_.device_id; + api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -3590,7 +3591,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (group.second) { nodes_list_output.push_back(group); } else { -/* onnx::ModelProto m; + onnx::ModelProto m; m.set_ir_version(3); onnx::OperatorSetIdProto* p = m.add_opset_import(); p->set_domain(""); @@ -3599,7 +3600,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect for (size_t i = 0; i < nodes_count; i++) { onnx::NodeProto* n = g->add_node(); const OrtNode* node = nullptr; - api_->OrtGraph_GetOrtNode(graph, i, &node); + api_->OrtGraph_GetOrtNode(graph, node_index[i], &node); const char* op_type = nullptr; api_->OrtNode_GetOpType(node, &op_type); @@ -3642,7 +3643,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect output->set_name(api_->OrtGraph_GetIthOutputName(graph, i)); output->mutable_type()->mutable_tensor_type()->set_elem_type(api_->OrtGraph_GetIthOutputElemType(graph, i)); } -*/ + // auto model_build = graph.CreateModel(*GetLogger()); // auto& graph_build = model_build->MainGraph(); // bool has_control_flow_op = false; @@ -3816,9 +3817,9 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &subgraph_node_count, &subgraph_node_index); next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, graph, early_termination); for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { - for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { - next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; - } +// for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { +// next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; +// } nodes_list_output.push_back(next_nodes_list[i]); } } diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.cc b/samples/tensorRTEp/tensorrt_execution_provider_info.cc index 5349e4c879e81..4d3030a365f18 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_info.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.cc @@ -54,6 +54,9 @@ constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { TensorrtExecutionProviderInfo info{}; + for (const auto& [k, v] : options) { + if (k == "device_id") info.device_id = std::atoi(v.c_str()); + } // void* user_compute_stream = nullptr; // ORT_THROW_IF_ERROR( // ProviderOptionsParser{} From 766fec9e940259cdfc5dbc630842f8ad1f2d91e5 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 9 Sep 2024 23:31:40 +0000 Subject: [PATCH 35/81] test cases for decoder and fast_rcnn, delete dynamic_cast in ShouldPostPoneRegisterResourceFor() --- .../core/framework/execution_provider.h | 4 + onnxruntime/core/framework/provider_adapter.h | 1 + onnxruntime/core/framework/session_state.cc | 3 +- samples/c_test/test.cpp | 85 +++++++++++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 49c3d1bdd088a..da0c665cb2c1b 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -75,6 +75,8 @@ class IExecutionProvider { */ const OrtDevice default_device_; + bool intree_ep = true; + public: virtual ~IExecutionProvider() = default; @@ -325,6 +327,8 @@ class IExecutionProvider { return InlinedVector(); } + bool IsIntreeEp() const { return intree_ep; } + private: const std::string type_; diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 7bae1c91fe87e..a485e7be82433 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -37,6 +37,7 @@ class DataTransferAdapter : public IDataTransfer { class ExecutionProviderAdapter : public IExecutionProvider { public: ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type, ep->default_device ? *(ep->default_device) : OrtDevice()), ep_impl_(ep) { + intree_ep = false; if (ep_impl_->RegisterKernels) { kernel_registry_ = std::make_shared(); ep_impl_->RegisterKernels(reinterpret_cast(kernel_registry_.get())); diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 142055d0c37b1..437c04d758931 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -64,8 +64,7 @@ class StreamCommandHandleRegistryImpl : public IStreamCommandHandleRegistry { #ifdef ORT_ENABLE_STREAM static std::string ShouldPostPoneRegisterResourceFor(IExecutionProvider* ep, const ExecutionProviders& all_ep) { - ExecutionProviderAdapter* ep_adapter = dynamic_cast(ep); - if (ep_adapter == nullptr) return ""; // TODO(leca): or add a member property for performance? + if (ep->IsIntreeEp()) return ""; // TODO(leca): Or use dynamic_cast to check is it ExecutionProviderAdapter instance? Need to disable onnxruntime_DISABLE_RTTI for (auto& any_ep : all_ep) { if (any_ep->Type() != ep->Type() && any_ep->GetOrtDeviceByMemType(OrtMemTypeDefault) == ep->GetOrtDeviceByMemType(OrtMemTypeDefault)) return any_ep->Type(); } diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 94fa534d9f704..49e6e2d30ea37 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -98,6 +98,91 @@ void RunRelu(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { for (size_t i = 0; i < 4; i++) std::cout<CreateSession(p_env, "/home/leca/models/decoder/decoder.onnx", so, &session)); + + OrtMemoryInfo* memory_info = nullptr; + THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); + std::vector input_tensors(28, nullptr); + + const int input_0_cnt = 16; + int64_t input_0_data[input_0_cnt]; + for (int i = 0; i < input_0_cnt; i++) input_0_data[i] = static_cast(rand()); + const size_t input_0_len = input_0_cnt * sizeof(int64_t); + const int64_t input_0_shape[] = {16, 1}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_0_data, input_0_len, input_0_shape, sizeof(input_0_shape)/sizeof(input_0_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &input_tensors[0])); + + const int input_1_cnt = 16; + bool input_1_data[input_1_cnt]; + for (int i = 0; i < input_1_cnt; i++) input_1_data[i] = false; + const size_t input_1_len = input_1_cnt * sizeof(bool); + const int64_t input_1_shape[] = {16, 1}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_1_data, input_1_len, input_1_shape, sizeof(input_1_shape)/sizeof(input_1_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, &input_tensors[1])); + + const int input_3_cnt = 16*256; + bool input_3_data[input_3_cnt]; + for (int i = 0; i < input_3_cnt; i++) input_3_data[i] = false; + const size_t input_3_len = input_3_cnt * sizeof(bool); + const int64_t input_3_shape[] = {16, 256}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_3_data, input_3_len, input_3_shape, sizeof(input_3_shape)/sizeof(input_3_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, &input_tensors[3])); + + for (int j = 2; j < 28; j++) { + if (j == 3) continue; + const int input_cnt = 16 * 256 * 1024; + float input_data[input_cnt]; + for (int i = 0; i < input_cnt; i++) input_data[i] = static_cast(rand()) / static_cast(RAND_MAX); // [0, 1) + const size_t input_len = input_cnt * sizeof(float); + const int64_t input_shape[] = {16, 256, 1024}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[j])); + } + + const char* input_names[] = {"input_ids", "input_mask", "encoder_states", "encoder_input_mask", "history_states_0", + "history_states_1", "history_states_2", "history_states_3", "history_states_4", "history_states_5", "history_states_6", + "history_states_7", "history_states_8", "history_states_9", "history_states_10", "history_states_11", "history_states_12", + "history_states_13", "history_states_14", "history_states_15", "history_states_16", "history_states_17", "history_states_18", + "history_states_19", "history_states_20", "history_states_21", "history_states_22", "history_states_23"}; + const char* output_names[] = {"lm_logits", "log_lm_logits", "hidden_states_0", "hidden_states_1", "hidden_states_2", + "hidden_states_3", "hidden_states_4", "hidden_states_5", "hidden_states_6", "hidden_states_7", "hidden_states_8", + "hidden_states_9", "hidden_states_10", "hidden_states_11", "hidden_states_12", "hidden_states_13", "hidden_states_14", + "hidden_states_15", "hidden_states_16", "hidden_states_17", "hidden_states_18", "hidden_states_19", "hidden_states_20", + "hidden_states_21", "hidden_states_22", "hidden_states_23", "hidden_states_24"}; + OrtValue* output_tensor = nullptr; + THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)input_tensors.data(), sizeof(input_names)/sizeof(input_names[0]), output_names, sizeof(output_names)/sizeof(output_names[0]), &output_tensor)); + + float* output_tensor_data = nullptr; + THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data)); + std::cout<<"Result:\n"; + for (size_t i = 0; i < 4; i++) std::cout<CreateSession(p_env, "/home/leca/models/faster_rcnn/faster_rcnn_R_50_FPN_1x.onnx", so, &session)); + + OrtMemoryInfo* memory_info = nullptr; + THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); + + const int input_cnt = 3 * 800 * 1088; + int64_t input_data[input_cnt]; + for (int i = 0; i < input_cnt; i++) input_data[i] = static_cast(rand()); + const size_t input_len = input_cnt * sizeof(float); + const int64_t input_shape[] = {3, 800, 1088}; + OrtValue* input_tensor = nullptr; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); + + const char* input_names[] = {"image"}; + const char* output_names[] = {"6379", "6381", "6383"}; + + OrtValue* output_tensor = nullptr; + THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)&input_tensor, sizeof(input_names)/sizeof(input_names[0]), output_names, sizeof(output_names)/sizeof(output_names[0]), &output_tensor)); + + float* output_tensor_data = nullptr; + THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data)); + std::cout<<"Result:\n"; + for (size_t i = 0; i < 4; i++) std::cout<GetApi(ORT_API_VERSION); OrtEnv* p_env = nullptr; From ea2465ca379bc13612f8aaeb10ce3938bfbda4f9 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 11 Sep 2024 00:36:55 +0000 Subject: [PATCH 36/81] add tensorrt home in CMakeLists, add trt and CUDA ep for test, change trt type to tensorrtEp --- samples/c_test/test.cpp | 62 +++++++++++-------- samples/tensorRTEp/CMakeLists.txt | 32 +++++----- .../tensorRTEp/tensorrt_execution_provider.cc | 2 +- 3 files changed, 52 insertions(+), 44 deletions(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 49e6e2d30ea37..2bffca67ec955 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -22,11 +22,17 @@ void TestTensorRTEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); +} + +void TestTensorRTAndCudaEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { + THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); + std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; + THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); -// OrtCUDAProviderOptionsV2* cuda_options = nullptr; -// THROW_ON_ERROR(g_ort->CreateCUDAProviderOptions(&cuda_options)); -// THROW_ON_ERROR(g_ort->SessionOptionsAppendExecutionProvider_CUDA_V2(so, cuda_options)); -// g_ort->ReleaseCUDAProviderOptions(cuda_options); + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + THROW_ON_ERROR(g_ort->CreateCUDAProviderOptions(&cuda_options)); + THROW_ON_ERROR(g_ort->SessionOptionsAppendExecutionProvider_CUDA_V2(so, cuda_options)); + g_ort->ReleaseCUDAProviderOptions(cuda_options); } void TestOriginalTensorRTEp(const OrtApi* g_ort, OrtSessionOptions* so) { @@ -160,27 +166,27 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { OrtSession* session = nullptr; THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/faster_rcnn/faster_rcnn_R_50_FPN_1x.onnx", so, &session)); - OrtMemoryInfo* memory_info = nullptr; - THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); - - const int input_cnt = 3 * 800 * 1088; - int64_t input_data[input_cnt]; - for (int i = 0; i < input_cnt; i++) input_data[i] = static_cast(rand()); - const size_t input_len = input_cnt * sizeof(float); - const int64_t input_shape[] = {3, 800, 1088}; - OrtValue* input_tensor = nullptr; - THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); - - const char* input_names[] = {"image"}; - const char* output_names[] = {"6379", "6381", "6383"}; - - OrtValue* output_tensor = nullptr; - THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)&input_tensor, sizeof(input_names)/sizeof(input_names[0]), output_names, sizeof(output_names)/sizeof(output_names[0]), &output_tensor)); - - float* output_tensor_data = nullptr; - THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data)); - std::cout<<"Result:\n"; - for (size_t i = 0; i < 4; i++) std::cout<CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); +// +// const int input_cnt = 3 * 800 * 1088; +// int64_t input_data[input_cnt]; +// for (int i = 0; i < input_cnt; i++) input_data[i] = static_cast(rand()) / static_cast(RAND_MAX); // [0, 1) +// const size_t input_len = input_cnt * sizeof(float); +// const int64_t input_shape[] = {3, 800, 1088}; +// OrtValue* input_tensor = nullptr; +// THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); +// +// const char* input_names[] = {"image"}; +// const char* output_names[] = {"6379", "6381", "6383"}; + +// OrtValue* output_tensor = nullptr; +// THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)&input_tensor, sizeof(input_names)/sizeof(input_names[0]), output_names, sizeof(output_names)/sizeof(output_names[0]), &output_tensor)); +// +// float* output_tensor_data = nullptr; +// THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data)); +// std::cout<<"Result:\n"; +// for (size_t i = 0; i < 4; i++) std::cout<ReleaseEnv(p_env); return 0; diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index f3a4a35371925..f9711bd77537f 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -1,6 +1,6 @@ # usage: # cd build/ -# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=90 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/leca/TensorRT-10.4.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") # cmake --build ./ cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) @@ -16,21 +16,21 @@ file(GLOB tensorrt_src "./*.cc") add_library(TensorRTEp SHARED ${tensorrt_src}) target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "/usr/local/cuda/include" - "/home/leca/TensorRT-10.0.1.6/include" - "../../build/Linux/Debug/_deps/flatbuffers-src/include" - "../../build/Linux/Debug/_deps/gsl-src/include" - "../../build/Linux/Debug/_deps/onnx-src" - "../../build/Linux/Debug/_deps/onnx-build" - "../../build/Linux/Debug/_deps/protobuf-src/src") + ${TENSORRT_HOME}/include + "../../build/tensorrt/Debug/_deps/flatbuffers-src/include" + "../../build/tensorrt/Debug/_deps/gsl-src/include" + "../../build/tensorrt/Debug/_deps/onnx-src" + "../../build/tensorrt/Debug/_deps/onnx-build" + "../../build/tensorrt/Debug/_deps/protobuf-src/src") ## looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol -target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" - "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer.so" - "/home/leca/TensorRT-10.0.1.6/lib/libnvinfer_plugin.so" - "/home/leca/TensorRT-10.0.1.6/lib/libnvonnxparser.so" - "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/flatbuffers-build/libflatbuffers.a" +target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so" + ${TENSORRT_HOME}/lib/libnvinfer.so + ${TENSORRT_HOME}/lib/libnvinfer_plugin.so + ${TENSORRT_HOME}/lib/libnvonnxparser.so + "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/flatbuffers-build/libflatbuffers.a" CUDA::cudart - "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" - "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" - "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" - "/home/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotocd.a") + "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/onnx-build/libonnx.a" + "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/onnx-build/libonnx_proto.a" + "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotobufd.a" + "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotocd.a") diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 1e0aa77f56063..841f5f70f27bd 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -2012,7 +2012,7 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { ProviderOptions options; for (size_t i = 0; i < option_size; i++) options[ep_option_keys[i]] = ep_option_values[i]; - std::unique_ptr ret = std::make_unique("TensorrtExecutionProvider", std::move(options)); + std::unique_ptr ret = std::make_unique("tensorrtEp", std::move(options)); return ret.release(); }; } From 76a9305eca897671382b20e5494ccb9de37cc6b9 Mon Sep 17 00:00:00 2001 From: cao lei Date: Wed, 18 Sep 2024 16:26:59 -0700 Subject: [PATCH 37/81] [WIP, DONT REVIEW] add initializer to graph proto (#22085) ### Description ### Motivation and Context --------- Co-authored-by: guyang3532 <62738430+guyang3532@users.noreply.github.com> --- .../core/session/onnxruntime_c_api.h | 30 ++- onnxruntime/core/session/onnxruntime_c_api.cc | 201 +++++++++++++++- onnxruntime/core/session/ort_apis.h | 14 ++ samples/c_test/test.cpp | 8 +- .../tensorRTEp/tensorrt_execution_provider.cc | 216 +----------------- 5 files changed, 258 insertions(+), 211 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 80ecf138a9fd8..a5cc35e3ee556 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -743,6 +743,20 @@ typedef struct OrtNodeComputeInfo { void(ORT_API_CALL* DestroyFunctionStateFunc)(void*); } OrtNodeComputeInfo; +typedef struct OrtTensorRef { // TODO(leca): OrtValueInfoRef inside OrtTensorRef? + int64_t* shape; + size_t shape_len; + ONNXTensorElementDataType data_type; + const char* data; + size_t data_len; +} OrtTensorRef; + +typedef struct OrtValueInfoRef { + int64_t* shape; + size_t shape_len; + ONNXTensorElementDataType data_type; +} OrtValueInfoRef; + typedef struct OrtExecutionProvider { #ifdef __cplusplus OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, CreatePreferredAllocators{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr}, @@ -4791,7 +4805,15 @@ struct OrtApi { int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; + bool(ORT_API_CALL* OrtGraph_GetInitializerTensor)(const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); + + bool(ORT_API_CALL* OrtGraph_GetValueInfo)(const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); + + size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; // TODO(leca): review and discuss + + ORT_API2_STATUS(OrtGraph_DeserializeFromArray, const void* data, size_t len, _Outptr_ OrtGraphViewer**); // TODO(leca): review and discuss + + ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); @@ -4819,8 +4841,12 @@ struct OrtApi { ORT_API2_STATUS(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index); + size_t(ORT_API_CALL* OrtNode_GetAttributeNames)(const OrtNode*, _Out_ const char*** names); + ORT_API2_STATUS(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size); + int(ORT_API_CALL* OrtNode_GetAttributeType)(const OrtNode* node, const char* attribute)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; // AttributeProto_AttributeType + ORT_API2_STATUS(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count); ORT_API2_STATUS(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size); @@ -4839,6 +4865,8 @@ struct OrtApi { int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + float(ORT_API_CALL* OrtNode_GetAttributeFloat)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index c3807dac76d17..caece5505b2c0 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2521,6 +2521,49 @@ ORT_API(int32_t, OrtApis::OrtGraph_GetIthOutputElemType, const OrtGraphViewer* g return graph_viewer->GetOutputs()[i]->TypeAsProto()->tensor_type().elem_type(); } +ORT_API(bool, OrtApis::OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** out) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const onnx::TensorProto* initializer = nullptr; + if (!graph_viewer->GetInitializedTensor(initializer_name, initializer)) return false; + *out = new OrtTensorRef(); // TODO(leca): release + (*out)->shape_len = initializer->dims_size(); + (*out)->shape = new int64_t [initializer->dims_size()]; + for (size_t i = 0; i < (*out)->shape_len; i++) { + ((*out)->shape)[i] = initializer->dims(i); + } + + (*out)->data_type = static_cast(initializer->data_type()); + // see utils::ConvertRawDataInTensorProto() + switch (initializer->data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + (*out)->data_len = initializer->float_data_size(); + (*out)->data = reinterpret_cast(initializer->float_data().data()); + break; + } + return true; +} + +static ONNXTensorElementDataType GetDataTypeFromTypeProto(const onnx::TypeProto* type) { // onnxruntime\core\optimizer\transpose_optimization\ort_optimizer_api_impl.cc + if (!type || !utils::HasTensorType(*type) || !utils::HasElementType(*type)) return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + return static_cast(type->tensor_type().elem_type()); +} + +ORT_API(bool, OrtApis::OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const NodeArg* node_arg = graph_viewer->GetNodeArg(name); + + *out = new OrtValueInfoRef(); // TODO(leca): release + const onnx::TypeProto* type = node_arg->TypeAsProto(); + (*out)->data_type = GetDataTypeFromTypeProto(type); + const auto& dims = utils::TryGetShape(*type)->dim(); + (*out)->shape_len = dims.size(); + (*out)->shape = new int64_t [(*out)->shape_len]; + for (size_t i = 0; i < (*out)->shape_len; i++) ((*out)->shape)[i] = utils::HasDimValue(dims[i]) ? dims[i].dim_value() : -1; + + return true; +} + ORT_API(size_t, OrtApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); Model model(graph_viewer->Name(), true, ModelMetaData(), PathString(), @@ -2533,11 +2576,139 @@ ORT_API(size_t, OrtApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, onnx::ModelProto model_proto = model.ToProto(); GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true); size_t ret = model_proto.ByteSizeLong(); - *data = malloc(ret); + *data = malloc(ret); // TODO(leca): release model_proto.SerializeToArray(*data, ret); return ret; } +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_DeserializeFromArray, const void* data, size_t len, _Outptr_ OrtGraphViewer** ret) { + onnx::ModelProto model_proto; + if (!model_proto.ParseFromArray(data, len)) return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Parse model proto from array returns false"); + std::shared_ptr model; + Status status = Model::Load(std::move(model_proto), model, nullptr, logging::LoggingManager::DefaultLogger()); + if (status != Status::OK()) return ToOrtStatus(status); + std::unique_ptr graph_viewer = std::make_unique(model->MainGraph()); + *ret = reinterpret_cast(graph_viewer.release()); // TODO(leca): release from the caller + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + // Get parent graph output names + std::unordered_set graph_output_names; + for (const auto* output_arg : graph_viewer->GetOutputs()) { + graph_output_names.insert(output_arg->Name()); + } + // TODO(leca): cannot use unique_ptr here, otherwise when this function exits, sub_graph_viewer->graph_->graph_proto_, which is from model_build->model_proto_, will be nullptr. + // Pay special attention when Graph object is releasing. We need to release model_build seperately then. + Model* model_build = new Model (graph_viewer->Name(), true, ModelMetaData(), PathString(), +#if !defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(), +#else + IOnnxRuntimeOpSchemaRegistryList(), graph_viewer->DomainToVersionMap(), +#endif // ORT_MINIMAL_BUILD + std::vector(), graph_viewer->GetGraph().GetLogger()); + + auto& graph_build = model_build->MainGraph(); + // bool has_control_flow_op = false; + + std::vector subgraph_output_names; + const std::vector& node_index = graph_viewer->GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + for(int i = 0; i < node_num; i++) { + const auto& node = graph_viewer->GetNode(node_index[node_indices[i]]); + std::vector inputs, outputs; + for (auto input : node->InputDefs()) { + auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + inputs.push_back(&n_input); + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph_viewer->GetInitializedTensor(input->Name(), initializer)) { + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { + graph_build.AddInitializedTensor(*(initializer)); + } + } + } + for (auto input : node->ImplicitInputDefs()) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph_viewer->GetInitializedTensor(input->Name(), initializer)) { + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { + graph_build.AddInitializedTensor(*(initializer)); + } + } + } + for (auto output : node->OutputDefs()) { + auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + outputs.push_back(&n_output); + const auto name = output->Name(); + if (graph_output_names.find(name) != graph_output_names.end()) { + subgraph_output_names.push_back(name); + } + } + + // TODO: handle control flow ops + // if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { + // has_control_flow_op = true; + // } + + // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node attributes are not in sync because of graph optimization. + // Therefore, we need to force GraphProto attributes to be updated in order to get the valid GraphProto. + if (node->GetAttributes().size() > 0) { + auto node_proto = std::make_unique(); + // we need to update any GraphProto attributes for subgraphs so that any changes made by things + // such as the optimizers are captured. otherwise we can end up saving an invalid graph. + node->ToProto(*node_proto, /* update_subgraphs */ true); + const int num_attributes = node_proto->attribute_size(); + NodeAttributes node_attributes; + node_attributes.reserve(num_attributes); + + for (int i = 0; i < num_attributes; ++i) { + auto& attr = node_proto->attribute(i); + node_attributes.emplace(attr.name(), attr); + } + + // The GraphProto attributes are the updated ones. + graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node_attributes, node->Domain()); + } else { + // The GraphProto attributes are the original ones. + graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); + } + } + + // TODO:yang + // Only if the newly built graph has control flow op as well as it has parent node, + // it needs to handle outer scope values before calling graph.Resolve(). + // if (has_control_flow_op && graph.ParentNode()) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); + // BuildSubGraphContext(graph_build); + // SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); + // SetAllGraphInputs(graph_build); + // } + + common::Status status = graph_build.Resolve(); + if (status != Status::OK()) return ToOrtStatus(status); + + // Add parent graph output to the subgraph + int i = 0; + std::vector subgraph_outputs; + subgraph_outputs.resize(subgraph_output_names.size()); + for (auto& name : subgraph_output_names) { + auto output_arg = graph_viewer->GetNodeArg(name); + auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + subgraph_outputs[i] = &subgraph_output_arg; + ++i; + } + auto& graph_build_outputs = graph_build.GetOutputs(); + subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); + graph_build.SetOutputs(graph_build_outputs); + status = graph_build.Resolve(); + if (status != Status::OK()) return ToOrtStatus(status); + + auto sub_graph_viewer = std::make_unique(graph_build); + *subgraph = reinterpret_cast(sub_graph_viewer.release()); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetName, const OrtNode* node, _Out_ const char** name) { const ::onnxruntime::Node* n = reinterpret_cast(node); *name = n->Name().c_str(); @@ -2623,12 +2794,28 @@ ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIndex, const OrtNode* node, _Out_ size_t return nullptr; } +ORT_API(size_t, OrtApis::OrtNode_GetAttributeNames, const OrtNode* node, _Out_ const char*** names) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + size_t ret = n->GetAttributes().size(); + *names = new const char* [ret]; + int i = 0; + for (const auto& [k, v] : n->GetAttributes()) { + (*names)[i++] = k.c_str(); + } + return ret; +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size) { const ::onnxruntime::Node* n = reinterpret_cast(node); *attr_size = n->GetAttributes().size(); return nullptr; } +ORT_API(int, OrtApis::OrtNode_GetAttributeType, const OrtNode* node, const char* attribute) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return static_cast(n->GetAttributes().at(attribute).type()); +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count) { const ::onnxruntime::Node* n = reinterpret_cast(node); *count = n->GetAttributes().count(key); @@ -2681,6 +2868,11 @@ ORT_API(int64_t, OrtApis::OrtNode_GetAttributeInt, const OrtNode* node, const ch return n->GetAttributes().at(key).i(); } +ORT_API(float, OrtApis::OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).f(); +} + ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs) { const ::onnxruntime::Node* n = reinterpret_cast(node); std::vector> subg = n->GetSubgraphs(); @@ -3117,7 +3309,11 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtGraph_GetOutputSize, &OrtApis::OrtGraph_GetIthOutputName, &OrtApis::OrtGraph_GetIthOutputElemType, + &OrtApis::OrtGraph_GetInitializerTensor, + &OrtApis::OrtGraph_GetValueInfo, &OrtApis::OrtGraph_SerializeToArray, + &OrtApis::OrtGraph_DeserializeFromArray, + &OrtApis::OrtGraph_GetSubGraph, &OrtApis::OrtNode_GetName, &OrtApis::OrtNode_GetDescription, &OrtApis::OrtNode_GetDomain, @@ -3131,7 +3327,9 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtNode_GetOutputSize, &OrtApis::OrtNode_GetIthOutputName, &OrtApis::OrtNode_GetIndex, + &OrtApis::OrtNode_GetAttributeNames, &OrtApis::OrtNode_GetAttributeSize, + &OrtApis::OrtNode_GetAttributeType, &OrtApis::OrtNode_GetAttributeKeyCount, &OrtApis::OrtNode_GetAttributeIntSize, &OrtApis::OrtNode_GetAttributeFloatSize, @@ -3141,6 +3339,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::OrtNode_GetAttributeIthStr, &OrtApis::OrtNode_GetAttributeStr, &OrtApis::OrtNode_GetAttributeInt, + &OrtApis::OrtNode_GetAttributeFloat, &OrtApis::OrtNode_GetSubgraphs, &OrtApis::OrtKernelRegistry_RegisterKernel, &OrtApis::CreateOrtTypeConstraints, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index bdd1077d2624c..fcf4659021c83 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -573,8 +573,16 @@ ORT_API(const char*, OrtGraph_GetIthOutputName, const OrtGraphViewer*, size_t i) ORT_API(int32_t, OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; +ORT_API(bool, OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); + +ORT_API(bool, OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); + ORT_API(size_t, OrtGraph_SerializeToArray, const OrtGraphViewer*, _Out_ void** data); +ORT_API_STATUS_IMPL(OrtGraph_DeserializeFromArray, const void* data, size_t len, _Outptr_ OrtGraphViewer**); + +ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); + ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); ORT_API_STATUS_IMPL(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); @@ -601,8 +609,12 @@ ORT_API_STATUS_IMPL(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Ou ORT_API_STATUS_IMPL(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index); +ORT_API(size_t, OrtNode_GetAttributeNames, const OrtNode* node, const char*** names); + ORT_API_STATUS_IMPL(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size); +ORT_API(int, OrtNode_GetAttributeType, const OrtNode* node, const char* attribute) ORT_ALL_ARGS_NONNULL; + ORT_API_STATUS_IMPL(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count); ORT_API_STATUS_IMPL(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size); @@ -621,6 +633,8 @@ ORT_API(const char*, OrtNode_GetAttributeStr, const OrtNode* node, const char* k ORT_API(int64_t, OrtNode_GetAttributeInt, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; +ORT_API(float, OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; + ORT_API_STATUS_IMPL(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 2bffca67ec955..454ac7b8728ff 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -2,8 +2,13 @@ #include #include +const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + inline void THROW_ON_ERROR(OrtStatus* status) { - if (status != nullptr) abort(); + if (status != nullptr) { + std::cout<<"ErrorMessage:"<GetErrorMessage(status)<<"\n"; + abort(); + } } void TestCompileBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { @@ -190,7 +195,6 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { } int main() { - const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); OrtEnv* p_env = nullptr; OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;//OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env)); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 841f5f70f27bd..c20861fbdf16a 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -3575,12 +3575,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect return nodes_list_output; } - std::unordered_set graph_output_names; - size_t output_size = api_->OrtGraph_GetOutputSize(graph); - for (size_t i = 0; i < output_size; i++) { - graph_output_names.insert(api_->OrtGraph_GetIthOutputName(graph, i)); - } - iterations++; size_t nodes_count = 0; const size_t* node_index = nullptr; @@ -3591,204 +3585,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (group.second) { nodes_list_output.push_back(group); } else { - onnx::ModelProto m; - m.set_ir_version(3); - onnx::OperatorSetIdProto* p = m.add_opset_import(); - p->set_domain(""); - p->set_version(10); - onnx::GraphProto* g = m.mutable_graph(); - for (size_t i = 0; i < nodes_count; i++) { - onnx::NodeProto* n = g->add_node(); - const OrtNode* node = nullptr; - api_->OrtGraph_GetOrtNode(graph, node_index[i], &node); - - const char* op_type = nullptr; - api_->OrtNode_GetOpType(node, &op_type); - n->set_op_type(op_type); - - const char* name = nullptr; - api_->OrtNode_GetName(node, &name); - n->set_name(name); - - // TODO(leca): Implicit input? & attributes - size_t input_size = 0; - api_->OrtNode_GetInputSize(node, &input_size); - for (size_t j = 0; j < input_size; j++) { - const char* jth_input_name = nullptr; - api_->OrtNode_GetIthInputName(node, j, &jth_input_name); - n->add_input(jth_input_name, strlen(jth_input_name)); - } - size_t output_size = 0; - api_->OrtNode_GetOutputSize(node, &output_size); - for (size_t j = 0; j < output_size; j++) { - const char* jth_output_name = nullptr; - api_->OrtNode_GetIthOutputName(node, j, &jth_output_name); - n->add_output(jth_output_name, strlen(jth_output_name)); - } - } - - // TODO(leca): set_elem_type, set_dim_value for graph input and output - size_t graph_inputs = 0; - const char** graph_input_names = nullptr; - api_->OrtGraph_GetInputsIncludingInitializers(graph, &graph_inputs, &graph_input_names); - for (size_t i = 0; i < graph_inputs; i++) { - onnx::ValueInfoProto* input = g->add_input(); - input->set_name(graph_input_names[i]); - } - - size_t graph_outputs = api_->OrtGraph_GetOutputSize(graph); - for (size_t i = 0; i < graph_outputs; i++) { - onnx::ValueInfoProto* output = g->add_output(); - output->set_name(api_->OrtGraph_GetIthOutputName(graph, i)); - output->mutable_type()->mutable_tensor_type()->set_elem_type(api_->OrtGraph_GetIthOutputElemType(graph, i)); - } - -// auto model_build = graph.CreateModel(*GetLogger()); -// auto& graph_build = model_build->MainGraph(); -// bool has_control_flow_op = false; -// -// // Add node and node args -// // If node output is also parent graph output, the output will be added to the -// // subgraph's output list -// std::vector subgraph_output_names; -// for (const auto& index : group.first) { -// const auto& node = graph.GetNode(node_index[index]); -// std::vector inputs, outputs; -// for (auto input : node->InputDefs()) { -// auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); -// inputs.push_back(&n_input); -// const ONNX_NAMESPACE::TensorProto* initializer = nullptr; -// if (graph.GetInitializedTensor(input->Name(), initializer)) { -// const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; -// if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { -// graph_build.AddInitializedTensor(*(initializer)); -// } -// } -// } -// -// for (auto input : node->ImplicitInputDefs()) { -// const ONNX_NAMESPACE::TensorProto* initializer = nullptr; -// if (graph.GetInitializedTensor(input->Name(), initializer)) { -// const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; -// if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { -// graph_build.AddInitializedTensor(*(initializer)); -// } -// } -// } -// for (auto output : node->OutputDefs()) { -// auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); -// outputs.push_back(&n_output); -// const auto name = output->Name(); -// if (graph_output_names.find(name) != graph_output_names.end()) { -// subgraph_output_names.push_back(name); -// } -// } -// -// if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { -// has_control_flow_op = true; -// } -// -// // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node attributes are not in sync because of graph optimization. -// // Therefore, we need to force GraphProto attributes to be updated in order to get the valid GraphProto. -// if (node->GetAttributes().size() > 0) { -// auto node_proto = ONNX_NAMESPACE::NodeProto::Create(); -// // we need to update any GraphProto attributes for subgraphs so that any changes made by things -// // such as the optimizers are captured. otherwise we can end up saving an invalid graph. -// node->ToProto(*node_proto, /* update_subgraphs */ true); -// const int num_attributes = node_proto->attribute_size(); -// auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); -// node_attributes->reserve(num_attributes); -// -// for (int i = 0; i < num_attributes; ++i) { -// auto& attr = node_proto->attribute(i); -// node_attributes->emplace(attr.name(), attr); -// } -// -// // The GraphProto attributes are the updated ones. -// graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, node_attributes.get(), node->Domain()); -// } else { -// // The GraphProto attributes are the original ones. -// graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); -// } -// } -// -// // Only if the newly built graph has control flow op as well as it has parent node, -// // it needs to handle outer scope values before calling graph.Resolve(). -// if (has_control_flow_op && graph.ParentNode()) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); -// BuildSubGraphContext(graph_build); -// SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); -// SetAllGraphInputs(graph_build); -// } -// -// ORT_ENFORCE(graph_build.Resolve().IsOK()); -// -// // Add parent graph output to the subgraph -// int i = 0; -// std::vector subgraph_outputs; -// subgraph_outputs.resize(subgraph_output_names.size()); -// for (auto& name : subgraph_output_names) { -// auto output_arg = graph.GetNodeArg(name); -// auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); -// subgraph_outputs[i] = &subgraph_output_arg; -// ++i; -// } -// auto& graph_build_outputs = graph_build.GetOutputs(); -// subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); -// graph_build.SetOutputs(graph_build_outputs); -// ORT_ENFORCE(graph_build.Resolve().IsOK()); -// -// // Check if input tensors have shapes -// if (iterations > 1) { -// auto graph_inputs = graph_build.GetInputs(); -// for (auto input_arg : graph_inputs) { -// bool has_dim_value_or_param = true; -// auto input_shape = input_arg->Shape(); -// if (input_shape != nullptr) { -// auto dim_size = input_shape->dim_size(); -// for (int i = 0; i < dim_size; ++i) { -// auto& dim = input_shape->dim(i); -// if (!dim.has_dim_value() && !dim.has_dim_param()) { -// has_dim_value_or_param = false; -// break; -// } -// } -// } -// -// if (input_shape == nullptr || !has_dim_value_or_param) { -// ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, -// "TensorRT input: " + input_arg->Name() + " has no shape specified. " + -// "Please run shape inference on the onnx model first. Details can be found in " + -// "https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs")); -// } -// } -// } -// -// // Serialize modelproto to string -// auto graph_viewer = graph_build.CreateGraphViewer(); -// auto model = graph_viewer->CreateModel(*GetLogger()); -// auto model_proto = model->ToProto(); -// -// // ORT's default topological sort is using reversed DFS. -// // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. -// // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating -// // the model proto that has different node ordering compared to original onnx model. -// graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); -// model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); -// -// std::string string_buf; -// model_proto->SerializeToString(string_buf); -// -// if (dump_subgraphs_) { -// // Dump TensorRT subgraph for debugging -// std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx", std::ios::out | std::ios::trunc | std::ios::binary); -// model_proto->SerializeToOstream(dump); -// } + const OrtGraphViewer* sub_graph_viewer = nullptr; + api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); void* buf_data = nullptr; - size_t buf_size = api_->OrtGraph_SerializeToArray(graph, &buf_data); - std::string string_buf(reinterpret_cast(buf_data), buf_size); + size_t buf_size = api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data); // Get supported node list recursively SubGraphCollection_t parser_nodes_list; @@ -3806,7 +3608,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect #pragma warning(push) #pragma warning(disable : 4996) #endif - trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); + trt_parser->supportsModel(buf_data, buf_size, parser_nodes_list, model_path_); #if defined(_MSC_VER) #pragma warning(pop) #endif @@ -3814,12 +3616,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect SubGraphCollection_t next_nodes_list; size_t subgraph_node_count = 0; const size_t* subgraph_node_index = nullptr; - api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &subgraph_node_count, &subgraph_node_index); - next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, graph, early_termination); + api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_count, &subgraph_node_index); + next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { -// for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { -// next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; -// } + for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; + } nodes_list_output.push_back(next_nodes_list[i]); } } From 330cdb6ad7bb6c416d8c293c9733845f4d9f9256 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 19 Sep 2024 23:45:41 +0000 Subject: [PATCH 38/81] use parameter ExecutionOrder::PRIORITY_BASED for GraphViewerToProto() to make graph partition the same, and improve test --- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- samples/c_test/test.cpp | 29 +++++++++++++------ .../tensorRTEp/tensorrt_execution_provider.cc | 5 +--- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index caece5505b2c0..8216e3d2c1490 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2574,7 +2574,7 @@ ORT_API(size_t, OrtApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, #endif graph_viewer->DomainToVersionMap(), std::vector(), graph_viewer->GetGraph().GetLogger()); onnx::ModelProto model_proto = model.ToProto(); - GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true); + GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true, ExecutionOrder::PRIORITY_BASED); size_t ret = model_proto.ByteSizeLong(); *data = malloc(ret); // TODO(leca): release model_proto.SerializeToArray(*data, ret); diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 454ac7b8728ff..1e97efbdaae52 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -194,22 +194,33 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { // for (size_t i = 0; i < 4; i++) std::cout<CreateEnv(log_level, "", &p_env)); OrtSessionOptions* so = nullptr; THROW_ON_ERROR(g_ort->CreateSessionOptions(&so)); - //TestCompileBasedEp(g_ort, p_env, so); - //TestKernelBasedEp(g_ort, p_env, so); - //TestTensorRTEp(g_ort, p_env, so); - TestTensorRTAndCudaEp(g_ort, p_env, so); - //TestOriginalTensorRTEp(g_ort, so); + if (strcmp(argv[1], "c") == 0) { + TestCompileBasedEp(g_ort, p_env, so); + } else if (strcmp(argv[1], "k") == 0) { + TestKernelBasedEp(g_ort, p_env, so); + } else if (strcmp(argv[1], "t") == 0) { + TestTensorRTEp(g_ort, p_env, so); + } else if (strcmp(argv[1], "tc") == 0) { + TestTensorRTAndCudaEp(g_ort, p_env, so); + } else if (strcmp(argv[1], "otc") == 0) { + TestOriginalTensorRTEp(g_ort, so); + } - //RunRelu(g_ort, p_env, so); - //RunResnet18v1_7(g_ort, p_env, so); - RunFastRcnn(g_ort, p_env, so); + if (!strcmp(argv[2], "relu")) { + RunRelu(g_ort, p_env, so); + } else if (!strcmp(argv[2], "resnet")) { + RunResnet18v1_7(g_ort, p_env, so); + } else if (!strcmp(argv[2], "rcnn")) { + RunFastRcnn(g_ort, p_env, so); + } g_ort->ReleaseEnv(p_env); return 0; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index c20861fbdf16a..d3ebb40031ac9 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -936,10 +936,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::unordered_map node_to_index_map; std::unordered_map index_to_node_map; std::unordered_map> input_to_nodes_map, node_to_outputs_map; - std::unordered_set non_trt_node_index; - for (size_t i = 0; i < node_count; ++i) { - non_trt_node_index.insert(nodes_index[i]); - } + std::unordered_set non_trt_node_index(node_index.begin(), node_index.end()); size_t id = 0; int subgraph_index = 0; for (const auto& group : supported_nodes_vector) { From 6fd50f0fa6579ad9eb2a62ac1503300c7b9ec7a8 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 23 Sep 2024 01:08:13 +0000 Subject: [PATCH 39/81] can create session with out tree trt ep now. Error:Name:'tensorrtEp_TRTKernel_graph_torch-jit-export10497453988321570186_10_10' Status Message: TensorRT EP failed to create engine from network --- .../core/session/onnxruntime_c_api.h | 2 +- onnxruntime/core/framework/provider_adapter.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- samples/c_test/test.cpp | 44 ++++++++++--------- samples/outTreeEp/out_tree_ep.cc | 4 +- samples/qnnEp/qnn_execution_provider.cc | 2 +- .../tensorRTEp/tensorrt_execution_provider.cc | 23 +++++----- .../tensorRTEp/tensorrt_execution_provider.h | 4 +- 8 files changed, 44 insertions(+), 39 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a5cc35e3ee556..ae29fe74da016 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -763,7 +763,7 @@ typedef struct OrtExecutionProvider { extra_param_for_create_state_func{nullptr}, extra_param_for_compute_func{nullptr} {} #endif void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); - OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); + OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info); void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index a485e7be82433..57cd700debad3 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -92,7 +92,7 @@ class ExecutionProviderAdapter : public IExecutionProvider { std::vector cache; cache.resize(count); OrtNodeComputeInfo* cache_data = cache.data(); - OrtStatus* ret = ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &cache_data); + OrtStatus* ret = ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, cache_data); if (ret != nullptr) return ToStatus(ret); node_compute_funcs.reserve(count); for (size_t i = 0; i < count; i++) { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 8216e3d2c1490..7f628f4cb260d 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2482,7 +2482,7 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, s ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Out_ size_t* len, _Outptr_ const OrtNode*** consumers) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); std::vector consumer_nodes = graph_viewer->GetConsumerNodes(input_name); - len = new size_t (consumer_nodes.size()); + *len = consumer_nodes.size(); *consumers = new const OrtNode* [*len]; for (size_t i = 0; i < consumer_nodes.size(); i++) (*consumers)[i] = reinterpret_cast(consumer_nodes[i]); diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 1e97efbdaae52..fcd8bbc6d23a4 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -171,27 +171,29 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { OrtSession* session = nullptr; THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/faster_rcnn/faster_rcnn_R_50_FPN_1x.onnx", so, &session)); -// OrtMemoryInfo* memory_info = nullptr; -// THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); -// -// const int input_cnt = 3 * 800 * 1088; -// int64_t input_data[input_cnt]; -// for (int i = 0; i < input_cnt; i++) input_data[i] = static_cast(rand()) / static_cast(RAND_MAX); // [0, 1) -// const size_t input_len = input_cnt * sizeof(float); -// const int64_t input_shape[] = {3, 800, 1088}; -// OrtValue* input_tensor = nullptr; -// THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); -// -// const char* input_names[] = {"image"}; -// const char* output_names[] = {"6379", "6381", "6383"}; - -// OrtValue* output_tensor = nullptr; -// THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)&input_tensor, sizeof(input_names)/sizeof(input_names[0]), output_names, sizeof(output_names)/sizeof(output_names[0]), &output_tensor)); -// -// float* output_tensor_data = nullptr; -// THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data)); -// std::cout<<"Result:\n"; -// for (size_t i = 0; i < 4; i++) std::cout<CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); + + const int input_cnt = 3 * 800 * 1088; + float* input_data = new float [input_cnt]; + for (int i = 0; i < input_cnt; i++) input_data[i] = static_cast(rand()) / static_cast(RAND_MAX); // [0, 1) + const size_t input_len = input_cnt * sizeof(float); + const int64_t input_shape[] = {3, 800, 1088}; + OrtValue* input_tensor = nullptr; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor)); + + const char* input_names[] = {"image"}; + const char* output_names[] = {"6379", "6381", "6383"}; + + size_t output_count = sizeof(output_names)/sizeof(output_names[0]); + std::vector output_tensors(output_count, nullptr); + //OrtValue* output_tensor = nullptr; + THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)&input_tensor, sizeof(input_names)/sizeof(input_names[0]), output_names, output_count, output_tensors.data())); + + float* output_tensor_data = nullptr; + THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensors[0], (void**)&output_tensor_data)); + std::cout<<"Result:\n"; + for (size_t i = 0; i < 4; i++) std::cout< OrtStatusPtr { + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info) -> OrtStatusPtr { OutTreeEp* p = static_cast(this_); this_->extra_param_for_compute_func = p; for (size_t i = 0; i < cnt; i++) { - node_compute_info[i]->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + node_compute_info[i].ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { const OrtValue* input = nullptr; api->KernelContext_GetInput(context, 0, &input); std::vector dim(1,4); diff --git a/samples/qnnEp/qnn_execution_provider.cc b/samples/qnnEp/qnn_execution_provider.cc index 68b59a4f73c3b..a586b3b3d203f 100644 --- a/samples/qnnEp/qnn_execution_provider.cc +++ b/samples/qnnEp/qnn_execution_provider.cc @@ -27,7 +27,7 @@ QNNExecutionProvider::QNNExecutionProvider(const char* ep_type, const ProviderOp OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { }; - OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info) -> OrtStatusPtr { return nullptr; }; } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index d3ebb40031ac9..d7d668f245734 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -936,7 +936,10 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::unordered_map node_to_index_map; std::unordered_map index_to_node_map; std::unordered_map> input_to_nodes_map, node_to_outputs_map; - std::unordered_set non_trt_node_index(node_index.begin(), node_index.end()); + std::unordered_set non_trt_node_index; + for (size_t i = 0; i < node_count; ++i) { + non_trt_node_index.insert(nodes_index[i]); + } size_t id = 0; int subgraph_index = 0; for (const auto& group : supported_nodes_vector) { @@ -1481,7 +1484,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } }; - OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info) -> OrtStatusPtr { + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info) -> OrtStatusPtr { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); TensorrtExecutionProvider* p = static_cast(this_); this_->extra_param_for_create_state_func = p; @@ -2086,7 +2089,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_funcs) { + OrtNodeComputeInfo* node_compute_funcs) { TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; @@ -2676,7 +2679,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } // Create function state - (*node_compute_funcs)->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + node_compute_funcs->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); std::unique_ptr p = std::make_unique(); @@ -2700,12 +2703,12 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort }; // Release function state - (*node_compute_funcs)->DestroyFunctionStateFunc = [](void* state) { + node_compute_funcs->DestroyFunctionStateFunc = [](void* state) { delete static_cast(state); }; // Create compute function - (*node_compute_funcs)->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + node_compute_funcs->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { Ort::KernelContext ctx(context); TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); TensorrtFuncState* trt_state = reinterpret_cast(state); @@ -3251,7 +3254,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_funcs) { + OrtNodeComputeInfo* node_compute_funcs) { std::unique_ptr trt_engine; std::unique_ptr trt_context; std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index @@ -3334,7 +3337,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi output_info_[fused_node_name].push_back(output_types); // Create function state - (*node_compute_funcs)->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + node_compute_funcs->CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); std::unique_ptr p = std::make_unique(); *p = { context->AllocateFunc, @@ -3352,12 +3355,12 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi }; // Release function state - (*node_compute_funcs)->DestroyFunctionStateFunc = [](void* state) { + node_compute_funcs->DestroyFunctionStateFunc = [](void* state) { delete reinterpret_cast(state); }; // Create compute function - (*node_compute_funcs)->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + node_compute_funcs->ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { TensorrtExecutionProvider* this_ = reinterpret_cast(extra_param); TensorrtShortFuncState* trt_state = reinterpret_cast(state); Ort::KernelContext ctx(context); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 52eca02531743..a7ae47e5f43ec 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -344,13 +344,13 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { OrtStatusPtr CreateNodeComputeInfoFromPrecompiledEngine(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_funcs); + OrtNodeComputeInfo* node_compute_funcs); OrtStatusPtr CreateNodeComputeInfoFromGraph(const OrtGraphViewer* graph_body_viewer, const OrtNode* fused_node, std::unordered_map& input_map, std::unordered_map& output_map, - OrtNodeComputeInfo** node_compute_funcs); + OrtNodeComputeInfo* node_compute_funcs); bool IsGraphCaptureAllowed() const { return false; }; From 681585f7ff05e57932f4951bba451783233c9eab Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 23 Sep 2024 23:55:30 +0000 Subject: [PATCH 40/81] make trt_node_name_with_precision_ from string to map, to capture the corresponding value for compute function --- .../tensorRTEp/tensorrt_execution_provider.cc | 37 +++++++++++-------- .../tensorRTEp/tensorrt_execution_provider.h | 6 +-- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index d7d668f245734..05413e7b476a9 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1435,7 +1435,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } } - if (all_subgraphs_are_supported) { for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { @@ -2269,29 +2268,32 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } } + const char* node_name = nullptr; + api_->OrtNode_GetName(fused_node, &node_name); + // Load INT8 calibration table + std::unordered_map dynamic_range_map; if (int8_enable_ && int8_calibration_cache_available_) { const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map_)) { + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); } } + dynamic_range_map_[node_name] = dynamic_range_map; // Set precision flags - const char* node_name = nullptr; - api_->OrtNode_GetName(fused_node, &node_name); - trt_node_name_with_precision_ = node_name; + std::string trt_node_name_with_precision(node_name); if (fp16_enable_ && int8_enable_) { trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - trt_node_name_with_precision_ += "_fp16_int8"; + trt_node_name_with_precision += "_fp16_int8"; //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; } else if (fp16_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - trt_node_name_with_precision_ += "_fp16"; + trt_node_name_with_precision += "_fp16"; //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; } else if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); - trt_node_name_with_precision_ += "_int8"; + trt_node_name_with_precision += "_int8"; //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; } @@ -2311,10 +2313,11 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); trt_config->setDLACore(dla_core_); - trt_node_name_with_precision_ += "_dlacore" + std::to_string(dla_core_); + trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); } } } + trt_node_name_with_precision_[node_name] = trt_node_name_with_precision; // enable sparse weights if (sparsity_enable_) { @@ -2387,14 +2390,16 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort std::unique_ptr trt_context; std::string cache_path = ""; + std::string cache_suffix = ""; // Customize cache prefix if assigned if (!cache_prefix_.empty()) { // Generate cache suffix in case user would like to customize cache prefix - cache_suffix_ = "_" + GetCacheSuffix(node_name, trt_node_name_with_precision_); - cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix_; + cache_suffix = "_" + GetCacheSuffix(node_name, trt_node_name_with_precision); + cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; } else { - cache_path = GetCachePath(cache_path_, trt_node_name_with_precision_); + cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); } + cache_suffix_[node_name] = cache_suffix; std::string cache_hw_compat = "_sm" + compute_capability_; // Enable hardware compatility mode if assigned @@ -2485,7 +2490,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort #if defined(_MSC_VER) #pragma warning(pop) #endif - if (!SetDynamicRange(*trt_network, dynamic_range_map_)) { + if (!SetDynamicRange(*trt_network, dynamic_range_map)) { return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not set INT8 dynamic range for fused node: " + std::string(node_name)).c_str()); } } @@ -2692,12 +2697,12 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort &(this_->parsers_[context->node_name]), &(this_->engines_[context->node_name]), &(this_->contexts_[context->node_name]), &(this_->networks_[context->node_name]), this_->input_info_[context->node_name], this_->output_info_[context->node_name], this_->input_shape_ranges_[context->node_name], /*&tensorrt_mu_,*/ this_->fp16_enable_, this_->int8_enable_, this_->int8_calibration_cache_available_, - this_->dla_enable_, this_->dla_core_, &(this_->max_workspace_size_), this_->trt_node_name_with_precision_, + this_->dla_enable_, this_->dla_core_, &(this_->max_workspace_size_), this_->trt_node_name_with_precision_[context->node_name], this_->engine_cache_enable_, this_->cache_path_, this_->runtime_.get(), this_->profiles_[context->node_name], - this_->context_memory_sharing_enable_, &(this_->max_ctx_mem_size_), this_->dynamic_range_map_, this_->engine_decryption_enable_, + this_->context_memory_sharing_enable_, &(this_->max_ctx_mem_size_), this_->dynamic_range_map_[context->node_name], this_->engine_decryption_enable_, this_->engine_decryption_, this_->engine_encryption_, this_->timing_cache_enable_, this_->global_cache_path_, this_->force_timing_cache_match_, this_->detailed_build_log_, this_->build_heuristics_enable_, this_->sparsity_enable_, this_->builder_optimization_level_, - this_->auxiliary_streams_, !(this_->tactic_sources_.empty()), tactics, this_->cuda_graph_enable_, this_->cache_prefix_, this_->cache_suffix_, this_->engine_hw_compatible_}; + this_->auxiliary_streams_, !(this_->tactic_sources_.empty()), tactics, this_->cuda_graph_enable_, this_->cache_prefix_, this_->cache_suffix_[context->node_name], this_->engine_hw_compatible_}; *state = p.release(); return 0; }; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index a7ae47e5f43ec..f35c8d4316e84 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -241,9 +241,9 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; static const OrtApi* api_; - std::string trt_node_name_with_precision_; - std::unordered_map dynamic_range_map_; - std::string cache_suffix_; + std::unordered_map trt_node_name_with_precision_; + std::unordered_map> dynamic_range_map_; + std::unordered_map cache_suffix_; private: mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; From 7db20cb29083f0b5e5bb9c3da5da8decd43d44a2 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:19:23 +0800 Subject: [PATCH 41/81] fix redundant inputs and outputs in GetSubgraph (#22201) --- .../tensorRTEp/tensorrt_execution_provider.cc | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 05413e7b476a9..130f4ff063014 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1136,8 +1136,8 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr sub_graph->node_index = new size_t [sub_graph->node_index_len]; sub_graph->meta_def = new OrtMetaDef(); std::unordered_set erased; - std::vector inputs; - std::vector outputs; + std::unordered_map input_to_order; + std::unordered_map output_to_order; int input_order = 0; int output_order = 0; @@ -1162,14 +1162,14 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); // If the input is not produced by any node, it is a graph input if (producer == nullptr) { - inputs.push_back(input_name); + input_to_order[input_name] = input_order++; continue; } size_t producer_index = 0; api_->OrtNode_GetIndex(producer, &producer_index); // If the producer node is not in the subgraph, the input is a graph input if (node_set.find(producer_index) == node_set.end()) { - inputs.push_back(input_name); + input_to_order[input_name] = input_order++; } } @@ -1188,14 +1188,14 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); // If the input is not produced by any node, it is a graph input if (producer == nullptr) { - inputs.push_back(input_name); + input_to_order[input_name] = input_order++; continue; } size_t producer_index = 0; api_->OrtNode_GetIndex(producer, &producer_index); // If the producer node is not in the subgraph, the input is a graph input if (node_set.find(producer_index) == node_set.end()) { - inputs.push_back(input_name); + input_to_order[input_name] = input_order++; } } @@ -1206,7 +1206,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr api_->OrtNode_GetIthOutputName(node, j, &output_name); // If the output is the graph output, it is a subgraph output if (graph_output_names.find(output_name) != graph_output_names.end()) { - outputs.push_back(output_name); + output_to_order[output_name] = output_order++; continue; } size_t consumer_count = 0; @@ -1217,13 +1217,22 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr api_->OrtNode_GetIndex(consumers[k], &consumer_index); // If the consumer node is not in the subgraph, the output is a subgraph output if (node_set.find(consumer_index) == node_set.end()) { - outputs.push_back(output_name); + output_to_order[output_name] = output_order++; break; } } } } + //Sort inputs and outputs based on their order + std::multimap ordered_inputs, ordered_outputs; + for (const auto& input : input_to_order) { + ordered_inputs.insert(std::pair(input.second, input.first)); + } + for (const auto& output : output_to_order) { + ordered_outputs.insert(std::pair(output.second, output.first)); + } + // Generate unique kernel name for TRT subgraph std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); bool is_subgraph = false; @@ -1235,12 +1244,12 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr strcpy(sub_graph->meta_def->name, meta_def_name.c_str()); // Assign inputs and outputs to subgraph's meta_def - sub_graph->meta_def->input_len = inputs.size(); + sub_graph->meta_def->input_len = ordered_inputs.size(); sub_graph->meta_def->inputs = new char* [sub_graph->meta_def->input_len]; i = 0; - for (const auto& input : inputs) { - sub_graph->meta_def->inputs[i] = new char [input.length() + 1]; - strcpy(sub_graph->meta_def->inputs[i++], input.c_str()); + for (const auto& input : ordered_inputs) { + sub_graph->meta_def->inputs[i] = new char [input.second.length() + 1]; + strcpy(sub_graph->meta_def->inputs[i++], input.second.c_str()); } sub_graph->meta_def->initializer_len = initializers.size(); @@ -1251,12 +1260,12 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr strcpy(sub_graph->meta_def->constant_initializers[i++], initializer.c_str()); } - sub_graph->meta_def->output_len = outputs.size(); + sub_graph->meta_def->output_len = ordered_outputs.size(); sub_graph->meta_def->outputs = new char* [sub_graph->meta_def->output_len]; i = 0; - for (const auto& output : outputs) { - sub_graph->meta_def->outputs[i] = new char [output.length() + 1]; - strcpy(sub_graph->meta_def->outputs[i++], output.c_str()); + for (const auto& output : ordered_outputs) { + sub_graph->meta_def->outputs[i] = new char [output.second.length() + 1]; + strcpy(sub_graph->meta_def->outputs[i++], output.second.c_str()); } sub_graph->meta_def->domain = "com.microsoft"; @@ -1320,7 +1329,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const size_t subgraph_count = 0; const OrtGraphViewer** subgraphs = nullptr; api->OrtNode_GetSubgraphs(node, &subgraph_count, &subgraphs); - if (subgraph_count == 0) { + if (subgraph_count != 0) { bool all_subgraphs_are_supported = true; for (size_t i = 0; i < subgraph_count; i++) { // TRT EP should consider the empty subgraph is fully supported by TRT. From ff782e0982ac4c8d5169624844d460c4fdac4a7a Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 25 Sep 2024 23:28:19 +0000 Subject: [PATCH 42/81] RunTinyYolov3() --- samples/c_test/test.cpp | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index fcd8bbc6d23a4..e130d8943bae0 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -187,9 +187,42 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { size_t output_count = sizeof(output_names)/sizeof(output_names[0]); std::vector output_tensors(output_count, nullptr); - //OrtValue* output_tensor = nullptr; THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)&input_tensor, sizeof(input_names)/sizeof(input_names[0]), output_names, output_count, output_tensors.data())); + // This output will be nullptr +// float* output_tensor_data = nullptr; +// THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensors[0], (void**)&output_tensor_data)); +// std::cout<<"Result:\n"; +// for (size_t i = 0; i < 4; i++) std::cout<CreateSession(p_env, "/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session)); + + OrtMemoryInfo* memory_info = nullptr; + THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); + + std::vector input_tensors(2, nullptr); + const int input_cnt = 3 * 416 * 416; + float input_data[input_cnt]; + for (int i = 0; i < input_cnt; i++) input_data[i] = 0.501960813999176; + const size_t input_len = input_cnt * sizeof(float); + const int64_t input_shape[] = {1, 3, 416, 416}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[0])); + + float input2[2] = {375, 500}; + const size_t input2_len = 8; // 2 * sizeof(float) + const int64_t input2_shape[] = {1, 2}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input2, input2_len, input2_shape, sizeof(input2_shape)/sizeof(input2_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[1])); + + const char* input_names[] = {"input_1", "image_shape"}; + const char* output_names[] = {"6379", "6381", "6383"}; + + size_t output_count = sizeof(output_names)/sizeof(output_names[0]); + std::vector output_tensors(output_count, nullptr); + THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)input_tensors.data(), sizeof(input_names)/sizeof(input_names[0]), output_names, output_count, output_tensors.data())); + float* output_tensor_data = nullptr; THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensors[0], (void**)&output_tensor_data)); std::cout<<"Result:\n"; @@ -222,6 +255,8 @@ int main(int argc, char *argv[]) { RunResnet18v1_7(g_ort, p_env, so); } else if (!strcmp(argv[2], "rcnn")) { RunFastRcnn(g_ort, p_env, so); + } else if (!strcmp(argv[2], "tyolo")) { + RunTinyYolov3(p_env, so); } g_ort->ReleaseEnv(p_env); From 1d7b2dfcb6f3f3e7a9d1fce1d5097aa8eb53db37 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Thu, 26 Sep 2024 22:47:05 +0800 Subject: [PATCH 43/81] fix bugs for run tinyYolo (#22233) --- .../core/session/onnxruntime_c_api.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 20 +++++++++---------- onnxruntime/core/session/ort_apis.h | 2 +- samples/c_test/test.cpp | 2 +- .../tensorRTEp/tensorrt_execution_provider.cc | 6 ++++-- .../tensorrt_execution_provider_utils.h | 4 ++-- 6 files changed, 19 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index ae29fe74da016..e6c0ea485d38f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4777,7 +4777,7 @@ struct OrtApi { ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); - ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret); + ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 7f628f4cb260d..9b2d50718677c 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2434,9 +2434,9 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, const Ort return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - *ret = graph_viewer->IsSubgraph(); +ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret) { + const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); + *ret = graph_ptr->IsSubgraph(); return nullptr; } @@ -2610,7 +2610,7 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, std::vector(), graph_viewer->GetGraph().GetLogger()); auto& graph_build = model_build->MainGraph(); - // bool has_control_flow_op = false; + bool has_control_flow_op = false; std::vector subgraph_output_names; const std::vector& node_index = graph_viewer->GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); @@ -2646,10 +2646,10 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, } } - // TODO: handle control flow ops - // if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { - // has_control_flow_op = true; - // } + std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; + if (control_flow_op_set.find(node->OpType()) != control_flow_op_set.end()) { + has_control_flow_op = true; + } // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node attributes are not in sync because of graph optimization. // Therefore, we need to force GraphProto attributes to be updated in order to get the valid GraphProto. @@ -2678,12 +2678,12 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, // TODO:yang // Only if the newly built graph has control flow op as well as it has parent node, // it needs to handle outer scope values before calling graph.Resolve(). - // if (has_control_flow_op && graph.ParentNode()) { + if (has_control_flow_op && graph_viewer->ParentNode()) { // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); // BuildSubGraphContext(graph_build); // SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); // SetAllGraphInputs(graph_build); - // } + } common::Status status = graph_build.Resolve(); if (status != Status::OK()) return ToOrtStatus(status); diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fcf4659021c83..0b16aae0fa253 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -551,7 +551,7 @@ ORT_API_STATUS_IMPL(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path); -ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret); +ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); ORT_API_STATUS_IMPL(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index e130d8943bae0..c89e209b56d09 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -217,7 +217,7 @@ void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so) { THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input2, input2_len, input2_shape, sizeof(input2_shape)/sizeof(input2_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[1])); const char* input_names[] = {"input_1", "image_shape"}; - const char* output_names[] = {"6379", "6381", "6383"}; + const char* output_names[] = {"yolonms_layer_1", "yolonms_layer_1:1", "yolonms_layer_1:2"}; size_t output_count = sizeof(output_names)/sizeof(output_names[0]); std::vector output_tensors(output_count, nullptr); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 130f4ff063014..28b15b3f73514 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1067,7 +1067,7 @@ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* const OrtGraph* cur_graph = nullptr; api->OrtGraph_GetOrtGraph(graph, &cur_graph); bool is_subgraph = false; - api->OrtGraph_IsSubgraph(graph, &is_subgraph); + api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); if (is_subgraph) { const OrtNode* node = nullptr; api->OrtGraph_GetParenNode(graph, &node); @@ -1235,8 +1235,10 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr // Generate unique kernel name for TRT subgraph std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); + const OrtGraph* cur_graph = nullptr; + api_->OrtGraph_GetOrtGraph(graph, &cur_graph); bool is_subgraph = false; - api_->OrtGraph_IsSubgraph(graph, &is_subgraph); + api_->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); const std::string graph_type = is_subgraph ? "subgraph" : "graph"; const char* graph_name = api_->OrtGraph_GetName(graph); std::string meta_def_name = "TRTKernel_" + graph_type + "_" + std::string(graph_name) + subgraph_id; diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index f0f0374865087..e9a9ff0cd46c1 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -274,12 +274,12 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { const OrtGraph* cur_graph = nullptr; api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph); bool is_subgraph = false; - api->OrtGraph_IsSubgraph(graph_viewer, &is_subgraph); + api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); while (is_subgraph) { const OrtGraph* parent_graph = nullptr; api->OrtGraph_GetParentGraph(cur_graph, &parent_graph); cur_graph = parent_graph; - api->OrtGraph_IsSubgraph(graph_viewer, &is_subgraph); + api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); } const OrtGraph* main_graph = cur_graph; From a4079443a963603e1feb22c12fda309aeb3279f0 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 26 Sep 2024 23:42:11 +0000 Subject: [PATCH 44/81] sample code to separate graph C API to different files --- .../core/session/onnxruntime_c_api.h | 5 +++++ .../core/session/onnxruntime_c_api_ep.h | 7 +++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 8 ++++++++ .../core/session/onnxruntime_c_api_ep.cc | 18 ++++++++++++++++++ onnxruntime/core/session/ort_apis.h | 2 ++ onnxruntime/core/session/ort_apis_ep.h | 6 ++++++ samples/c_test/test.cpp | 12 ++++++++---- .../tensorRTEp/tensorrt_execution_provider.cc | 3 +++ .../tensorRTEp/tensorrt_execution_provider.h | 2 +- 9 files changed, 58 insertions(+), 5 deletions(-) create mode 100644 include/onnxruntime/core/session/onnxruntime_c_api_ep.h create mode 100644 onnxruntime/core/session/onnxruntime_c_api_ep.cc create mode 100644 onnxruntime/core/session/ort_apis_ep.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e6c0ea485d38f..b504dddcf62cf 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -673,6 +673,9 @@ typedef struct OrtApi OrtApi; struct OrtTrainingApi; typedef struct OrtTrainingApi OrtTrainingApi; +struct OrtGraphApi; +typedef struct OrtGraphApi OrtGraphApi; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -4876,6 +4879,8 @@ struct OrtApi { ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); ORT_CLASS_RELEASE(TypeConstraints); + + const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(uint32_t version)NO_EXCEPTION; }; // struct OrtApi /* diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h new file mode 100644 index 0000000000000..20a4860cab163 --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -0,0 +1,7 @@ +#pragma once +#include "onnxruntime_c_api.h" + +struct OrtGraphApi { +ORT_API2_STATUS(OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out); +}; +typedef struct OrtGraphApi OrtGraphApi; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 9b2d50718677c..7836c8024e76e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -44,6 +44,8 @@ #include "core/framework/provider_factory_adapter.h" #include "core/framework/kernel_registry.h" #include "core/framework/ort_type_constraints.h" +#include "onnxruntime_c_api_ep.h" +#include "ort_apis_ep.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -2906,6 +2908,11 @@ ORT_API(void, OrtApis::ReleaseTypeConstraints, OrtTypeConstraints* type_constrai delete type_constraints; } +ORT_API(const OrtGraphApi*, OrtApis::GetGraphApi, uint32_t version) { + //if (version >= xx && version <= ORT_API_VERSION) + return OrtGraphApis::GetGraphApi(version); +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -3345,6 +3352,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::CreateOrtTypeConstraints, &OrtApis::AddTypeConstraint, &OrtApis::ReleaseTypeConstraints, + &OrtApis::GetGraphApi, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc new file mode 100644 index 0000000000000..5c89348188faf --- /dev/null +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -0,0 +1,18 @@ +#include "core/session/onnxruntime_c_api_ep.h" +#include "ort_apis_ep.h" +#include "core/graph/graph_viewer.h" + +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *out = graph_viewer->NumberOfNodes(); + return nullptr; +} + +static constexpr OrtGraphApi ort_graph_api = { + &OrtGraphApis::OrtGraph_PlaceHolder, +}; + +ORT_API(const OrtGraphApi*, OrtGraphApis::GetGraphApi, uint32_t) { + // No constraints on the API version yet. + return &ort_graph_api; +} diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 0b16aae0fa253..bf149d4daca4d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -644,4 +644,6 @@ ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type ORT_API_STATUS_IMPL(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); ORT_API(void, ReleaseTypeConstraints, _In_ OrtTypeConstraints* type_constraints); + +ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h new file mode 100644 index 0000000000000..ef0af223504f5 --- /dev/null +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -0,0 +1,6 @@ +#pragma once + +namespace OrtGraphApis { +ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version); +ORT_API_STATUS_IMPL(OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out); +} diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index c89e209b56d09..362c329039581 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -196,9 +196,10 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { // for (size_t i = 0; i < 4; i++) std::cout<CreateSession(p_env, "/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session)); + if (!strcmp(model, "tyolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session)); + else if (!strcmp(model, "yolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/yolov3/yolov3.onnx", so, &session)); OrtMemoryInfo* memory_info = nullptr; THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); @@ -212,6 +213,9 @@ void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so) { THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[0])); float input2[2] = {375, 500}; + if (!strcmp(model, "yolo")) { + input2[0] = 506, input2[1] = 640; + } const size_t input2_len = 8; // 2 * sizeof(float) const int64_t input2_shape[] = {1, 2}; THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input2, input2_len, input2_shape, sizeof(input2_shape)/sizeof(input2_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[1])); @@ -255,8 +259,8 @@ int main(int argc, char *argv[]) { RunResnet18v1_7(g_ort, p_env, so); } else if (!strcmp(argv[2], "rcnn")) { RunFastRcnn(g_ort, p_env, so); - } else if (!strcmp(argv[2], "tyolo")) { - RunTinyYolov3(p_env, so); + } else if (!strcmp(argv[2], "tyolo") || !strcmp(argv[2], "yolo")) { + RunTinyYolov3(p_env, so, argv[2]); } g_ort->ReleaseEnv(p_env); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 28b15b3f73514..9b676a0aee76b 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1280,6 +1280,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); const TensorrtExecutionProvider* p = static_cast(this_); + const OrtGraphApi* g_ort_graph_api = api->GetGraphApi(ORT_API_VERSION); + int num_nodes = 0; + g_ort_graph_api->OrtGraph_PlaceHolder(graph, &num_nodes); // Get ModelPath const std::filesystem::path* model_path = nullptr; api->OrtGraph_GetModelPath(graph, (const void**)&model_path); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index f35c8d4316e84..56feecea84e4e 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -2,7 +2,7 @@ #include #include #include -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include "core/framework/provider_options.h" #include "tensorrt_execution_provider_info.h" #include "nv_includes.h" From f871b25d337d800813151bff2aeb2ef64ae3c6d6 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 2 Oct 2024 23:44:03 +0000 Subject: [PATCH 45/81] new test control_flow, error: ErrorMessage:Failed to find kernel for MemcpyFromHost(1) (node:'Memcpy' ep:'tensorrtEp'). Kernel not found --- .../tensorrt/tensorrt_execution_provider.h | 2 +- .../tensorrt_execution_provider_helper.cc | 4 +- onnxruntime/core/session/onnxruntime_c_api.cc | 249 +++++++++++++++++- samples/c_test/test.cpp | 39 +++ .../tensorRTEp/tensorrt_execution_provider.cc | 10 +- 5 files changed, 292 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index b58e86237860c..6090b5a9ec277 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -546,7 +546,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { /** * The newly-built graph has not yet being resolved by Graph::Resolve(), so we can't leverage - * Graph::ResolveContext::IsOuterScopeValue(). We have to implement this fuction again. + * Graph::ResolveContext::IsOuterScopeValue(). We have to implement this function again. */ bool IsOuterScopeValue(const Graph& graph, const std::string& name) const; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index 92fa101118506..6d68327ca2e21 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -34,7 +34,7 @@ std::string GetUniqueGraphName(const Graph& graph) { } // namespace // The newly-built graph has not yet being resolved by Graph::Resolve(), so we can't leverage -// Graph::ResolveContext::IsInputInitializerOrOutput(). We have to implement this fuction again. +// Graph::ResolveContext::IsInputInitializerOrOutput(). We have to implement this function again. bool TensorrtExecutionProvider::IsInputInitializerOrOutput(const Graph& graph, const std::string& name, bool check_ancestors) const { @@ -125,7 +125,7 @@ void TensorrtExecutionProvider::BuildSubGraphContext(const Graph& graph) const { } } -// Set outer scope values for subgraphs and add thoes values as top-level graph's inputs if needed. +// Set outer scope values for subgraphs and add those values as top-level graph's inputs if needed. void TensorrtExecutionProvider::SetGraphOuterScopeValuesAndInputs(Graph& graph_build, const Graph& graph) const { // Iterate all the nodes and recurse into inner most subgraph first for both newly built graph and original graph diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 7836c8024e76e..d14c0747f4426 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -31,6 +31,7 @@ #include "core/providers/get_execution_providers.h" #include "core/session/environment.h" #include "core/framework/callback.h" +#include "core/framework/murmurhash3.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/inference_session.h" @@ -2594,6 +2595,244 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_DeserializeFromArray, const void* data, si return nullptr; } +struct SubGraphContext2 { + std::unordered_set output_args; + std::unordered_map inputs_and_initializers; + std::unordered_map manually_added_graph_inputs; +}; + +static std::string GetUniqueGraphName(const Graph& graph) { + HashValue model_hash = 0; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // Hash all nodes' name + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + hash_str(node->Name()); + } + + model_hash = hash[0] | (uint64_t(hash[1]) << 32); + + return graph.Name() + "_" + std::to_string(model_hash); +} + +static bool IsLocalValue(const Graph& graph, + const std::string& name, + const std::unordered_map>& subgraph_context_map) { + std::string unique_graph_name = GetUniqueGraphName(graph); + if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end()) { + return false; + } + SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); + return context->output_args.find(name) != context->output_args.cend() || + context->inputs_and_initializers.find(name) != context->inputs_and_initializers.cend(); +} + +static bool IsInputInitializerOrOutput(const Graph& graph, + const std::string& name, + bool check_ancestors, + const std::unordered_map>& subgraph_context_map) { + const Graph* parent_graph = nullptr; + return IsLocalValue(graph, name, subgraph_context_map) || + (check_ancestors && (parent_graph = graph.ParentGraph()) != nullptr && + IsInputInitializerOrOutput(*parent_graph, name, check_ancestors, subgraph_context_map)); +} + +static bool IsOuterScopeValue(const Graph& graph, + const std::string& name, + const std::unordered_map>& subgraph_context_map) { + const Graph* parent_graph = nullptr; + return (parent_graph = graph.ParentGraph()) != nullptr && + IsInputInitializerOrOutput(*parent_graph, name, true, subgraph_context_map); +} + +static void BuildSubGraphContext(const Graph& graph, std::unordered_map>& subgraph_context_map) { + // Iterate all the nodes and recurse into inner most subgraph first + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + + auto subgraph_map = node->GetAttributeNameToSubgraphMap(); + for (auto& entry : subgraph_map) { + const Graph* subgraph = entry.second; + BuildSubGraphContext(*subgraph, subgraph_context_map); + } + } + + std::string unique_graph_name = GetUniqueGraphName(graph); + + // Subgraph context has been built before, no need to do it again + if (subgraph_context_map.find(unique_graph_name) != subgraph_context_map.end()) { + return; + } + + subgraph_context_map.emplace(unique_graph_name, std::make_unique()); + SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); + + // Collect all nodes' outputs and nodes' name + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + + for (const auto& output : node->OutputDefs()) { + context->output_args.insert(output->Name()); + } + } + + // Go thru all node's inputs + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + + for (const auto& input : node->InputDefs()) { + if (context->output_args.find(input->Name()) != context->output_args.end()) { + continue; + } + // This input arg is not the output of another node so must come from either a graph input or an initializer. + context->inputs_and_initializers[input->Name()] = input; + } + } +} + +static void SetGraphOuterScopeValuesAndInputs(Graph& graph_build, + const Graph& graph, + std::unordered_map>& subgraph_context_map) { + // Iterate all the nodes and recurse into inner most subgraph first for both newly built graph and original graph + for (int i = 0; i < graph_build.MaxNodeIndex(); ++i) { + auto graph_build_node = graph_build.GetNode(i); + if (graph_build_node == nullptr) { + continue; + } + + auto graph_build_map = graph_build_node->GetAttributeNameToMutableSubgraphMap(); + std::unordered_map> subgraph_map; + const Node* graph_node = nullptr; + + // Find corresponding original graph node's subgraphs + for (int j = 0; j < graph.MaxNodeIndex(); ++j) { + if (graph.GetNode(j) && graph.GetNode(j)->Name() == graph_build_node->Name()) { + graph_node = graph.GetNode(j); + subgraph_map = graph_node->GetAttributeNameToSubgraphMap(); + break; + } + } + + for (auto& entry : graph_build_map) { + auto attr_name = entry.first; + Graph* subgraph_build = entry.second; + if (subgraph_map.find(attr_name) != subgraph_map.end()) { + // recurse into subgraph + const Graph* subgraph = subgraph_map.at(attr_name); + SetGraphOuterScopeValuesAndInputs(*subgraph_build, *subgraph, subgraph_context_map); + } + } + } + + // Start from the inner most subgraph first and check whether its outer scope values are existed in the + // newly built graph. If not, we need to add those outer scope values as explicit inputs to the top-level + // of newly built graph. + if (graph_build.ParentNode()) { + auto top_level_graph = &graph_build; + while (top_level_graph->MutableParentGraph()) { + top_level_graph = top_level_graph->MutableParentGraph(); + } + std::string unique_graph_name = GetUniqueGraphName(*top_level_graph); + if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end()) { + return; + } + + SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); + + // Iterate all the implicit inputs to set outer scope value for the newly built subgraph + for (const auto& input : graph.ParentNode()->ImplicitInputDefs()) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name(); + + // The node arg in parent node's implicit inputs could be used for parent node's other subgraph, for example + // "If" op has two subgraphs. So we need to make sure that the node arg is used in current subgraph only. + // (GetNodeArg searches for specific node arg in all node args in the graph) + if (graph_build.GetNodeArg(input->Name())) { + graph_build.AddOuterScopeNodeArg(input->Name()); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name() << " is used in this subgraph"; + + if (context && + (context->manually_added_graph_inputs.find(input->Name()) != context->manually_added_graph_inputs.end())) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name() << " is already been added as an explicit input to graph"; + continue; + } + + // Handle the case where this outer scope value is not existed in any outer scope levels of the + // newly built graph (the newly built graph is the subgraph of the original graph). Need to add + // the outer scope value as an explicit input to the top-level of newly built graph. + if (!IsOuterScopeValue(graph_build, input->Name(), subgraph_context_map)) { + const auto& name = input->Name(); + auto graph_inputs_including_initializers = top_level_graph->GetInputsIncludingInitializers(); + auto added_graph_input = std::find_if(graph_inputs_including_initializers.begin(), + graph_inputs_including_initializers.end(), + [&name](const NodeArg* entry) { return entry->Name() == name; }); + + if (added_graph_input == graph_inputs_including_initializers.end()) { + if (context) { + auto type_proto = std::make_unique(); + type_proto->CopyFrom(*(input->TypeAsProto())); + auto& n_input = top_level_graph->GetOrCreateNodeArg(name, type_proto.get()); + context->manually_added_graph_inputs[n_input.Name()] = &n_input; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << n_input.Name() << " is added as an explicit input into the newly built graph"; + } + } + } + } + } + } +} + +static void SetAllGraphInputs(Graph& graph, std::unordered_map>& subgraph_context_map) { + // If ORT TRT doesn't manully set graph input in TensorrtExecutionProvider::SetGraphOuterScopeValuesAndInputs(), + // Graph::Resolve() will help set graph inputs in Graph::SetGraphInputsOutputs(), so no need to set graph inputs here. + std::string unique_graph_name = GetUniqueGraphName(graph); + if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end() || + subgraph_context_map[unique_graph_name].get()->manually_added_graph_inputs.size() == 0) { + return; + } + + SubGraphContext2* context = subgraph_context_map[unique_graph_name].get(); + std::vector graph_inputs_including_initializers; + std::unordered_set graph_inputs_including_initializers_set; + + for (const auto& entry : context->inputs_and_initializers) { + graph_inputs_including_initializers.push_back(entry.second); + graph_inputs_including_initializers_set.insert(entry.first); + } + + for (const auto& entry : context->manually_added_graph_inputs) { + if (graph_inputs_including_initializers_set.find(entry.first) == graph_inputs_including_initializers_set.end()) { + graph_inputs_including_initializers.push_back(entry.second); + graph_inputs_including_initializers_set.insert(entry.first); + } + } + + for (const auto& node_arg : graph.GetInputsIncludingInitializers()) { + if (graph_inputs_including_initializers_set.find(node_arg->Name()) == graph_inputs_including_initializers_set.end()) { + graph_inputs_including_initializers.push_back(node_arg); + graph_inputs_including_initializers_set.insert(node_arg->Name()); + } + } + + graph.SetInputs(graph_inputs_including_initializers); +} + ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); // Get parent graph output names @@ -2680,11 +2919,13 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, // TODO:yang // Only if the newly built graph has control flow op as well as it has parent node, // it needs to handle outer scope values before calling graph.Resolve(). + // TODO(leca): Is local variable enough? Do we need to make it EP class variable? + std::unordered_map> subgraph_context_map; if (has_control_flow_op && graph_viewer->ParentNode()) { // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); - // BuildSubGraphContext(graph_build); - // SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); - // SetAllGraphInputs(graph_build); + BuildSubGraphContext(graph_build, subgraph_context_map); + SetGraphOuterScopeValuesAndInputs(graph_build, graph_viewer->GetGraph(), subgraph_context_map); + SetAllGraphInputs(graph_build, subgraph_context_map); } common::Status status = graph_build.Resolve(); @@ -2878,7 +3119,7 @@ ORT_API(float, OrtApis::OrtNode_GetAttributeFloat, const OrtNode* node, const ch ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs) { const ::onnxruntime::Node* n = reinterpret_cast(node); std::vector> subg = n->GetSubgraphs(); - len = new size_t (subg.size()); + *len = subg.size(); *subgraphs = new const OrtGraphViewer* [*len]; for (size_t i = 0; i < subg.size(); i++) { const ::onnxruntime::GraphViewer* graph_viewer = new const ::onnxruntime::GraphViewer(*subg[i]); diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 362c329039581..db40463b5975c 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -233,6 +233,43 @@ void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so, const char* model) { for (size_t i = 0; i < 4; i++) std::cout<CreateSession(p_env, "/home/leca/models/control_flow/control_flow_model.onnx", so, &session)); + + OrtMemoryInfo* memory_info = nullptr; + THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); + + std::vector input_tensors(3, nullptr); + const int input_cnt = 2; + float input_data[input_cnt]; + for (int i = 0; i < input_cnt; i++) input_data[i] = 1; + const size_t input_len = input_cnt * sizeof(float); + const int64_t input_shape[] = {1, 2}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[0])); + + float input2[2] = {0.36252614855766296, 0.030415434390306473}; + const size_t input2_len = 8; // 2 * sizeof(float) + const int64_t input2_shape[] = {1, 2}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input2, input2_len, input2_shape, sizeof(input2_shape)/sizeof(input2_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[1])); + + float input3 = 0.5945659279823303; + const int64_t input3_shape[] = {1}; + THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, &input3, 4, input3_shape, sizeof(input3_shape)/sizeof(input3_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[2])); + + const char* input_names[] = {"x1.opt", "x2", "x3"}; + const char* output_names[] = {"y"}; + + size_t output_count = sizeof(output_names)/sizeof(output_names[0]); + std::vector output_tensors(output_count, nullptr); + THROW_ON_ERROR(g_ort->Run(session, nullptr, input_names, (const OrtValue* const*)input_tensors.data(), sizeof(input_names)/sizeof(input_names[0]), output_names, output_count, output_tensors.data())); + + float* output_tensor_data = nullptr; + THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensors[0], (void**)&output_tensor_data)); +// std::cout<<"Result:\n"; +// for (size_t i = 0; i < 4; i++) std::cout<ReleaseEnv(p_env); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 9b676a0aee76b..be78f52d6df4c 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1083,7 +1083,8 @@ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* // Check whether all the nodes of the graph are assigned to specific ep bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - std::vector nodes_vector(api->OrtGraph_NumberOfNodes(graph)); + const int number_of_ort_nodes = api->OrtGraph_NumberOfNodes(graph); + std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); size_t node_count = 0; const size_t* nodes_index = nullptr; @@ -1093,12 +1094,11 @@ bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraphViewe api->OrtGraph_GetOrtNode(graph, nodes_index[index], &node); const char* node_ep_type; api->OrtNode_GetExecutionProviderType(node, &node_ep_type); - if (!strcmp(node_ep_type, provider_type.c_str())) { + if (strcmp(node_ep_type, provider_type.c_str())) { return false; } } - return true; - + return number_of_ort_nodes != 0; } // Check whether all the nodes of subgraph are supported @@ -1430,7 +1430,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const break; } // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. - else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], "TensorrtExecutionProvider")) { + else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], "tensorrtEp")) { all_subgraphs_are_supported = true; break; } From e84f00c90de857a9237a87941374e8c3584cadad Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 3 Oct 2024 23:45:25 +0000 Subject: [PATCH 46/81] control flow model works --- .../core/session/onnxruntime_c_api.h | 7 +++ samples/c_test/test.cpp | 49 +++++++++++++++- .../tensorRTEp/tensorrt_execution_provider.cc | 57 +++++++++++++++++++ 3 files changed, 111 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b504dddcf62cf..48786dc26396b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4909,6 +4909,13 @@ typedef enum OrtCustomOpInputOutputCharacteristic { * the implementor of the custom op. */ struct OrtCustomOp { +#ifdef __cplusplus + // TODO(leca): initialize all member function pointers to nullptr? + OrtCustomOp() : CreateKernel{nullptr}, KernelCompute{nullptr}, KernelDestroy{nullptr}, GetInputCharacteristic{nullptr}, + GetOutputCharacteristic{nullptr}, GetVariadicInputMinArity{nullptr}, GetVariadicOutputMinArity{nullptr}, + GetStartVersion{nullptr}, GetEndVersion{nullptr}, GetMayInplace{nullptr}, ReleaseMayInplace{nullptr}, + GetAliasMap{nullptr}, ReleaseAliasMap{nullptr} {} +#endif uint32_t version; // Must be initialized to ORT_API_VERSION // This callback creates the kernel, which is a user defined diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index db40463b5975c..b6f676d50daef 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -222,6 +222,11 @@ void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so, const char* model) { const char* input_names[] = {"input_1", "image_shape"}; const char* output_names[] = {"yolonms_layer_1", "yolonms_layer_1:1", "yolonms_layer_1:2"}; + if (!strcmp(model, "yolo")) { + output_names[0] = "yolonms_layer_1/ExpandDims_1:0"; + output_names[1] = "yolonms_layer_1/ExpandDims_3:0"; + output_names[2] = "yolonms_layer_1/concat_2:0"; + } size_t output_count = sizeof(output_names)/sizeof(output_names[0]); std::vector output_tensors(output_count, nullptr); @@ -266,8 +271,8 @@ void RunControlFlow(OrtEnv* p_env, OrtSessionOptions* so) { float* output_tensor_data = nullptr; THROW_ON_ERROR(g_ort->GetTensorMutableData(output_tensors[0], (void**)&output_tensor_data)); -// std::cout<<"Result:\n"; -// for (size_t i = 0; i < 4; i++) std::cout<CreateSessionOptions(&so)); + // sanity tests +// if (argc == 1) { +// std::cout<<"Compile based EP, relu:\n"; +// TestCompileBasedEp(g_ort, p_env, so); +// RunRelu(g_ort, p_env, so); +// +// std::cout<<"Kernel based EP, relu:\n"; +// TestKernelBasedEp(g_ort, p_env, so); +// RunRelu(g_ort, p_env, so); +// +// std::cout<<"TRT, relu:\n"; +// TestTensorRTEp(g_ort, p_env, so); +// RunRelu(g_ort, p_env, so); +// +// std::cout<<"out tree TRT + In tree cuda, relu:\n"; +// TestTensorRTAndCudaEp(g_ort, p_env, so); +// RunRelu(g_ort, p_env, so); +// +// std::cout<<"out tree TRT + In tree cuda, resnet:\n"; +// TestTensorRTAndCudaEp(g_ort, p_env, so); +// RunResnet18v1_7(g_ort, p_env, so); +// +// std::cout<<"out tree TRT + In tree cuda, fast rcnn:\n"; +// TestTensorRTAndCudaEp(g_ort, p_env, so); +// RunFastRcnn(g_ort, p_env, so); +// +// std::cout<<"out tree TRT + In tree cuda, tiny yoloV3:\n"; +// TestTensorRTAndCudaEp(g_ort, p_env, so); +// RunTinyYolov3(p_env, so, "tyolo"); +// +// std::cout<<"out tree TRT + In tree cuda, yoloV3:\n"; +// TestTensorRTAndCudaEp(g_ort, p_env, so); +// RunTinyYolov3(p_env, so, "yolo"); +// +// std::cout<<"out tree TRT + In tree cuda, control flow:\n"; +// TestTensorRTAndCudaEp(g_ort, p_env, so); +// RunControlFlow(p_env, so); +// return 0; +// } + if (strcmp(argv[1], "c") == 0) { TestCompileBasedEp(g_ort, p_env, so); } else if (strcmp(argv[1], "k") == 0) { diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index be78f52d6df4c..fcdfd494f081c 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -27,6 +27,52 @@ void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } namespace onnxruntime { +//static const std::string + +struct MemcpyFromHost : OrtCustomOp { + MemcpyFromHost() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::GetName = [](const struct OrtCustomOp* op) { return "MemcpyFromHost"; }; + OrtCustomOp::GetExecutionProviderType = [](const struct OrtCustomOp* op) { return "tensorrtEp"; }; + OrtCustomOp::CreateKernelV2 = [](const struct OrtCustomOp* op, const OrtApi* api, const OrtKernelInfo* info, void** kernel) -> OrtStatusPtr { + return nullptr; + }; + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + void* stream = nullptr; + api->KernelContext_GetGPUComputeStream(context, &stream); + + const OrtValue* input = nullptr; + api->KernelContext_GetInput(context, 0, &input); + OrtTensorTypeAndShapeInfo* shape_info; + api->GetTensorTypeAndShape(input, &shape_info); + size_t dim_count = 0; + api->GetDimensionsCount(shape_info, &dim_count); + std::vector dim(dim_count, 0); + api->GetDimensions(shape_info, dim.data(), dim_count); + + OrtValue* output = nullptr; + api->KernelContext_GetOutput(context, 0, dim.data(), dim.size(), &output); + + void* input_raw = nullptr, *output_raw = nullptr; + api->GetTensorMutableData(const_cast(input), &input_raw); + api->GetTensorMutableData(output, &output_raw); + + size_t count = dim[0]; + for (size_t i = 1; i < dim_count; i++) count *= dim[i]; + cudaMemcpyAsync(output_raw, input_raw, count * sizeof(float) , cudaMemcpyHostToDevice, static_cast(stream)); // TODO(leca): other data type + + return nullptr; + }; + OrtCustomOp::GetInputTypeCount = [](const struct OrtCustomOp* op) -> size_t { return 1; }; + OrtCustomOp::GetOutputTypeCount = [](const struct OrtCustomOp* op) -> size_t { return 1; }; + OrtCustomOp::GetInputMemoryType = [](const struct OrtCustomOp* op, size_t index) { return OrtMemType::OrtMemTypeCPUInput; }; + OrtCustomOp::GetInputType = [](const struct OrtCustomOp* op, size_t index) { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + OrtCustomOp::GetOutputType = [](const struct OrtCustomOp* op, size_t index) { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; // TODO(leca): other data type + OrtCustomOp::GetStartVersion = [](const struct OrtCustomOp* op) { return 1; }; + } +}; + template using IAllocatorUniquePtr = std::unique_ptr>; const OrtApi* TensorrtExecutionProvider::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); @@ -1598,6 +1644,17 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const return stream; }; + OrtExecutionProvider::RegisterKernels = [](OrtKernelRegistry* kernel_registry) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + + OrtTypeConstraints* type_constraints = nullptr; + api->CreateOrtTypeConstraints(&type_constraints); + api->AddTypeConstraint(type_constraints, "T", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); // TODO(leca): other data type + OrtCustomOp* op = new MemcpyFromHost(); + api->OrtKernelRegistry_RegisterKernel(kernel_registry, op, type_constraints); + api->ReleaseTypeConstraints(type_constraints); + }; + info_ = TensorrtExecutionProviderInfo::FromProviderOptions(ep_info); device_id_ = info_.device_id; api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); From 5b2de226ef7c4a22c7c9fb8d7d2dc99f2bf75cc9 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 7 Oct 2024 14:32:49 +0000 Subject: [PATCH 47/81] API refactor --- .../core/framework/ort_type_constraints.h | 2 +- .../onnxruntime/core/session/environment.h | 1 + .../core/session/onnxruntime_c_api.h | 191 +---- .../core/session/onnxruntime_c_api_ep.h | 180 ++++- onnxruntime/core/framework/provider_adapter.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 763 ------------------ .../core/session/onnxruntime_c_api_ep.cc | 721 ++++++++++++++++- onnxruntime/core/session/ort_apis.h | 98 --- onnxruntime/core/session/ort_apis_ep.h | 98 ++- onnxruntime/core/session/ort_env.cc | 1 + onnxruntime/core/session/ort_env.h | 1 + samples/c_test/sanityTests.sh | 28 + samples/c_test/test.cpp | 2 +- samples/outTreeEp/out_tree_ep.cc | 21 +- samples/outTreeEp/out_tree_ep.h | 2 +- samples/outTreeEp_kernel/kernel_ep.h | 2 +- samples/qnnEp/qnn_execution_provider.h | 2 +- samples/tensorRTEp/onnx_ctx_model_helper.cc | 40 +- samples/tensorRTEp/onnx_ctx_model_helper.h | 4 +- .../tensorRTEp/tensorrt_execution_provider.cc | 279 +++---- .../tensorRTEp/tensorrt_execution_provider.h | 1 + .../tensorrt_execution_provider_utils.h | 33 +- 22 files changed, 1184 insertions(+), 1288 deletions(-) create mode 100755 samples/c_test/sanityTests.sh diff --git a/include/onnxruntime/core/framework/ort_type_constraints.h b/include/onnxruntime/core/framework/ort_type_constraints.h index 1224e56d58fb9..dfc418ca27510 100644 --- a/include/onnxruntime/core/framework/ort_type_constraints.h +++ b/include/onnxruntime/core/framework/ort_type_constraints.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include #include #include diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 08a3730827835..c05d9768d9b0b 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -10,6 +10,7 @@ #include "core/platform/threadpool.h" #include "core/common/logging/logging.h" #include "core/framework/allocator.h" +#include "core/session/onnxruntime_c_api_ep.h" struct OrtThreadingOptions; namespace onnxruntime { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 48786dc26396b..87e5867793b61 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -304,11 +304,9 @@ ORT_RUNTIME_CLASS(Op); ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); -ORT_RUNTIME_CLASS(ExecutionProvider); -ORT_RUNTIME_CLASS(ExecutionProviderFactory); -ORT_RUNTIME_CLASS(Node); -ORT_RUNTIME_CLASS(Graph); -ORT_RUNTIME_CLASS(GraphViewer); +ORT_RUNTIME_CLASS(KernelInfo); +ORT_RUNTIME_CLASS(KernelContext); +ORT_RUNTIME_CLASS(CustomOp); ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(TypeConstraints); ORT_RUNTIME_CLASS(Device); @@ -372,13 +370,6 @@ typedef enum OrtLanguageProjection { ORT_PROJECTION_NODEJS = 6, } OrtLanguageProjection; -struct OrtKernelInfo; -typedef struct OrtKernelInfo OrtKernelInfo; -struct OrtKernelContext; -typedef struct OrtKernelContext OrtKernelContext; -struct OrtCustomOp; -typedef struct OrtCustomOp OrtCustomOp; - typedef enum OrtAllocatorType { OrtInvalidAllocator = -1, OrtDeviceAllocator = 0, @@ -707,81 +698,6 @@ typedef struct OrtApiBase OrtApiBase; */ ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; -typedef struct OrtCreateStream { - int device_type; - void*(ORT_API_CALL* CreateStreamFunc)(const OrtDevice*); -} OrtCreateStream; - -typedef struct OrtMetaDef { - char* name; - char* domain; - int since_version; - - char** inputs; - size_t input_len; - char** outputs; - size_t output_len; - char** constant_initializers; - size_t initializer_len; - - char* doc_string; -} OrtMetaDef; - -typedef struct OrtIndexedSubGraph { - OrtMetaDef* meta_def; // TODO(leca): how to define a nested structure pointer? - size_t* node_index; - size_t node_index_len; -} OrtIndexedSubGraph; - -typedef struct OrtComputeContext { - void*(ORT_API_CALL* AllocateFunc)(void*, size_t, size_t); - void(ORT_API_CALL* DestroyFunc)(void*, void*); - void* allocator_handle; - const char* node_name; -} OrtComputeContext; - -typedef struct OrtNodeComputeInfo { - int(ORT_API_CALL* CreateFunctionStateFunc)(OrtComputeContext*, void*, void**); - OrtStatusPtr(ORT_API_CALL* ComputeFunc)(void*, void*, const OrtApi*, OrtKernelContext*); - void(ORT_API_CALL* DestroyFunctionStateFunc)(void*); -} OrtNodeComputeInfo; - -typedef struct OrtTensorRef { // TODO(leca): OrtValueInfoRef inside OrtTensorRef? - int64_t* shape; - size_t shape_len; - ONNXTensorElementDataType data_type; - const char* data; - size_t data_len; -} OrtTensorRef; - -typedef struct OrtValueInfoRef { - int64_t* shape; - size_t shape_len; - ONNXTensorElementDataType data_type; -} OrtValueInfoRef; - -typedef struct OrtExecutionProvider { -#ifdef __cplusplus - OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, CreatePreferredAllocators{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr}, - extra_param_for_create_state_func{nullptr}, extra_param_for_compute_func{nullptr} {} -#endif - void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); - OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info); - void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); - bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); - OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); - int(ORT_API_CALL* CreatePreferredAllocators)(OrtExecutionProvider* this_, OrtAllocator*** ort_allocators); - const char* type; - OrtCreateStream* create_stream; - const OrtDevice* default_device; - void* extra_param_for_create_state_func; - void* extra_param_for_compute_func; -} OrtExecutionProvider; - -typedef struct OrtExecutionProviderFactory { - OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); -} OrtExecutionProviderFactory; - /** \brief Thread work loop function * * Onnxruntime will provide the working loop on custom thread creation @@ -4403,9 +4319,6 @@ struct OrtApi { */ const char*(ORT_API_CALL* GetBuildInfoString)(void); - /// \name OrtROCMProviderOptions - /// @{ - /** \brief Create an OrtROCMProviderOptions * * \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions @@ -4774,104 +4687,6 @@ struct OrtApi { ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); - const char*(ORT_API_CALL* OrtGraph_GetName)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); - - ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); - - ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); - - ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); - - ORT_API2_STATUS(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); - - ORT_API2_STATUS(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path); - - ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); - - ORT_API2_STATUS(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Out_ size_t* num_inputs, _Outptr_ const char*** input_names); - - ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); - - ORT_API2_STATUS(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Out_ size_t* len, _Outptr_ const OrtNode*** consumers); // TODO(leca): ValueConsumers::comprehensive ? - - ORT_API2_STATUS(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); - - int(ORT_API_CALL* OrtGraph_NumberOfNodes)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - ORT_API2_STATUS(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out); - - size_t(ORT_API_CALL* OrtGraph_GetOutputSize)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - const char*(ORT_API_CALL* OrtGraph_GetIthOutputName)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - bool(ORT_API_CALL* OrtGraph_GetInitializerTensor)(const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); - - bool(ORT_API_CALL* OrtGraph_GetValueInfo)(const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); - - size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; // TODO(leca): review and discuss - - ORT_API2_STATUS(OrtGraph_DeserializeFromArray, const void* data, size_t len, _Outptr_ OrtGraphViewer**); // TODO(leca): review and discuss - - ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss - - ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); - - ORT_API2_STATUS(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); - - ORT_API2_STATUS(OrtNode_GetDomain, const OrtNode* node, _Out_ const char** domain); - - ORT_API2_STATUS(OrtNode_SinceVersion, const OrtNode* node, _Out_ int* since_version); - - ORT_API2_STATUS(OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** ep_type); - - ORT_API2_STATUS(OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type); - - ORT_API2_STATUS(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* input_size); - - ORT_API2_STATUS(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); - - ORT_API2_STATUS(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size); - - ORT_API2_STATUS(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); - - ORT_API2_STATUS(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size); - - ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); - - ORT_API2_STATUS(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index); - - size_t(ORT_API_CALL* OrtNode_GetAttributeNames)(const OrtNode*, _Out_ const char*** names); - - ORT_API2_STATUS(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size); - - int(ORT_API_CALL* OrtNode_GetAttributeType)(const OrtNode* node, const char* attribute)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; // AttributeProto_AttributeType - - ORT_API2_STATUS(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count); - - ORT_API2_STATUS(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size); - - ORT_API2_STATUS(OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* float_size); - - ORT_API2_STATUS(OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* str_size); - - ORT_API2_STATUS(OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* ints); - - ORT_API2_STATUS(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* floats); - - ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs); - - const char*(ORT_API_CALL* OrtNode_GetAttributeStr)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - float(ORT_API_CALL* OrtNode_GetAttributeFloat)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - ORT_API2_STATUS(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); - ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index 20a4860cab163..87d742911bef3 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -1,7 +1,185 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "onnxruntime_c_api.h" +ORT_RUNTIME_CLASS(ExecutionProvider); +ORT_RUNTIME_CLASS(ExecutionProviderFactory); +ORT_RUNTIME_CLASS(Node); +ORT_RUNTIME_CLASS(Graph); +ORT_RUNTIME_CLASS(GraphViewer); + +typedef struct OrtCreateStream { + int device_type; + void*(ORT_API_CALL* CreateStreamFunc)(const OrtDevice*); +} OrtCreateStream; + +typedef struct OrtMetaDef { + char* name; + char* domain; + int since_version; + + char** inputs; + size_t input_len; + char** outputs; + size_t output_len; + char** constant_initializers; + size_t initializer_len; + + char* doc_string; +} OrtMetaDef; + +typedef struct OrtIndexedSubGraph { + OrtMetaDef* meta_def; // TODO(leca): how to define a nested structure pointer? + size_t* node_index; + size_t node_index_len; +} OrtIndexedSubGraph; + +typedef struct OrtComputeContext { + void*(ORT_API_CALL* AllocateFunc)(void*, size_t, size_t); + void(ORT_API_CALL* DestroyFunc)(void*, void*); + void* allocator_handle; + const char* node_name; +} OrtComputeContext; + +typedef struct OrtNodeComputeInfo { + int(ORT_API_CALL* CreateFunctionStateFunc)(OrtComputeContext*, void*, void**); + OrtStatusPtr(ORT_API_CALL* ComputeFunc)(void*, void*, const OrtApi*, OrtKernelContext*); + void(ORT_API_CALL* DestroyFunctionStateFunc)(void*); +} OrtNodeComputeInfo; + +typedef struct OrtTensorRef { // TODO(leca): OrtValueInfoRef inside OrtTensorRef? + int64_t* shape; + size_t shape_len; + ONNXTensorElementDataType data_type; + const char* data; + size_t data_len; +} OrtTensorRef; + +typedef struct OrtValueInfoRef { + int64_t* shape; + size_t shape_len; + ONNXTensorElementDataType data_type; +} OrtValueInfoRef; + +typedef struct OrtExecutionProvider { +#ifdef __cplusplus + OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, CreatePreferredAllocators{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr}, + extra_param_for_create_state_func{nullptr}, extra_param_for_compute_func{nullptr} {} +#endif + void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); + OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info); + void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); + bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); + OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); + int(ORT_API_CALL* CreatePreferredAllocators)(OrtExecutionProvider* this_, OrtAllocator*** ort_allocators); + const char* type; + OrtCreateStream* create_stream; + const OrtDevice* default_device; + void* extra_param_for_create_state_func; + void* extra_param_for_compute_func; +} OrtExecutionProvider; + +typedef struct OrtExecutionProviderFactory { + OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); +} OrtExecutionProviderFactory; + struct OrtGraphApi { -ORT_API2_STATUS(OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out); +const char*(ORT_API_CALL* OrtGraph_GetName)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +bool(ORT_API_CALL* OrtGraph_IsConstantInitializer)(const OrtGraphViewer* graph, const char* name, bool check_outer_scope)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +size_t(ORT_API_CALL* OrtGraph_GetNodesIndexInTopologicalOrder)(const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order); + +bool(ORT_API_CALL* OrtGraph_IsSubgraph)(const OrtGraph* graph); + +const OrtGraph*(ORT_API_CALL* OrtGraph_GetParentGraph)(const OrtGraph* graph); + +const OrtNode*(ORT_API_CALL* OrtGraph_GetParenNode)(const OrtGraphViewer* graph); + +const void*(ORT_API_CALL* OrtGraph_GetModelPath)(const OrtGraphViewer* graph); + +const OrtGraph*(ORT_API_CALL* OrtGraph_GetOrtGraph)(const OrtGraphViewer* graph_viewer); + +size_t(ORT_API_CALL* OrtGraph_GetInputsIncludingInitializers)(const OrtGraphViewer* graph, _Outptr_ const char*** input_names); + +const OrtNode*(ORT_API_CALL* OrtGraph_GetOrtNode)(const OrtGraphViewer* graph, size_t node_index); + +size_t(ORT_API_CALL* OrtGraph_GetNodesConsumingInput)(const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers); // TODO(leca): ValueConsumers::comprehensive ? + +const OrtNode*(ORT_API_CALL* OrtGraph_GetNodeProducingOutput)(const OrtGraphViewer* graph, const char* output_name); + +int(ORT_API_CALL* OrtGraph_NumberOfNodes)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +int(ORT_API_CALL* OrtGraph_MaxNodeIndex)(const OrtGraphViewer* graph); + +size_t(ORT_API_CALL* OrtGraph_GetOutputSize)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +const char*(ORT_API_CALL* OrtGraph_GetIthOutputName)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +bool(ORT_API_CALL* OrtGraph_GetInitializerTensor)(const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); + +bool(ORT_API_CALL* OrtGraph_GetValueInfo)(const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); + +size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; // TODO(leca): review and discuss + +ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss + +const char*(ORT_API_CALL* OrtNode_GetName)(const OrtNode* node); + +const char*(ORT_API_CALL* OrtNode_GetDescription)(const OrtNode* node); + +const char*(ORT_API_CALL* OrtNode_GetDomain)(const OrtNode* node); + +int(ORT_API_CALL* OrtNode_SinceVersion)(const OrtNode* node); + +const char*(ORT_API_CALL* OrtNode_GetExecutionProviderType)(const OrtNode* node); + +const char*(ORT_API_CALL* OrtNode_GetOpType)(const OrtNode* node); + +size_t(ORT_API_CALL* OrtNode_GetImplicitInputSize)(const OrtNode* node); + +const char*(ORT_API_CALL* OrtNode_GetIthImplicitInputName)(const OrtNode* node, size_t i); + +size_t(ORT_API_CALL* OrtNode_GetInputSize)(const OrtNode* node); + +const char*(ORT_API_CALL* OrtNode_GetIthInputName)(const OrtNode* node, size_t i); + +size_t(ORT_API_CALL* OrtNode_GetOutputSize)(const OrtNode* node); + +const char*(ORT_API_CALL* OrtNode_GetIthOutputName)(const OrtNode* node, size_t i); + +size_t(ORT_API_CALL* OrtNode_GetIndex)(const OrtNode* node); + +size_t(ORT_API_CALL* OrtNode_GetAttributeNames)(const OrtNode*, _Out_ const char*** names); + +size_t(ORT_API_CALL* OrtNode_GetAttributeSize)(const OrtNode* node); + +int(ORT_API_CALL* OrtNode_GetAttributeType)(const OrtNode* node, const char* attribute)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; // AttributeProto_AttributeType + +size_t(ORT_API_CALL* OrtNode_GetAttributeKeyCount)(const OrtNode* node, const char* key); + +int(ORT_API_CALL* OrtNode_GetAttributeIntSize)(const OrtNode* node, const char* key); + +int(ORT_API_CALL* OrtNode_GetAttributeFloatSize)(const OrtNode* node, const char* key); + +int(ORT_API_CALL* OrtNode_GetAttributeStringSize)(const OrtNode* node, const char* key); + +int64_t(ORT_API_CALL* OrtNode_GetAttributeIthInt)(const OrtNode* node, const char* key, int i); + +float(ORT_API_CALL* OrtNode_GetAttributeIthFloat)(const OrtNode* node, const char* key, int i); + +const char*(ORT_API_CALL* OrtNode_GetAttributeIthStr)(const OrtNode* node, const char* key, int i); + +const char*(ORT_API_CALL* OrtNode_GetAttributeStr)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +float(ORT_API_CALL* OrtNode_GetAttributeFloat)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + +size_t(ORT_API_CALL* OrtNode_GetSubgraphs)(const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs); }; typedef struct OrtGraphApi OrtGraphApi; diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 57cd700debad3..7f5582da84b33 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include "core/framework/compute_capability.h" #include "core/framework/error_code_helper.h" #include "core/framework/kernel_registry.h" diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d14c0747f4426..e18b61618ffeb 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -21,17 +21,13 @@ #include "core/common/status.h" #include "core/common/safeint.h" #include "core/graph/constants.h" -#include "core/graph/model.h" #include "core/graph/graph.h" -#include "core/graph/graph_proto_serializer.h" -#include "core/graph/graph_viewer.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" #include "core/framework/ort_value.h" #include "core/providers/get_execution_providers.h" #include "core/session/environment.h" #include "core/framework/callback.h" -#include "core/framework/murmurhash3.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/inference_session.h" @@ -2418,716 +2414,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendOrtExecutionProvider, _In_ OrtS return nullptr; } -ORT_API(const char*, OrtApis::OrtGraph_GetName, const OrtGraphViewer* graph) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->Name().c_str(); -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - *ret = graph_viewer->IsConstantInitializer(name, check_outer_scope); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - const std::vector& nodes = graph_viewer->GetNodesInTopologicalOrder(static_cast(execution_order)); - *len = nodes.size(); - *nodes_index_in_topological_order = nodes.data(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret) { - const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); - *ret = graph_ptr->IsSubgraph(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph) { - const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); - *parent_graph = reinterpret_cast(graph_ptr->ParentGraph()); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - *parent_node = reinterpret_cast(graph_viewer->ParentNode()); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - *path = reinterpret_cast(&graph_viewer->ModelPath()); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph) { - const ::onnxruntime::GraphViewer* graph_viewer_ptr = reinterpret_cast(graph_viewer); - *graph = reinterpret_cast(&graph_viewer_ptr->GetGraph()); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Out_ size_t* num_inputs, _Outptr_ const char*** input_names) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - const auto& inputs = graph_viewer->GetInputsIncludingInitializers(); - *num_inputs = inputs.size(); - *input_names = new const char*[*num_inputs]; - for (size_t i = 0; i < *num_inputs; i++) (*input_names)[i] = inputs[i]->Name().c_str(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - *node = reinterpret_cast(graph_viewer->GetNode(node_index)); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Out_ size_t* len, _Outptr_ const OrtNode*** consumers) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - std::vector consumer_nodes = graph_viewer->GetConsumerNodes(input_name); - *len = consumer_nodes.size(); - *consumers = new const OrtNode* [*len]; - for (size_t i = 0; i < consumer_nodes.size(); i++) (*consumers)[i] = reinterpret_cast(consumer_nodes[i]); - - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - *producer = reinterpret_cast(graph_viewer->GetProducerNode(output_name)); - return nullptr; -} - -ORT_API(int, OrtApis::OrtGraph_NumberOfNodes, const OrtGraphViewer* graph) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->NumberOfNodes(); -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - *out = graph_viewer->MaxNodeIndex(); - return nullptr; -} - -ORT_API(size_t, OrtApis::OrtGraph_GetOutputSize, const OrtGraphViewer* graph) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->GetOutputs().size(); -} - -ORT_API(const char*, OrtApis::OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->GetOutputs()[i]->Name().c_str(); -} - -ORT_API(int32_t, OrtApis::OrtGraph_GetIthOutputElemType, const OrtGraphViewer* graph, size_t i) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->GetOutputs()[i]->TypeAsProto()->tensor_type().elem_type(); -} - -ORT_API(bool, OrtApis::OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** out) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - const onnx::TensorProto* initializer = nullptr; - if (!graph_viewer->GetInitializedTensor(initializer_name, initializer)) return false; - *out = new OrtTensorRef(); // TODO(leca): release - (*out)->shape_len = initializer->dims_size(); - (*out)->shape = new int64_t [initializer->dims_size()]; - for (size_t i = 0; i < (*out)->shape_len; i++) { - ((*out)->shape)[i] = initializer->dims(i); - } - - (*out)->data_type = static_cast(initializer->data_type()); - // see utils::ConvertRawDataInTensorProto() - switch (initializer->data_type()) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - (*out)->data_len = initializer->float_data_size(); - (*out)->data = reinterpret_cast(initializer->float_data().data()); - break; - } - return true; -} - -static ONNXTensorElementDataType GetDataTypeFromTypeProto(const onnx::TypeProto* type) { // onnxruntime\core\optimizer\transpose_optimization\ort_optimizer_api_impl.cc - if (!type || !utils::HasTensorType(*type) || !utils::HasElementType(*type)) return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - - return static_cast(type->tensor_type().elem_type()); -} - -ORT_API(bool, OrtApis::OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - const NodeArg* node_arg = graph_viewer->GetNodeArg(name); - - *out = new OrtValueInfoRef(); // TODO(leca): release - const onnx::TypeProto* type = node_arg->TypeAsProto(); - (*out)->data_type = GetDataTypeFromTypeProto(type); - const auto& dims = utils::TryGetShape(*type)->dim(); - (*out)->shape_len = dims.size(); - (*out)->shape = new int64_t [(*out)->shape_len]; - for (size_t i = 0; i < (*out)->shape_len; i++) ((*out)->shape)[i] = utils::HasDimValue(dims[i]) ? dims[i].dim_value() : -1; - - return true; -} - -ORT_API(size_t, OrtApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - Model model(graph_viewer->Name(), true, ModelMetaData(), PathString(), -#if defined(ORT_MINIMAL_BUILD) - IOnnxRuntimeOpSchemaRegistryList(), -#else - IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), -#endif - graph_viewer->DomainToVersionMap(), std::vector(), graph_viewer->GetGraph().GetLogger()); - onnx::ModelProto model_proto = model.ToProto(); - GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true, ExecutionOrder::PRIORITY_BASED); - size_t ret = model_proto.ByteSizeLong(); - *data = malloc(ret); // TODO(leca): release - model_proto.SerializeToArray(*data, ret); - return ret; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_DeserializeFromArray, const void* data, size_t len, _Outptr_ OrtGraphViewer** ret) { - onnx::ModelProto model_proto; - if (!model_proto.ParseFromArray(data, len)) return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Parse model proto from array returns false"); - std::shared_ptr model; - Status status = Model::Load(std::move(model_proto), model, nullptr, logging::LoggingManager::DefaultLogger()); - if (status != Status::OK()) return ToOrtStatus(status); - std::unique_ptr graph_viewer = std::make_unique(model->MainGraph()); - *ret = reinterpret_cast(graph_viewer.release()); // TODO(leca): release from the caller - return nullptr; -} - -struct SubGraphContext2 { - std::unordered_set output_args; - std::unordered_map inputs_and_initializers; - std::unordered_map manually_added_graph_inputs; -}; - -static std::string GetUniqueGraphName(const Graph& graph) { - HashValue model_hash = 0; - uint32_t hash[4] = {0, 0, 0, 0}; - - auto hash_str = [&hash](const std::string& str) { - MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); - }; - - // Hash all nodes' name - for (int i = 0; i < graph.MaxNodeIndex(); ++i) { - auto node = graph.GetNode(i); - if (node == nullptr) { - continue; - } - hash_str(node->Name()); - } - - model_hash = hash[0] | (uint64_t(hash[1]) << 32); - - return graph.Name() + "_" + std::to_string(model_hash); -} - -static bool IsLocalValue(const Graph& graph, - const std::string& name, - const std::unordered_map>& subgraph_context_map) { - std::string unique_graph_name = GetUniqueGraphName(graph); - if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end()) { - return false; - } - SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); - return context->output_args.find(name) != context->output_args.cend() || - context->inputs_and_initializers.find(name) != context->inputs_and_initializers.cend(); -} - -static bool IsInputInitializerOrOutput(const Graph& graph, - const std::string& name, - bool check_ancestors, - const std::unordered_map>& subgraph_context_map) { - const Graph* parent_graph = nullptr; - return IsLocalValue(graph, name, subgraph_context_map) || - (check_ancestors && (parent_graph = graph.ParentGraph()) != nullptr && - IsInputInitializerOrOutput(*parent_graph, name, check_ancestors, subgraph_context_map)); -} - -static bool IsOuterScopeValue(const Graph& graph, - const std::string& name, - const std::unordered_map>& subgraph_context_map) { - const Graph* parent_graph = nullptr; - return (parent_graph = graph.ParentGraph()) != nullptr && - IsInputInitializerOrOutput(*parent_graph, name, true, subgraph_context_map); -} - -static void BuildSubGraphContext(const Graph& graph, std::unordered_map>& subgraph_context_map) { - // Iterate all the nodes and recurse into inner most subgraph first - for (int i = 0; i < graph.MaxNodeIndex(); ++i) { - auto node = graph.GetNode(i); - if (node == nullptr) { - continue; - } - - auto subgraph_map = node->GetAttributeNameToSubgraphMap(); - for (auto& entry : subgraph_map) { - const Graph* subgraph = entry.second; - BuildSubGraphContext(*subgraph, subgraph_context_map); - } - } - - std::string unique_graph_name = GetUniqueGraphName(graph); - - // Subgraph context has been built before, no need to do it again - if (subgraph_context_map.find(unique_graph_name) != subgraph_context_map.end()) { - return; - } - - subgraph_context_map.emplace(unique_graph_name, std::make_unique()); - SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); - - // Collect all nodes' outputs and nodes' name - for (int i = 0; i < graph.MaxNodeIndex(); ++i) { - auto node = graph.GetNode(i); - if (node == nullptr) { - continue; - } - - for (const auto& output : node->OutputDefs()) { - context->output_args.insert(output->Name()); - } - } - - // Go thru all node's inputs - for (int i = 0; i < graph.MaxNodeIndex(); ++i) { - auto node = graph.GetNode(i); - if (node == nullptr) { - continue; - } - - for (const auto& input : node->InputDefs()) { - if (context->output_args.find(input->Name()) != context->output_args.end()) { - continue; - } - // This input arg is not the output of another node so must come from either a graph input or an initializer. - context->inputs_and_initializers[input->Name()] = input; - } - } -} - -static void SetGraphOuterScopeValuesAndInputs(Graph& graph_build, - const Graph& graph, - std::unordered_map>& subgraph_context_map) { - // Iterate all the nodes and recurse into inner most subgraph first for both newly built graph and original graph - for (int i = 0; i < graph_build.MaxNodeIndex(); ++i) { - auto graph_build_node = graph_build.GetNode(i); - if (graph_build_node == nullptr) { - continue; - } - - auto graph_build_map = graph_build_node->GetAttributeNameToMutableSubgraphMap(); - std::unordered_map> subgraph_map; - const Node* graph_node = nullptr; - - // Find corresponding original graph node's subgraphs - for (int j = 0; j < graph.MaxNodeIndex(); ++j) { - if (graph.GetNode(j) && graph.GetNode(j)->Name() == graph_build_node->Name()) { - graph_node = graph.GetNode(j); - subgraph_map = graph_node->GetAttributeNameToSubgraphMap(); - break; - } - } - - for (auto& entry : graph_build_map) { - auto attr_name = entry.first; - Graph* subgraph_build = entry.second; - if (subgraph_map.find(attr_name) != subgraph_map.end()) { - // recurse into subgraph - const Graph* subgraph = subgraph_map.at(attr_name); - SetGraphOuterScopeValuesAndInputs(*subgraph_build, *subgraph, subgraph_context_map); - } - } - } - - // Start from the inner most subgraph first and check whether its outer scope values are existed in the - // newly built graph. If not, we need to add those outer scope values as explicit inputs to the top-level - // of newly built graph. - if (graph_build.ParentNode()) { - auto top_level_graph = &graph_build; - while (top_level_graph->MutableParentGraph()) { - top_level_graph = top_level_graph->MutableParentGraph(); - } - std::string unique_graph_name = GetUniqueGraphName(*top_level_graph); - if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end()) { - return; - } - - SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); - - // Iterate all the implicit inputs to set outer scope value for the newly built subgraph - for (const auto& input : graph.ParentNode()->ImplicitInputDefs()) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name(); - - // The node arg in parent node's implicit inputs could be used for parent node's other subgraph, for example - // "If" op has two subgraphs. So we need to make sure that the node arg is used in current subgraph only. - // (GetNodeArg searches for specific node arg in all node args in the graph) - if (graph_build.GetNodeArg(input->Name())) { - graph_build.AddOuterScopeNodeArg(input->Name()); -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name() << " is used in this subgraph"; - - if (context && - (context->manually_added_graph_inputs.find(input->Name()) != context->manually_added_graph_inputs.end())) { -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name() << " is already been added as an explicit input to graph"; - continue; - } - - // Handle the case where this outer scope value is not existed in any outer scope levels of the - // newly built graph (the newly built graph is the subgraph of the original graph). Need to add - // the outer scope value as an explicit input to the top-level of newly built graph. - if (!IsOuterScopeValue(graph_build, input->Name(), subgraph_context_map)) { - const auto& name = input->Name(); - auto graph_inputs_including_initializers = top_level_graph->GetInputsIncludingInitializers(); - auto added_graph_input = std::find_if(graph_inputs_including_initializers.begin(), - graph_inputs_including_initializers.end(), - [&name](const NodeArg* entry) { return entry->Name() == name; }); - - if (added_graph_input == graph_inputs_including_initializers.end()) { - if (context) { - auto type_proto = std::make_unique(); - type_proto->CopyFrom(*(input->TypeAsProto())); - auto& n_input = top_level_graph->GetOrCreateNodeArg(name, type_proto.get()); - context->manually_added_graph_inputs[n_input.Name()] = &n_input; -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << n_input.Name() << " is added as an explicit input into the newly built graph"; - } - } - } - } - } - } -} - -static void SetAllGraphInputs(Graph& graph, std::unordered_map>& subgraph_context_map) { - // If ORT TRT doesn't manully set graph input in TensorrtExecutionProvider::SetGraphOuterScopeValuesAndInputs(), - // Graph::Resolve() will help set graph inputs in Graph::SetGraphInputsOutputs(), so no need to set graph inputs here. - std::string unique_graph_name = GetUniqueGraphName(graph); - if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end() || - subgraph_context_map[unique_graph_name].get()->manually_added_graph_inputs.size() == 0) { - return; - } - - SubGraphContext2* context = subgraph_context_map[unique_graph_name].get(); - std::vector graph_inputs_including_initializers; - std::unordered_set graph_inputs_including_initializers_set; - - for (const auto& entry : context->inputs_and_initializers) { - graph_inputs_including_initializers.push_back(entry.second); - graph_inputs_including_initializers_set.insert(entry.first); - } - - for (const auto& entry : context->manually_added_graph_inputs) { - if (graph_inputs_including_initializers_set.find(entry.first) == graph_inputs_including_initializers_set.end()) { - graph_inputs_including_initializers.push_back(entry.second); - graph_inputs_including_initializers_set.insert(entry.first); - } - } - - for (const auto& node_arg : graph.GetInputsIncludingInitializers()) { - if (graph_inputs_including_initializers_set.find(node_arg->Name()) == graph_inputs_including_initializers_set.end()) { - graph_inputs_including_initializers.push_back(node_arg); - graph_inputs_including_initializers_set.insert(node_arg->Name()); - } - } - - graph.SetInputs(graph_inputs_including_initializers); -} - -ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph) { - const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - // Get parent graph output names - std::unordered_set graph_output_names; - for (const auto* output_arg : graph_viewer->GetOutputs()) { - graph_output_names.insert(output_arg->Name()); - } - // TODO(leca): cannot use unique_ptr here, otherwise when this function exits, sub_graph_viewer->graph_->graph_proto_, which is from model_build->model_proto_, will be nullptr. - // Pay special attention when Graph object is releasing. We need to release model_build seperately then. - Model* model_build = new Model (graph_viewer->Name(), true, ModelMetaData(), PathString(), -#if !defined(ORT_MINIMAL_BUILD) - IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(), -#else - IOnnxRuntimeOpSchemaRegistryList(), graph_viewer->DomainToVersionMap(), -#endif // ORT_MINIMAL_BUILD - std::vector(), graph_viewer->GetGraph().GetLogger()); - - auto& graph_build = model_build->MainGraph(); - bool has_control_flow_op = false; - - std::vector subgraph_output_names; - const std::vector& node_index = graph_viewer->GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); - for(int i = 0; i < node_num; i++) { - const auto& node = graph_viewer->GetNode(node_index[node_indices[i]]); - std::vector inputs, outputs; - for (auto input : node->InputDefs()) { - auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); - inputs.push_back(&n_input); - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph_viewer->GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } - } - for (auto input : node->ImplicitInputDefs()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph_viewer->GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } - } - for (auto output : node->OutputDefs()) { - auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); - outputs.push_back(&n_output); - const auto name = output->Name(); - if (graph_output_names.find(name) != graph_output_names.end()) { - subgraph_output_names.push_back(name); - } - } - - std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; - if (control_flow_op_set.find(node->OpType()) != control_flow_op_set.end()) { - has_control_flow_op = true; - } - - // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node attributes are not in sync because of graph optimization. - // Therefore, we need to force GraphProto attributes to be updated in order to get the valid GraphProto. - if (node->GetAttributes().size() > 0) { - auto node_proto = std::make_unique(); - // we need to update any GraphProto attributes for subgraphs so that any changes made by things - // such as the optimizers are captured. otherwise we can end up saving an invalid graph. - node->ToProto(*node_proto, /* update_subgraphs */ true); - const int num_attributes = node_proto->attribute_size(); - NodeAttributes node_attributes; - node_attributes.reserve(num_attributes); - - for (int i = 0; i < num_attributes; ++i) { - auto& attr = node_proto->attribute(i); - node_attributes.emplace(attr.name(), attr); - } - - // The GraphProto attributes are the updated ones. - graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node_attributes, node->Domain()); - } else { - // The GraphProto attributes are the original ones. - graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); - } - } - - // TODO:yang - // Only if the newly built graph has control flow op as well as it has parent node, - // it needs to handle outer scope values before calling graph.Resolve(). - // TODO(leca): Is local variable enough? Do we need to make it EP class variable? - std::unordered_map> subgraph_context_map; - if (has_control_flow_op && graph_viewer->ParentNode()) { - // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); - BuildSubGraphContext(graph_build, subgraph_context_map); - SetGraphOuterScopeValuesAndInputs(graph_build, graph_viewer->GetGraph(), subgraph_context_map); - SetAllGraphInputs(graph_build, subgraph_context_map); - } - - common::Status status = graph_build.Resolve(); - if (status != Status::OK()) return ToOrtStatus(status); - - // Add parent graph output to the subgraph - int i = 0; - std::vector subgraph_outputs; - subgraph_outputs.resize(subgraph_output_names.size()); - for (auto& name : subgraph_output_names) { - auto output_arg = graph_viewer->GetNodeArg(name); - auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); - subgraph_outputs[i] = &subgraph_output_arg; - ++i; - } - auto& graph_build_outputs = graph_build.GetOutputs(); - subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); - graph_build.SetOutputs(graph_build_outputs); - status = graph_build.Resolve(); - if (status != Status::OK()) return ToOrtStatus(status); - - auto sub_graph_viewer = std::make_unique(graph_build); - *subgraph = reinterpret_cast(sub_graph_viewer.release()); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetName, const OrtNode* node, _Out_ const char** name) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *name = n->Name().c_str(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *description = n->Description().c_str(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetDomain, const OrtNode* node, _Out_ const char** domain) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *domain = n->Domain().c_str(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_SinceVersion, const OrtNode* node, _Out_ int* since_version) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *since_version = n->SinceVersion(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** ep_type) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *ep_type = n->GetExecutionProviderType().c_str(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *op_type = n->OpType().c_str(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* input_size) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *input_size = n->ImplicitInputDefs().size(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - assert(i < n->ImplicitInputDefs().size()); - *ith_input_name = n->ImplicitInputDefs()[i]->Name().c_str(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *input_size = n->InputDefs().size(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - assert(i < n->InputDefs().size()); - *ith_input_name = n->InputDefs()[i]->Name().c_str(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *output_size = n->OutputDefs().size(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - assert(i < n->OutputDefs().size()); - if (n->OutputDefs()[i]->Exists()){ - *ith_output_name = n->OutputDefs()[i]->Name().c_str(); - } else { - *ith_output_name = nullptr; - } - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *index = n->Index(); - return nullptr; -} - -ORT_API(size_t, OrtApis::OrtNode_GetAttributeNames, const OrtNode* node, _Out_ const char*** names) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - size_t ret = n->GetAttributes().size(); - *names = new const char* [ret]; - int i = 0; - for (const auto& [k, v] : n->GetAttributes()) { - (*names)[i++] = k.c_str(); - } - return ret; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *attr_size = n->GetAttributes().size(); - return nullptr; -} - -ORT_API(int, OrtApis::OrtNode_GetAttributeType, const OrtNode* node, const char* attribute) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - return static_cast(n->GetAttributes().at(attribute).type()); -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *count = n->GetAttributes().count(key); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *int_size = n->GetAttributes().at(key).ints_size(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* float_size) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *float_size = n->GetAttributes().at(key).floats_size(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* str_size) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *str_size = n->GetAttributes().at(key).strings_size(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* ints) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *ints = n->GetAttributes().at(key).ints(i); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* floats) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *floats = n->GetAttributes().at(key).floats(i); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - *strs = n->GetAttributes().at(key).strings(i).c_str(); - return nullptr; -} - -ORT_API(const char*, OrtApis::OrtNode_GetAttributeStr, const OrtNode* node, const char* key) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).s().c_str(); -} - -ORT_API(int64_t, OrtApis::OrtNode_GetAttributeInt, const OrtNode* node, const char* key) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).i(); -} - -ORT_API(float, OrtApis::OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).f(); -} - -ORT_API_STATUS_IMPL(OrtApis::OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs) { - const ::onnxruntime::Node* n = reinterpret_cast(node); - std::vector> subg = n->GetSubgraphs(); - *len = subg.size(); - *subgraphs = new const OrtGraphViewer* [*len]; - for (size_t i = 0; i < subg.size(); i++) { - const ::onnxruntime::GraphViewer* graph_viewer = new const ::onnxruntime::GraphViewer(*subg[i]); - (*subgraphs)[i] = reinterpret_cast(graph_viewer); - } - return nullptr; -} - ORT_API_STATUS_IMPL(OrtApis::OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints) { KernelRegistry* kr = reinterpret_cast(kernel_registry); KernelCreateInfo kci = CreateKernelCreateInfo2("", custom_op, type_constraints); @@ -3540,55 +2826,6 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::RegisterOrtExecutionProviderLibrary, &OrtApis::SessionOptionsAppendOrtExecutionProvider, - &OrtApis::OrtGraph_GetName, - &OrtApis::OrtGraph_IsConstantInitializer, - &OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, - &OrtApis::OrtGraph_IsSubgraph, - &OrtApis::OrtGraph_GetParentGraph, - &OrtApis::OrtGraph_GetParenNode, - &OrtApis::OrtGraph_GetModelPath, - &OrtApis::OrtGraph_GetOrtGraph, - &OrtApis::OrtGraph_GetInputsIncludingInitializers, - &OrtApis::OrtGraph_GetOrtNode, - &OrtApis::OrtGraph_GetNodesConsumingInput, - &OrtApis::OrtGraph_GetNodeProducingOutput, - &OrtApis::OrtGraph_NumberOfNodes, - &OrtApis::OrtGraph_MaxNodeIndex, - &OrtApis::OrtGraph_GetOutputSize, - &OrtApis::OrtGraph_GetIthOutputName, - &OrtApis::OrtGraph_GetIthOutputElemType, - &OrtApis::OrtGraph_GetInitializerTensor, - &OrtApis::OrtGraph_GetValueInfo, - &OrtApis::OrtGraph_SerializeToArray, - &OrtApis::OrtGraph_DeserializeFromArray, - &OrtApis::OrtGraph_GetSubGraph, - &OrtApis::OrtNode_GetName, - &OrtApis::OrtNode_GetDescription, - &OrtApis::OrtNode_GetDomain, - &OrtApis::OrtNode_SinceVersion, - &OrtApis::OrtNode_GetExecutionProviderType, - &OrtApis::OrtNode_GetOpType, - &OrtApis::OrtNode_GetImplicitInputSize, - &OrtApis::OrtNode_GetIthImplicitInputName, - &OrtApis::OrtNode_GetInputSize, - &OrtApis::OrtNode_GetIthInputName, - &OrtApis::OrtNode_GetOutputSize, - &OrtApis::OrtNode_GetIthOutputName, - &OrtApis::OrtNode_GetIndex, - &OrtApis::OrtNode_GetAttributeNames, - &OrtApis::OrtNode_GetAttributeSize, - &OrtApis::OrtNode_GetAttributeType, - &OrtApis::OrtNode_GetAttributeKeyCount, - &OrtApis::OrtNode_GetAttributeIntSize, - &OrtApis::OrtNode_GetAttributeFloatSize, - &OrtApis::OrtNode_GetAttributeStringSize, - &OrtApis::OrtNode_GetAttributeIthInt, - &OrtApis::OrtNode_GetAttributeIthFloat, - &OrtApis::OrtNode_GetAttributeIthStr, - &OrtApis::OrtNode_GetAttributeStr, - &OrtApis::OrtNode_GetAttributeInt, - &OrtApis::OrtNode_GetAttributeFloat, - &OrtApis::OrtNode_GetSubgraphs, &OrtApis::OrtKernelRegistry_RegisterKernel, &OrtApis::CreateOrtTypeConstraints, &OrtApis::AddTypeConstraint, diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index 5c89348188faf..9f4ca8093cdb3 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -1,15 +1,730 @@ #include "core/session/onnxruntime_c_api_ep.h" #include "ort_apis_ep.h" +#include "core/graph/graph_proto_serializer.h" #include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/murmurhash3.h" +#include "core/framework/session_options.h" +#include "core/framework/tensorprotoutils.h" +#include "core/session/ort_apis.h" -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out) { +using namespace onnxruntime; + +ORT_API(const char*, OrtGraphApis::OrtGraph_GetName, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->Name().c_str(); +} + +ORT_API(bool, OrtGraphApis::OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->IsConstantInitializer(name, check_outer_scope); +} + +ORT_API(size_t, OrtGraphApis::OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const std::vector& nodes = graph_viewer->GetNodesInTopologicalOrder(static_cast(execution_order)); + *nodes_index_in_topological_order = nodes.data(); + return nodes.size(); +} + +ORT_API(bool, OrtGraphApis::OrtGraph_IsSubgraph, const OrtGraph* graph) { + const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); + return graph_ptr->IsSubgraph(); +} + +ORT_API(const OrtGraph*, OrtGraphApis::OrtGraph_GetParentGraph, const OrtGraph* graph) { + const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); + return reinterpret_cast(graph_ptr->ParentGraph()); +} + +ORT_API(const OrtNode*, OrtGraphApis::OrtGraph_GetParenNode, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return reinterpret_cast(graph_viewer->ParentNode()); +} + +ORT_API(const void*, OrtGraphApis::OrtGraph_GetModelPath, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return reinterpret_cast(&graph_viewer->ModelPath()); +} + +ORT_API(const OrtGraph*, OrtGraphApis::OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer) { + const ::onnxruntime::GraphViewer* graph_viewer_ptr = reinterpret_cast(graph_viewer); + return reinterpret_cast(&graph_viewer_ptr->GetGraph()); +} + +ORT_API(size_t, OrtGraphApis::OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const auto& inputs = graph_viewer->GetInputsIncludingInitializers(); + size_t ret = inputs.size(); + *input_names = new const char*[ret]; + for (size_t i = 0; i < ret; i++) (*input_names)[i] = inputs[i]->Name().c_str(); + return ret; +} + +ORT_API(const OrtNode*, OrtGraphApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return reinterpret_cast(graph_viewer->GetNode(node_index)); +} + +ORT_API(size_t, OrtGraphApis::OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + std::vector consumer_nodes = graph_viewer->GetConsumerNodes(input_name); + size_t ret = consumer_nodes.size(); + *consumers = new const OrtNode* [ret]; + for (size_t i = 0; i < ret; i++) (*consumers)[i] = reinterpret_cast(consumer_nodes[i]); + + return ret; +} + +ORT_API(const OrtNode*, OrtGraphApis::OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return reinterpret_cast(graph_viewer->GetProducerNode(output_name)); +} + +ORT_API(int, OrtGraphApis::OrtGraph_NumberOfNodes, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->NumberOfNodes(); +} + +ORT_API(int, OrtGraphApis::OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->MaxNodeIndex(); +} + +ORT_API(size_t, OrtGraphApis::OrtGraph_GetOutputSize, const OrtGraphViewer* graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->GetOutputs().size(); +} + +ORT_API(const char*, OrtGraphApis::OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - *out = graph_viewer->NumberOfNodes(); + return graph_viewer->GetOutputs()[i]->Name().c_str(); +} + +ORT_API(int32_t, OrtGraphApis::OrtGraph_GetIthOutputElemType, const OrtGraphViewer* graph, size_t i) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + return graph_viewer->GetOutputs()[i]->TypeAsProto()->tensor_type().elem_type(); +} + +ORT_API(bool, OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** out) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const onnx::TensorProto* initializer = nullptr; + if (!graph_viewer->GetInitializedTensor(initializer_name, initializer)) return false; + *out = new OrtTensorRef(); // TODO(leca): release + (*out)->shape_len = initializer->dims_size(); + (*out)->shape = new int64_t [initializer->dims_size()]; + for (size_t i = 0; i < (*out)->shape_len; i++) { + ((*out)->shape)[i] = initializer->dims(i); + } + + (*out)->data_type = static_cast(initializer->data_type()); + // see utils::ConvertRawDataInTensorProto() + switch (initializer->data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + (*out)->data_len = initializer->float_data_size(); + (*out)->data = reinterpret_cast(initializer->float_data().data()); + break; + } + return true; +} + +static ONNXTensorElementDataType GetDataTypeFromTypeProto(const onnx::TypeProto* type) { // onnxruntime\core\optimizer\transpose_optimization\ort_optimizer_api_impl.cc + if (!type || !utils::HasTensorType(*type) || !utils::HasElementType(*type)) return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + return static_cast(type->tensor_type().elem_type()); +} + +ORT_API(bool, OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const NodeArg* node_arg = graph_viewer->GetNodeArg(name); + + *out = new OrtValueInfoRef(); // TODO(leca): release + const onnx::TypeProto* type = node_arg->TypeAsProto(); + (*out)->data_type = GetDataTypeFromTypeProto(type); + const auto& dims = utils::TryGetShape(*type)->dim(); + (*out)->shape_len = dims.size(); + (*out)->shape = new int64_t [(*out)->shape_len]; + for (size_t i = 0; i < (*out)->shape_len; i++) ((*out)->shape)[i] = utils::HasDimValue(dims[i]) ? dims[i].dim_value() : -1; + + return true; +} + +ORT_API(size_t, OrtGraphApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + Model model(graph_viewer->Name(), true, ModelMetaData(), PathString(), +#if defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaRegistryList(), +#else + IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), +#endif + graph_viewer->DomainToVersionMap(), std::vector(), graph_viewer->GetGraph().GetLogger()); + onnx::ModelProto model_proto = model.ToProto(); + GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true, ExecutionOrder::PRIORITY_BASED); + size_t ret = model_proto.ByteSizeLong(); + *data = malloc(ret); // TODO(leca): release + model_proto.SerializeToArray(*data, ret); + return ret; +} + +struct SubGraphContext2 { + std::unordered_set output_args; + std::unordered_map inputs_and_initializers; + std::unordered_map manually_added_graph_inputs; +}; + +static std::string GetUniqueGraphName(const Graph& graph) { + HashValue model_hash = 0; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // Hash all nodes' name + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + hash_str(node->Name()); + } + + model_hash = hash[0] | (uint64_t(hash[1]) << 32); + + return graph.Name() + "_" + std::to_string(model_hash); +} + +static bool IsLocalValue(const Graph& graph, + const std::string& name, + const std::unordered_map>& subgraph_context_map) { + std::string unique_graph_name = GetUniqueGraphName(graph); + if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end()) { + return false; + } + SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); + return context->output_args.find(name) != context->output_args.cend() || + context->inputs_and_initializers.find(name) != context->inputs_and_initializers.cend(); +} + +static bool IsInputInitializerOrOutput(const Graph& graph, + const std::string& name, + bool check_ancestors, + const std::unordered_map>& subgraph_context_map) { + const Graph* parent_graph = nullptr; + return IsLocalValue(graph, name, subgraph_context_map) || + (check_ancestors && (parent_graph = graph.ParentGraph()) != nullptr && + IsInputInitializerOrOutput(*parent_graph, name, check_ancestors, subgraph_context_map)); +} + +static bool IsOuterScopeValue(const Graph& graph, + const std::string& name, + const std::unordered_map>& subgraph_context_map) { + const Graph* parent_graph = nullptr; + return (parent_graph = graph.ParentGraph()) != nullptr && + IsInputInitializerOrOutput(*parent_graph, name, true, subgraph_context_map); +} + +static void BuildSubGraphContext(const Graph& graph, std::unordered_map>& subgraph_context_map) { + // Iterate all the nodes and recurse into inner most subgraph first + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + + auto subgraph_map = node->GetAttributeNameToSubgraphMap(); + for (auto& entry : subgraph_map) { + const Graph* subgraph = entry.second; + BuildSubGraphContext(*subgraph, subgraph_context_map); + } + } + + std::string unique_graph_name = GetUniqueGraphName(graph); + + // Subgraph context has been built before, no need to do it again + if (subgraph_context_map.find(unique_graph_name) != subgraph_context_map.end()) { + return; + } + + subgraph_context_map.emplace(unique_graph_name, std::make_unique()); + SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); + + // Collect all nodes' outputs and nodes' name + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + + for (const auto& output : node->OutputDefs()) { + context->output_args.insert(output->Name()); + } + } + + // Go thru all node's inputs + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + + for (const auto& input : node->InputDefs()) { + if (context->output_args.find(input->Name()) != context->output_args.end()) { + continue; + } + // This input arg is not the output of another node so must come from either a graph input or an initializer. + context->inputs_and_initializers[input->Name()] = input; + } + } +} + +static void SetGraphOuterScopeValuesAndInputs(Graph& graph_build, + const Graph& graph, + std::unordered_map>& subgraph_context_map) { + // Iterate all the nodes and recurse into inner most subgraph first for both newly built graph and original graph + for (int i = 0; i < graph_build.MaxNodeIndex(); ++i) { + auto graph_build_node = graph_build.GetNode(i); + if (graph_build_node == nullptr) { + continue; + } + + auto graph_build_map = graph_build_node->GetAttributeNameToMutableSubgraphMap(); + std::unordered_map> subgraph_map; + const Node* graph_node = nullptr; + + // Find corresponding original graph node's subgraphs + for (int j = 0; j < graph.MaxNodeIndex(); ++j) { + if (graph.GetNode(j) && graph.GetNode(j)->Name() == graph_build_node->Name()) { + graph_node = graph.GetNode(j); + subgraph_map = graph_node->GetAttributeNameToSubgraphMap(); + break; + } + } + + for (auto& entry : graph_build_map) { + auto attr_name = entry.first; + Graph* subgraph_build = entry.second; + if (subgraph_map.find(attr_name) != subgraph_map.end()) { + // recurse into subgraph + const Graph* subgraph = subgraph_map.at(attr_name); + SetGraphOuterScopeValuesAndInputs(*subgraph_build, *subgraph, subgraph_context_map); + } + } + } + + // Start from the inner most subgraph first and check whether its outer scope values are existed in the + // newly built graph. If not, we need to add those outer scope values as explicit inputs to the top-level + // of newly built graph. + if (graph_build.ParentNode()) { + auto top_level_graph = &graph_build; + while (top_level_graph->MutableParentGraph()) { + top_level_graph = top_level_graph->MutableParentGraph(); + } + std::string unique_graph_name = GetUniqueGraphName(*top_level_graph); + if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end()) { + return; + } + + SubGraphContext2* context = subgraph_context_map.at(unique_graph_name).get(); + + // Iterate all the implicit inputs to set outer scope value for the newly built subgraph + for (const auto& input : graph.ParentNode()->ImplicitInputDefs()) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name(); + + // The node arg in parent node's implicit inputs could be used for parent node's other subgraph, for example + // "If" op has two subgraphs. So we need to make sure that the node arg is used in current subgraph only. + // (GetNodeArg searches for specific node arg in all node args in the graph) + if (graph_build.GetNodeArg(input->Name())) { + graph_build.AddOuterScopeNodeArg(input->Name()); +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name() << " is used in this subgraph"; + + if (context && + (context->manually_added_graph_inputs.find(input->Name()) != context->manually_added_graph_inputs.end())) { +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << input->Name() << " is already been added as an explicit input to graph"; + continue; + } + + // Handle the case where this outer scope value is not existed in any outer scope levels of the + // newly built graph (the newly built graph is the subgraph of the original graph). Need to add + // the outer scope value as an explicit input to the top-level of newly built graph. + if (!IsOuterScopeValue(graph_build, input->Name(), subgraph_context_map)) { + const auto& name = input->Name(); + auto graph_inputs_including_initializers = top_level_graph->GetInputsIncludingInitializers(); + auto added_graph_input = std::find_if(graph_inputs_including_initializers.begin(), + graph_inputs_including_initializers.end(), + [&name](const NodeArg* entry) { return entry->Name() == name; }); + + if (added_graph_input == graph_inputs_including_initializers.end()) { + if (context) { + auto type_proto = std::make_unique(); + type_proto->CopyFrom(*(input->TypeAsProto())); + auto& n_input = top_level_graph->GetOrCreateNodeArg(name, type_proto.get()); + context->manually_added_graph_inputs[n_input.Name()] = &n_input; +// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] \t" << n_input.Name() << " is added as an explicit input into the newly built graph"; + } + } + } + } + } + } +} + +static void SetAllGraphInputs(Graph& graph, std::unordered_map>& subgraph_context_map) { + // If ORT TRT doesn't manully set graph input in TensorrtExecutionProvider::SetGraphOuterScopeValuesAndInputs(), + // Graph::Resolve() will help set graph inputs in Graph::SetGraphInputsOutputs(), so no need to set graph inputs here. + std::string unique_graph_name = GetUniqueGraphName(graph); + if (subgraph_context_map.find(unique_graph_name) == subgraph_context_map.end() || + subgraph_context_map[unique_graph_name].get()->manually_added_graph_inputs.size() == 0) { + return; + } + + SubGraphContext2* context = subgraph_context_map[unique_graph_name].get(); + std::vector graph_inputs_including_initializers; + std::unordered_set graph_inputs_including_initializers_set; + + for (const auto& entry : context->inputs_and_initializers) { + graph_inputs_including_initializers.push_back(entry.second); + graph_inputs_including_initializers_set.insert(entry.first); + } + + for (const auto& entry : context->manually_added_graph_inputs) { + if (graph_inputs_including_initializers_set.find(entry.first) == graph_inputs_including_initializers_set.end()) { + graph_inputs_including_initializers.push_back(entry.second); + graph_inputs_including_initializers_set.insert(entry.first); + } + } + + for (const auto& node_arg : graph.GetInputsIncludingInitializers()) { + if (graph_inputs_including_initializers_set.find(node_arg->Name()) == graph_inputs_including_initializers_set.end()) { + graph_inputs_including_initializers.push_back(node_arg); + graph_inputs_including_initializers_set.insert(node_arg->Name()); + } + } + + graph.SetInputs(graph_inputs_including_initializers); +} + +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + // Get parent graph output names + std::unordered_set graph_output_names; + for (const auto* output_arg : graph_viewer->GetOutputs()) { + graph_output_names.insert(output_arg->Name()); + } + // TODO(leca): cannot use unique_ptr here, otherwise when this function exits, sub_graph_viewer->graph_->graph_proto_, which is from model_build->model_proto_, will be nullptr. + // Pay special attention when Graph object is releasing. We need to release model_build seperately then. + Model* model_build = new Model (graph_viewer->Name(), true, ModelMetaData(), PathString(), +#if !defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(), +#else + IOnnxRuntimeOpSchemaRegistryList(), graph_viewer->DomainToVersionMap(), +#endif // ORT_MINIMAL_BUILD + std::vector(), graph_viewer->GetGraph().GetLogger()); + + auto& graph_build = model_build->MainGraph(); + bool has_control_flow_op = false; + + std::vector subgraph_output_names; + const std::vector& node_index = graph_viewer->GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + for(int i = 0; i < node_num; i++) { + const auto& node = graph_viewer->GetNode(node_index[node_indices[i]]); + std::vector inputs, outputs; + for (auto input : node->InputDefs()) { + auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + inputs.push_back(&n_input); + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph_viewer->GetInitializedTensor(input->Name(), initializer)) { + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { + graph_build.AddInitializedTensor(*(initializer)); + } + } + } + for (auto input : node->ImplicitInputDefs()) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph_viewer->GetInitializedTensor(input->Name(), initializer)) { + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { + graph_build.AddInitializedTensor(*(initializer)); + } + } + } + for (auto output : node->OutputDefs()) { + auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + outputs.push_back(&n_output); + const auto name = output->Name(); + if (graph_output_names.find(name) != graph_output_names.end()) { + subgraph_output_names.push_back(name); + } + } + + std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; + if (control_flow_op_set.find(node->OpType()) != control_flow_op_set.end()) { + has_control_flow_op = true; + } + + // If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node attributes are not in sync because of graph optimization. + // Therefore, we need to force GraphProto attributes to be updated in order to get the valid GraphProto. + if (node->GetAttributes().size() > 0) { + auto node_proto = std::make_unique(); + // we need to update any GraphProto attributes for subgraphs so that any changes made by things + // such as the optimizers are captured. otherwise we can end up saving an invalid graph. + node->ToProto(*node_proto, /* update_subgraphs */ true); + const int num_attributes = node_proto->attribute_size(); + NodeAttributes node_attributes; + node_attributes.reserve(num_attributes); + + for (int i = 0; i < num_attributes; ++i) { + auto& attr = node_proto->attribute(i); + node_attributes.emplace(attr.name(), attr); + } + + // The GraphProto attributes are the updated ones. + graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node_attributes, node->Domain()); + } else { + // The GraphProto attributes are the original ones. + graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); + } + } + + // TODO:yang + // Only if the newly built graph has control flow op as well as it has parent node, + // it needs to handle outer scope values before calling graph.Resolve(). + // TODO(leca): Is local variable enough? Do we need to make it EP class variable? + std::unordered_map> subgraph_context_map; + if (has_control_flow_op && graph_viewer->ParentNode()) { + // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); + BuildSubGraphContext(graph_build, subgraph_context_map); + SetGraphOuterScopeValuesAndInputs(graph_build, graph_viewer->GetGraph(), subgraph_context_map); + SetAllGraphInputs(graph_build, subgraph_context_map); + } + + common::Status status = graph_build.Resolve(); + if (status != Status::OK()) return onnxruntime::ToOrtStatus(status); + + // Add parent graph output to the subgraph + int i = 0; + std::vector subgraph_outputs; + subgraph_outputs.resize(subgraph_output_names.size()); + for (auto& name : subgraph_output_names) { + auto output_arg = graph_viewer->GetNodeArg(name); + auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + subgraph_outputs[i] = &subgraph_output_arg; + ++i; + } + auto& graph_build_outputs = graph_build.GetOutputs(); + subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); + graph_build.SetOutputs(graph_build_outputs); + status = graph_build.Resolve(); + if (status != Status::OK()) return onnxruntime::ToOrtStatus(status); + + auto sub_graph_viewer = std::make_unique(graph_build); + *subgraph = reinterpret_cast(sub_graph_viewer.release()); + return nullptr; +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetName, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->Name().c_str(); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetDescription, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->Description().c_str(); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetDomain, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->Domain().c_str(); +} + +ORT_API(int, OrtGraphApis::OrtNode_SinceVersion, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->SinceVersion(); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetExecutionProviderType, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetExecutionProviderType().c_str(); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetOpType, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->OpType().c_str(); +} + +ORT_API(size_t, OrtGraphApis::OrtNode_GetImplicitInputSize, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->ImplicitInputDefs().size(); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + assert(i < n->ImplicitInputDefs().size()); + return n->ImplicitInputDefs()[i]->Name().c_str(); +} + +ORT_API(size_t, OrtGraphApis::OrtNode_GetInputSize, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->InputDefs().size(); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetIthInputName, const OrtNode* node, size_t i) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + assert(i < n->InputDefs().size()); + return n->InputDefs()[i]->Name().c_str(); +} + +ORT_API(size_t, OrtGraphApis::OrtNode_GetOutputSize, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->OutputDefs().size(); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetIthOutputName, const OrtNode* node, size_t i) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + assert(i < n->OutputDefs().size()); + if (n->OutputDefs()[i]->Exists()) return n->OutputDefs()[i]->Name().c_str(); return nullptr; } +ORT_API(size_t, OrtGraphApis::OrtNode_GetIndex, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->Index(); +} + +ORT_API(size_t, OrtGraphApis::OrtNode_GetAttributeNames, const OrtNode* node, _Out_ const char*** names) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + size_t ret = n->GetAttributes().size(); + *names = new const char* [ret]; + int i = 0; + for (const auto& [k, v] : n->GetAttributes()) { + (*names)[i++] = k.c_str(); + } + return ret; +} + +ORT_API(size_t, OrtGraphApis::OrtNode_GetAttributeSize, const OrtNode* node) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().size(); +} + +ORT_API(int, OrtGraphApis::OrtNode_GetAttributeType, const OrtNode* node, const char* attribute) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return static_cast(n->GetAttributes().at(attribute).type()); +} + +ORT_API(size_t, OrtGraphApis::OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().count(key); +} + +ORT_API(int, OrtGraphApis::OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).ints_size(); +} + +ORT_API(int, OrtGraphApis::OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).floats_size(); +} + +ORT_API(int, OrtGraphApis::OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).strings_size(); +} + +ORT_API(int64_t, OrtGraphApis::OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).ints(i); +} + +ORT_API(float, OrtGraphApis::OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).floats(i); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).strings(i).c_str(); +} + +ORT_API(const char*, OrtGraphApis::OrtNode_GetAttributeStr, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).s().c_str(); +} + +ORT_API(int64_t, OrtGraphApis::OrtNode_GetAttributeInt, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).i(); +} + +ORT_API(float, OrtGraphApis::OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + return n->GetAttributes().at(key).f(); +} + +ORT_API(size_t, OrtGraphApis::OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + std::vector> subg = n->GetSubgraphs(); + size_t ret = subg.size(); + *subgraphs = new const OrtGraphViewer* [ret]; + for (size_t i = 0; i < ret; i++) { + const ::onnxruntime::GraphViewer* graph_viewer = new const ::onnxruntime::GraphViewer(*subg[i]); + (*subgraphs)[i] = reinterpret_cast(graph_viewer); + } + return ret; +} + static constexpr OrtGraphApi ort_graph_api = { - &OrtGraphApis::OrtGraph_PlaceHolder, + &OrtGraphApis::OrtGraph_GetName, + &OrtGraphApis::OrtGraph_IsConstantInitializer, + &OrtGraphApis::OrtGraph_GetNodesIndexInTopologicalOrder, + &OrtGraphApis::OrtGraph_IsSubgraph, + &OrtGraphApis::OrtGraph_GetParentGraph, + &OrtGraphApis::OrtGraph_GetParenNode, + &OrtGraphApis::OrtGraph_GetModelPath, + &OrtGraphApis::OrtGraph_GetOrtGraph, + &OrtGraphApis::OrtGraph_GetInputsIncludingInitializers, + &OrtGraphApis::OrtGraph_GetOrtNode, + &OrtGraphApis::OrtGraph_GetNodesConsumingInput, + &OrtGraphApis::OrtGraph_GetNodeProducingOutput, + &OrtGraphApis::OrtGraph_NumberOfNodes, + &OrtGraphApis::OrtGraph_MaxNodeIndex, + &OrtGraphApis::OrtGraph_GetOutputSize, + &OrtGraphApis::OrtGraph_GetIthOutputName, + &OrtGraphApis::OrtGraph_GetIthOutputElemType, + &OrtGraphApis::OrtGraph_GetInitializerTensor, + &OrtGraphApis::OrtGraph_GetValueInfo, + &OrtGraphApis::OrtGraph_SerializeToArray, + &OrtGraphApis::OrtGraph_GetSubGraph, + &OrtGraphApis::OrtNode_GetName, + &OrtGraphApis::OrtNode_GetDescription, + &OrtGraphApis::OrtNode_GetDomain, + &OrtGraphApis::OrtNode_SinceVersion, + &OrtGraphApis::OrtNode_GetExecutionProviderType, + &OrtGraphApis::OrtNode_GetOpType, + &OrtGraphApis::OrtNode_GetImplicitInputSize, + &OrtGraphApis::OrtNode_GetIthImplicitInputName, + &OrtGraphApis::OrtNode_GetInputSize, + &OrtGraphApis::OrtNode_GetIthInputName, + &OrtGraphApis::OrtNode_GetOutputSize, + &OrtGraphApis::OrtNode_GetIthOutputName, + &OrtGraphApis::OrtNode_GetIndex, + &OrtGraphApis::OrtNode_GetAttributeNames, + &OrtGraphApis::OrtNode_GetAttributeSize, + &OrtGraphApis::OrtNode_GetAttributeType, + &OrtGraphApis::OrtNode_GetAttributeKeyCount, + &OrtGraphApis::OrtNode_GetAttributeIntSize, + &OrtGraphApis::OrtNode_GetAttributeFloatSize, + &OrtGraphApis::OrtNode_GetAttributeStringSize, + &OrtGraphApis::OrtNode_GetAttributeIthInt, + &OrtGraphApis::OrtNode_GetAttributeIthFloat, + &OrtGraphApis::OrtNode_GetAttributeIthStr, + &OrtGraphApis::OrtNode_GetAttributeStr, + &OrtGraphApis::OrtNode_GetAttributeInt, + &OrtGraphApis::OrtNode_GetAttributeFloat, + &OrtGraphApis::OrtNode_GetSubgraphs, }; ORT_API(const OrtGraphApi*, OrtGraphApis::GetGraphApi, uint32_t) { diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index bf149d4daca4d..5ce145daf3fe9 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -539,104 +539,6 @@ ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* l ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); -ORT_API(const char*, OrtGraph_GetName, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; - -ORT_API_STATUS_IMPL(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); - -ORT_API_STATUS_IMPL(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); - -ORT_API_STATUS_IMPL(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); - -ORT_API_STATUS_IMPL(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); - -ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path); - -ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); - -ORT_API_STATUS_IMPL(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); - -ORT_API_STATUS_IMPL(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Out_ size_t* num_inputs, _Outptr_ const char*** input_names); - -ORT_API_STATUS_IMPL(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); - -ORT_API_STATUS_IMPL(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Out_ size_t* len, _Outptr_ const OrtNode*** consumers); - -ORT_API_STATUS_IMPL(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); - -ORT_API(int, OrtGraph_NumberOfNodes, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; - -ORT_API_STATUS_IMPL(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out); - -ORT_API(size_t, OrtGraph_GetOutputSize, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; - -ORT_API(const char*, OrtGraph_GetIthOutputName, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; - -ORT_API(int32_t, OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; - -ORT_API(bool, OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); - -ORT_API(bool, OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); - -ORT_API(size_t, OrtGraph_SerializeToArray, const OrtGraphViewer*, _Out_ void** data); - -ORT_API_STATUS_IMPL(OrtGraph_DeserializeFromArray, const void* data, size_t len, _Outptr_ OrtGraphViewer**); - -ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); - -ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); - -ORT_API_STATUS_IMPL(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); - -ORT_API_STATUS_IMPL(OrtNode_GetDomain, const OrtNode* node, _Out_ const char** domain); - -ORT_API_STATUS_IMPL(OrtNode_SinceVersion, const OrtNode* node, _Out_ int* since_version); - -ORT_API_STATUS_IMPL(OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** ep_type); - -ORT_API_STATUS_IMPL(OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type); - -ORT_API_STATUS_IMPL(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* input_size); - -ORT_API_STATUS_IMPL(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); - -ORT_API_STATUS_IMPL(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size); - -ORT_API_STATUS_IMPL(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); - -ORT_API_STATUS_IMPL(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size); - -ORT_API_STATUS_IMPL(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); - -ORT_API_STATUS_IMPL(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index); - -ORT_API(size_t, OrtNode_GetAttributeNames, const OrtNode* node, const char*** names); - -ORT_API_STATUS_IMPL(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size); - -ORT_API(int, OrtNode_GetAttributeType, const OrtNode* node, const char* attribute) ORT_ALL_ARGS_NONNULL; - -ORT_API_STATUS_IMPL(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count); - -ORT_API_STATUS_IMPL(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size); - -ORT_API_STATUS_IMPL(OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* float_size); - -ORT_API_STATUS_IMPL(OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* str_size); - -ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* ints); - -ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* floats); - -ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs); - -ORT_API(const char*, OrtNode_GetAttributeStr, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; - -ORT_API(int64_t, OrtNode_GetAttributeInt, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; - -ORT_API(float, OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; - -ORT_API_STATUS_IMPL(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); - ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index ef0af223504f5..d82b5e9742e43 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -2,5 +2,101 @@ namespace OrtGraphApis { ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version); -ORT_API_STATUS_IMPL(OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out); + +ORT_API(const char*, OrtGraph_GetName, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; + +ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope)ORT_ALL_ARGS_NONNULL; + +ORT_API(size_t, OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order); + +ORT_API(bool, OrtGraph_IsSubgraph, const OrtGraph* graph); + +ORT_API(const OrtGraph*, OrtGraph_GetParentGraph, const OrtGraph* graph); + +ORT_API(const OrtNode*, OrtGraph_GetParenNode, const OrtGraphViewer* graph); + +ORT_API(const void*, OrtGraph_GetModelPath, const OrtGraphViewer* graph); + +ORT_API(const OrtGraph*, OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer); + +ORT_API(size_t, OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names); + +ORT_API(const OrtNode*, OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index); + +ORT_API(size_t, OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers); + +ORT_API(const OrtNode*, OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name); + +ORT_API(int, OrtGraph_NumberOfNodes, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; + +ORT_API(int, OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph); + +ORT_API(size_t, OrtGraph_GetOutputSize, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; + +ORT_API(const char*, OrtGraph_GetIthOutputName, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; + +ORT_API(int32_t, OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; + +ORT_API(bool, OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); + +ORT_API(bool, OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); + +ORT_API(size_t, OrtGraph_SerializeToArray, const OrtGraphViewer*, _Out_ void** data); + +ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); + +ORT_API(const char*, OrtNode_GetName, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetDescription, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetDomain, const OrtNode* node); + +ORT_API(int, OrtNode_SinceVersion, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetExecutionProviderType, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetOpType, const OrtNode* node); + +ORT_API(size_t, OrtNode_GetImplicitInputSize, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i); + +ORT_API(size_t, OrtNode_GetInputSize, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetIthInputName, const OrtNode* node, size_t i); + +ORT_API(size_t, OrtNode_GetOutputSize, const OrtNode* node); + +ORT_API(const char*, OrtNode_GetIthOutputName, const OrtNode* node, size_t i); + +ORT_API(size_t, OrtNode_GetIndex, const OrtNode* node); + +ORT_API(size_t, OrtNode_GetAttributeNames, const OrtNode* node, const char*** names); + +ORT_API(size_t, OrtNode_GetAttributeSize, const OrtNode* node); + +ORT_API(int, OrtNode_GetAttributeType, const OrtNode* node, const char* attribute) ORT_ALL_ARGS_NONNULL; + +ORT_API(size_t, OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key); + +ORT_API(int, OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key); + +ORT_API(int, OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key); + +ORT_API(int, OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key); + +ORT_API(int64_t, OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i); + +ORT_API(float, OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i); + +ORT_API(const char*, OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i); + +ORT_API(const char*, OrtNode_GetAttributeStr, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; + +ORT_API(int64_t, OrtNode_GetAttributeInt, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; + +ORT_API(float, OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; + +ORT_API(size_t, OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs); + } diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 188b276d12a6d..4b15cdcb88351 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -13,6 +13,7 @@ #include "core/common/logging/logging.h" #include "core/framework/provider_shutdown.h" #include "core/platform/logging/make_platform_default_log_sink.h" +#include "core/session/onnxruntime_c_api_ep.h" using namespace onnxruntime; using namespace onnxruntime::logging; diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 31a31b21ef54c..dd0a87c44515d 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -13,6 +13,7 @@ namespace onnxruntime { class Environment; } +struct OrtExecutionProviderFactory; struct OrtEnv { public: diff --git a/samples/c_test/sanityTests.sh b/samples/c_test/sanityTests.sh new file mode 100755 index 0000000000000..020d49489df96 --- /dev/null +++ b/samples/c_test/sanityTests.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +echo 'Compile based EP, relu:' +./TestOutTreeEp c relu + +echo 'Kernel based EP, relu:' +./TestOutTreeEp k relu + +echo 'TRT EP, relu:' +./TestOutTreeEp t relu + +echo 'out tree TRT + In tree cuda, relu:' +./TestOutTreeEp tc relu + +echo 'out tree TRT + In tree cuda, resnet:' +./TestOutTreeEp tc resnet + +echo 'out tree TRT + In tree cuda, fast rcnn:' +./TestOutTreeEp tc rcnn + +echo 'out tree TRT + In tree cuda, tiny yolov3:' +./TestOutTreeEp tc tyolo + +echo 'out tree TRT + In tree cuda, yolov3:' +./TestOutTreeEp tc yolo + +echo 'out tree TRT + In tree cuda, control flow:' +./TestOutTreeEp tc cf diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index b6f676d50daef..34f4fce0b0102 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -1,4 +1,4 @@ -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include #include diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index d9643a6f88309..679babd03874b 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -8,15 +8,13 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe type = ep_type; OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + const OrtGraphApi* ort_graph_api = api->GetGraphApi(ORT_API_VERSION); std::vector cache; - size_t nodes_count = 0; const size_t* nodes_index = nullptr; - api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 0, &nodes_count, &nodes_index); + size_t nodes_count = ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 0, &nodes_index); for (size_t i = 0; i < nodes_count; i++) { - const OrtNode* node = nullptr; - api->OrtGraph_GetOrtNode(graph, nodes_index[i], &node); - const char* node_op_type; - api->OrtNode_GetOpType(node, &node_op_type); + const OrtNode* node = ort_graph_api->OrtGraph_GetOrtNode(graph, nodes_index[i]); + const char* node_op_type = ort_graph_api->OrtNode_GetOpType(node); if (!strcmp(node_op_type, "Relu")) { OrtIndexedSubGraph* subgraph = new OrtIndexedSubGraph(); subgraph->node_index_len = 1; @@ -26,20 +24,17 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe subgraph->meta_def = new OrtMetaDef(); subgraph->meta_def->name = "Relu_subgraph"; subgraph->meta_def->input_len = 0; - api->OrtNode_GetInputSize(node, &(subgraph->meta_def->input_len)); + subgraph->meta_def->input_len = ort_graph_api->OrtNode_GetInputSize(node); subgraph->meta_def->inputs = new char* [subgraph->meta_def->input_len]; for (size_t j = 0; j < subgraph->meta_def->input_len; j++) { - const char* input_j = nullptr; - api->OrtNode_GetIthInputName(node, j, &input_j); + const char* input_j = ort_graph_api->OrtNode_GetIthInputName(node, j); subgraph->meta_def->inputs[j] = const_cast(input_j); } - subgraph->meta_def->output_len = 0; - api->OrtNode_GetOutputSize(node, &(subgraph->meta_def->output_len)); + subgraph->meta_def->output_len = ort_graph_api->OrtNode_GetOutputSize(node); subgraph->meta_def->outputs = new char* [subgraph->meta_def->output_len]; for (size_t j = 0; j < subgraph->meta_def->output_len; j++) { - const char* output_j = nullptr; - api->OrtNode_GetIthOutputName(node, j, &output_j); + const char* output_j = ort_graph_api->OrtNode_GetIthOutputName(node, j); subgraph->meta_def->outputs[j] = const_cast(output_j); } diff --git a/samples/outTreeEp/out_tree_ep.h b/samples/outTreeEp/out_tree_ep.h index cd4b49cabd2c6..cd09625a40cca 100644 --- a/samples/outTreeEp/out_tree_ep.h +++ b/samples/outTreeEp/out_tree_ep.h @@ -1,5 +1,5 @@ #pragma once -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include #ifdef _WIN32 diff --git a/samples/outTreeEp_kernel/kernel_ep.h b/samples/outTreeEp_kernel/kernel_ep.h index 85c0bec6c302e..734b0cd9a6b98 100644 --- a/samples/outTreeEp_kernel/kernel_ep.h +++ b/samples/outTreeEp_kernel/kernel_ep.h @@ -1,5 +1,5 @@ #pragma once -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include #ifdef _WIN32 diff --git a/samples/qnnEp/qnn_execution_provider.h b/samples/qnnEp/qnn_execution_provider.h index 91fa97158de0b..083d7bb72c733 100644 --- a/samples/qnnEp/qnn_execution_provider.h +++ b/samples/qnnEp/qnn_execution_provider.h @@ -1,5 +1,5 @@ #pragma once -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include "core/framework/provider_options.h" #include diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index afb88345675cb..66d99faa51b09 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -8,16 +8,14 @@ namespace onnxruntime { bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - int maxNodeIndex = 0; - api->OrtGraph_MaxNodeIndex(graph_viewer, &maxNodeIndex); + const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION); + int maxNodeIndex = graph_api->OrtGraph_MaxNodeIndex(graph_viewer); for (int i = 0; i < maxNodeIndex; ++i) { - const OrtNode* node = nullptr; - api->OrtGraph_GetOrtNode(graph_viewer, i, &node); + const OrtNode* node = graph_api->OrtGraph_GetOrtNode(graph_viewer, i); if (node == nullptr) { continue; } - const char* opType = nullptr; - api->OrtNode_GetOpType(node, &opType); + const char* opType = graph_api->OrtNode_GetOpType(node); if (strcmp(opType, EPCONTEXT_OP.c_str()) == 0) { return true; } @@ -118,13 +116,12 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView if (!ValidateEPCtxNode(graph_viewer)) { return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "It's not a valid EP Context node"); } - const OrtNode* node = nullptr; - api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); + const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0); - const int64_t embed_mode = api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str()); + const int64_t embed_mode = graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str()); if (embed_mode) { // Get engine from byte stream. - const std::string& context_binary(api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str())); + const std::string& context_binary(graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str())); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; @@ -133,7 +130,7 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView } } else { // Get engine from cache file. - std::string cache_path(api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str())); + std::string cache_path(graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str())); // For security purpose, in the case of running context model, TRT EP won't allow // engine cache path to be the relative path like "../file_path" or the absolute path. @@ -185,7 +182,7 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); if (weight_stripped_engine_refit_) { - const std::string onnx_model_filename(api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str())); + const std::string onnx_model_filename(graph_api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str())); std::string weight_stripped_engine_cache = engine_cache_path.string(); auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, @@ -203,18 +200,15 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView } bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_viewer) { - assert(api_->OrtGraph_NumberOfNodes(graph_viewer) == 1); - const OrtNode* node = nullptr; - api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); - const char* opType = nullptr; - api_->OrtNode_GetOpType(node, &opType); + assert(graph_api_->OrtGraph_NumberOfNodes(graph_viewer) == 1); + const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0); + const char* opType = graph_api_->OrtNode_GetOpType(node); assert(strcmp(opType, EPCONTEXT_OP.c_str()) == 0); - size_t key_count = 0; - api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str(), &key_count); + size_t key_count = graph_api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str()); // Show the warning if compute capability is not matched if (key_count > 0) { - const char* model_compute_capability = api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str()); + const char* model_compute_capability = graph_api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str()); // Verify if engine was compiled with ampere+ hardware compatibility enabled if (strcmp(model_compute_capability, "80+") == 0) { // if (std::stoi(compute_capability_) < 80) { @@ -228,12 +222,12 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_vi } // "embed_mode" attr and "ep_cache_context" attr should be present - api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str(), &key_count); + key_count = graph_api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str()); assert(key_count > 0); - api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str(), &key_count); + key_count = graph_api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str()); assert(key_count > 0); - const int64_t embed_mode = api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str()); + const int64_t embed_mode = graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str()); if (embed_mode == 1) { // engine binary data // LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.h b/samples/tensorRTEp/onnx_ctx_model_helper.h index c90574ebd4bae..60dcea9164b0a 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.h +++ b/samples/tensorRTEp/onnx_ctx_model_helper.h @@ -6,7 +6,7 @@ #include #include #include -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include "nv_includes.h" namespace onnxruntime { @@ -46,6 +46,7 @@ class TensorRTCacheModelHandler { onnx_model_folder_path_(onnx_model_folder_path), detailed_build_log_(detailed_build_log) { api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); + graph_api_ = api_->GetGraphApi(ORT_API_VERSION); } bool ValidateEPCtxNode(const OrtGraphViewer* graph_viewer); @@ -60,5 +61,6 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path_; bool detailed_build_log_; const OrtApi* api_; + const OrtGraphApi* graph_api_; }; // TRTCacheModelHandler } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index fcdfd494f081c..6e6b5ce6b0d7a 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -27,13 +27,13 @@ void CUDA_RETURN_IF_ERROR(cudaError_t res) { if (res != cudaSuccess) abort(); } namespace onnxruntime { -//static const std::string +static const std::string tensorrtEp = "tensorrtEp"; struct MemcpyFromHost : OrtCustomOp { MemcpyFromHost() { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const struct OrtCustomOp* op) { return "MemcpyFromHost"; }; - OrtCustomOp::GetExecutionProviderType = [](const struct OrtCustomOp* op) { return "tensorrtEp"; }; + OrtCustomOp::GetExecutionProviderType = [](const struct OrtCustomOp* op) { return tensorrtEp.c_str(); }; OrtCustomOp::CreateKernelV2 = [](const struct OrtCustomOp* op, const OrtApi* api, const OrtKernelInfo* info, void** kernel) -> OrtStatusPtr { return nullptr; }; @@ -76,6 +76,7 @@ struct MemcpyFromHost : OrtCustomOp { template using IAllocatorUniquePtr = std::unique_ptr>; const OrtApi* TensorrtExecutionProvider::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +const OrtGraphApi* TensorrtExecutionProvider::graph_api_ = TensorrtExecutionProvider::api_->GetGraphApi(ORT_API_VERSION); // Check if cycle exists in the graph after partitioning bool FindCycleHelper(size_t i, const std::list* adjacency_map, bool visited[], bool* st, std::vector& cycles) { @@ -972,10 +973,8 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, // Detect and remove cycles from supported node list bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles) const { - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - size_t node_count = 0; const size_t* nodes_index = nullptr; - api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_count, &nodes_index); + size_t node_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index); bool trt_cycle = true, cycle_detected = false; while (trt_cycle) { trt_cycle = false; @@ -1019,37 +1018,29 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& // Add non TensorRT nodes to the maps for (const auto& index : non_trt_node_index) { - const OrtNode* node = nullptr; - api->OrtGraph_GetOrtNode(graph, index, &node); - const char* node_name_char = nullptr; - api->OrtNode_GetName(node, &node_name_char); + const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph, index); + const char* node_name_char = graph_api_->OrtNode_GetName(node); const std::string node_name(node_name_char); if (node_to_index_map.find(node_name) == node_to_index_map.end()) { index_to_node_map[id] = node_name; node_to_index_map[node_name] = id++; } - size_t input_count = 0; - api->OrtNode_GetInputSize(node, &input_count); + size_t input_count = graph_api_->OrtNode_GetInputSize(node); for (size_t i = 0; i < input_count; ++i) { - const char* input_name_char = nullptr; - api->OrtNode_GetIthInputName(node, i, &input_name_char); + const char* input_name_char = graph_api_->OrtNode_GetIthInputName(node, i); input_to_nodes_map[std::string(input_name_char)].insert(node_name); } - size_t implicit_input_count = 0; - api->OrtNode_GetImplicitInputSize(node, &implicit_input_count); + size_t implicit_input_count = graph_api_->OrtNode_GetImplicitInputSize(node); for (size_t i = 0; i < implicit_input_count; ++i) { - const char* input_name_char = nullptr; - api->OrtNode_GetIthImplicitInputName(node, i, &input_name_char); + const char* input_name_char = graph_api_->OrtNode_GetIthImplicitInputName(node, i); input_to_nodes_map[std::string(input_name_char)].insert(node_name); } - size_t output_count = 0; - api->OrtNode_GetOutputSize(node, &output_count); + size_t output_count = graph_api_->OrtNode_GetOutputSize(node); for (size_t i = 0; i < output_count; ++i) { - const char* output_name_char = nullptr; - api->OrtNode_GetIthOutputName(node, i, &output_name_char); + const char* output_name_char = graph_api_->OrtNode_GetIthOutputName(node, i); node_to_outputs_map[node_name].insert(std::string(output_name_char)); } } @@ -1109,16 +1100,11 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& // Check the graph is the subgraph of control flow op bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const { - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - const OrtGraph* cur_graph = nullptr; - api->OrtGraph_GetOrtGraph(graph, &cur_graph); - bool is_subgraph = false; - api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + const OrtGraph* cur_graph = graph_api_->OrtGraph_GetOrtGraph(graph); + bool is_subgraph = graph_api_->OrtGraph_IsSubgraph(cur_graph); if (is_subgraph) { - const OrtNode* node = nullptr; - api->OrtGraph_GetParenNode(graph, &node); - const char* node_op_type; - api->OrtNode_GetOpType(node, &node_op_type); + const OrtNode* node = graph_api_->OrtGraph_GetParenNode(graph); + const char* node_op_type = graph_api_->OrtNode_GetOpType(node); if (control_flow_op_set_.find(std::string(node_op_type)) != control_flow_op_set_.end()) { return true; } @@ -1128,18 +1114,14 @@ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* // Check whether all the nodes of the graph are assigned to specific ep bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const { - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - const int number_of_ort_nodes = api->OrtGraph_NumberOfNodes(graph); + const int number_of_ort_nodes = graph_api_->OrtGraph_NumberOfNodes(graph); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - size_t node_count = 0; const size_t* nodes_index = nullptr; - api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_count, &nodes_index); + size_t node_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index); for (const auto& index : nodes_vector) { - const OrtNode* node = nullptr; - api->OrtGraph_GetOrtNode(graph, nodes_index[index], &node); - const char* node_ep_type; - api->OrtNode_GetExecutionProviderType(node, &node_ep_type); + const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph, nodes_index[index]); + const char* node_ep_type = graph_api_->OrtNode_GetExecutionProviderType(node); if (strcmp(node_ep_type, provider_type.c_str())) { return false; } @@ -1160,9 +1142,8 @@ bool TensorrtExecutionProvider::IsSubGraphFullySupported(SubGraphCollection_t su } std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const OrtGraphViewer* graph, const HashValue& model_hash, int subgraph_index) const { - size_t nodes_count = 0; const size_t* node_index = nullptr; - api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_count, &node_index); + size_t nodes_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_index); std::unordered_set node_set; node_set.reserve(graph_nodes_index.first.size()); for (const auto& index : graph_nodes_index.first) { @@ -1171,9 +1152,9 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr // Get parent graph output names std::unordered_set graph_output_names; - size_t graph_output_size = api_->OrtGraph_GetOutputSize(graph); + size_t graph_output_size = graph_api_->OrtGraph_GetOutputSize(graph); for (size_t i = 0; i < graph_output_size; i++) { - graph_output_names.insert(api_->OrtGraph_GetIthOutputName(graph, i)); + graph_output_names.insert(graph_api_->OrtGraph_GetIthOutputName(graph, i)); } // Find inputs and outputs of the subgraph @@ -1191,76 +1172,59 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr int i = 0; for (const auto& index : graph_nodes_index.first) { sub_graph->node_index[i++] = node_index[index]; - const OrtNode* node = nullptr; - api_->OrtGraph_GetOrtNode(graph, node_index[index], &node); - size_t input_size = 0; - api_->OrtNode_GetInputSize(node, &input_size); + const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph, node_index[index]); + size_t input_size = graph_api_->OrtNode_GetInputSize(node); for (size_t j = 0; j < input_size; j++) { - const char* input_name = nullptr; - api_->OrtNode_GetIthInputName(node, j, &input_name); - bool is_constant_initializer = false; - api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_constant_initializer); - if (is_constant_initializer) { + const char* input_name = graph_api_->OrtNode_GetIthInputName(node, j); + if (graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true)) { initializers.push_back(input_name); continue; } - const OrtNode* producer = nullptr; - api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); + const OrtNode* producer = graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name); // If the input is not produced by any node, it is a graph input if (producer == nullptr) { input_to_order[input_name] = input_order++; continue; } - size_t producer_index = 0; - api_->OrtNode_GetIndex(producer, &producer_index); + size_t producer_index = graph_api_->OrtNode_GetIndex(producer); // If the producer node is not in the subgraph, the input is a graph input if (node_set.find(producer_index) == node_set.end()) { input_to_order[input_name] = input_order++; } } - size_t implicit_input_size = 0; - api_->OrtNode_GetImplicitInputSize(node, &implicit_input_size); + size_t implicit_input_size = graph_api_->OrtNode_GetImplicitInputSize(node); for (size_t j = 0; j < implicit_input_size; j++) { - const char* input_name = nullptr; - api_->OrtNode_GetIthImplicitInputName(node, j, &input_name); - bool is_constant_initializer = false; - api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_constant_initializer); - if (is_constant_initializer) { + const char* input_name = graph_api_->OrtNode_GetIthImplicitInputName(node, j); + if (graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true)) { initializers.push_back(input_name); continue; } - const OrtNode* producer = nullptr; - api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); + const OrtNode* producer = graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name); // If the input is not produced by any node, it is a graph input if (producer == nullptr) { input_to_order[input_name] = input_order++; continue; } - size_t producer_index = 0; - api_->OrtNode_GetIndex(producer, &producer_index); + size_t producer_index = graph_api_->OrtNode_GetIndex(producer); // If the producer node is not in the subgraph, the input is a graph input if (node_set.find(producer_index) == node_set.end()) { input_to_order[input_name] = input_order++; } } - size_t output_size = 0; - api_->OrtNode_GetOutputSize(node, &output_size); + size_t output_size = graph_api_->OrtNode_GetOutputSize(node); for (size_t j = 0; j < output_size; j++) { - const char* output_name = nullptr; - api_->OrtNode_GetIthOutputName(node, j, &output_name); + const char* output_name = graph_api_->OrtNode_GetIthOutputName(node, j); // If the output is the graph output, it is a subgraph output if (graph_output_names.find(output_name) != graph_output_names.end()) { output_to_order[output_name] = output_order++; continue; } - size_t consumer_count = 0; const OrtNode** consumers = nullptr; - api_->OrtGraph_GetNodesConsumingInput(graph, output_name, &consumer_count, &consumers); + size_t consumer_count = graph_api_->OrtGraph_GetNodesConsumingInput(graph, output_name, &consumers); for (size_t k = 0; k < consumer_count; k++) { - size_t consumer_index = 0; - api_->OrtNode_GetIndex(consumers[k], &consumer_index); + size_t consumer_index = graph_api_->OrtNode_GetIndex(consumers[k]); // If the consumer node is not in the subgraph, the output is a subgraph output if (node_set.find(consumer_index) == node_set.end()) { output_to_order[output_name] = output_order++; @@ -1281,12 +1245,10 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr // Generate unique kernel name for TRT subgraph std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); - const OrtGraph* cur_graph = nullptr; - api_->OrtGraph_GetOrtGraph(graph, &cur_graph); - bool is_subgraph = false; - api_->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + const OrtGraph* cur_graph = graph_api_->OrtGraph_GetOrtGraph(graph); + bool is_subgraph = graph_api_->OrtGraph_IsSubgraph(cur_graph); const std::string graph_type = is_subgraph ? "subgraph" : "graph"; - const char* graph_name = api_->OrtGraph_GetName(graph); + const char* graph_name = graph_api_->OrtGraph_GetName(graph); std::string meta_def_name = "TRTKernel_" + graph_type + "_" + std::string(graph_name) + subgraph_id; sub_graph->meta_def->name = new char [meta_def_name.length() + 1]; strcpy(sub_graph->meta_def->name, meta_def_name.c_str()); @@ -1324,14 +1286,9 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& ep_info) : OrtExecutionProvider() { OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); const TensorrtExecutionProvider* p = static_cast(this_); - const OrtGraphApi* g_ort_graph_api = api->GetGraphApi(ORT_API_VERSION); - int num_nodes = 0; - g_ort_graph_api->OrtGraph_PlaceHolder(graph, &num_nodes); // Get ModelPath - const std::filesystem::path* model_path = nullptr; - api->OrtGraph_GetModelPath(graph, (const void**)&model_path); + const std::filesystem::path* model_path = static_cast(graph_api_->OrtGraph_GetModelPath(graph)); const auto& path_string = model_path->string(); #ifdef _WIN32 std::strncpy_s(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); @@ -1340,7 +1297,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const #endif p->model_path_[sizeof(p->model_path_) - 1] = '\0'; - if (api->OrtGraph_NumberOfNodes(graph) == 1 && GraphHasCtxNode(graph)) { + if (graph_api_->OrtGraph_NumberOfNodes(graph) == 1 && GraphHasCtxNode(graph)) { SubGraph_t supported_node_vector = {{0}, true}; std::unique_ptr sub_graph = p->GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); *cnt = 1; @@ -1353,19 +1310,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const HashValue model_hash = TRTGenerateId(graph); // Get supported node list from TensorRT parser - const int number_of_ort_nodes = api->OrtGraph_NumberOfNodes(graph); + const int number_of_ort_nodes = graph_api_->OrtGraph_NumberOfNodes(graph); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); std::vector filtered_nodes_vector; - size_t nodes_count = 0; const size_t* nodes_index = nullptr; - api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_count, &nodes_index); + size_t nodes_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index); for (const auto& index : nodes_vector) { - const OrtNode* node = nullptr; - api->OrtGraph_GetOrtNode(graph, nodes_index[index], &node); - const char* node_op_type; - api->OrtNode_GetOpType(node, &node_op_type); + const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph, nodes_index[index]); + const char* node_op_type = graph_api_->OrtNode_GetOpType(node); /* If current node is control flow op, we take different approach based on following four cases: * @@ -1377,17 +1331,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. */ if (p->control_flow_op_set_.find(std::string(node_op_type)) != p->control_flow_op_set_.end()) { - size_t subgraph_count = 0; const OrtGraphViewer** subgraphs = nullptr; - api->OrtNode_GetSubgraphs(node, &subgraph_count, &subgraphs); + size_t subgraph_count = graph_api_->OrtNode_GetSubgraphs(node, &subgraphs); if (subgraph_count != 0) { bool all_subgraphs_are_supported = true; for (size_t i = 0; i < subgraph_count; i++) { // TRT EP should consider the empty subgraph is fully supported by TRT. - if (api->OrtGraph_NumberOfNodes(subgraphs[i]) == 0) { + if (graph_api_->OrtGraph_NumberOfNodes(subgraphs[i]) == 0) { continue; } - if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], "tensorrtEp")) { + if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) { all_subgraphs_are_supported = false; break; } @@ -1445,25 +1398,20 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT. - const OrtNode* parent_node = nullptr; - api->OrtGraph_GetParenNode(graph, &parent_node); - const char* parent_node_op_type = nullptr; - api->OrtNode_GetOpType(parent_node, &parent_node_op_type); + const OrtNode* parent_node = graph_api_->OrtGraph_GetParenNode(graph); + const char* parent_node_op_type = graph_api_->OrtNode_GetOpType(parent_node); if (strcmp(parent_node_op_type, "If") == 0) { all_subgraphs_are_supported = false; SubGraphCollection_t subgraph_supported_nodes_vector; - size_t subgraph_count = 0; const OrtGraphViewer** subgraphs = nullptr; - api->OrtNode_GetSubgraphs(parent_node, &subgraph_count, &subgraphs); - const OrtGraph* origin_graph = nullptr; - api->OrtGraph_GetOrtGraph(graph, &origin_graph); + size_t subgraph_count = graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs); + const OrtGraph* origin_graph = graph_api_->OrtGraph_GetOrtGraph(graph); for (size_t i = 0; i < subgraph_count; i++) { - const OrtGraph* subgraph = nullptr; - api->OrtGraph_GetOrtGraph(subgraphs[i], &subgraph); + const OrtGraph* subgraph = graph_api_->OrtGraph_GetOrtGraph(subgraphs[i]); if (subgraph == origin_graph) { continue; } - const int number_of_ort_subgraph_nodes = api->OrtGraph_NumberOfNodes(subgraphs[i]); + const int number_of_ort_subgraph_nodes = graph_api_->OrtGraph_NumberOfNodes(subgraphs[i]); std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; @@ -1476,7 +1424,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const break; } // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. - else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], "tensorrtEp")) { + else if (p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) { all_subgraphs_are_supported = true; break; } @@ -1544,25 +1492,20 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const }; OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info) -> OrtStatusPtr { - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); TensorrtExecutionProvider* p = static_cast(this_); this_->extra_param_for_create_state_func = p; this_->extra_param_for_compute_func = p; for (size_t j = 0; j < cnt; j++) { std::unordered_map input_map, output_map; - size_t input_size = 0; - api->OrtNode_GetInputSize(node[j], &input_size); + size_t input_size = graph_api_->OrtNode_GetInputSize(node[j]); for (size_t i = 0; i < input_size; i++) { - const char* ith_input_name = nullptr; - api->OrtNode_GetIthInputName(node[j], i, &ith_input_name); + const char* ith_input_name = graph_api_->OrtNode_GetIthInputName(node[j], i); input_map[ith_input_name] = i; } - size_t output_size = 0; - api->OrtNode_GetOutputSize(node[j], &output_size); + size_t output_size = graph_api_->OrtNode_GetOutputSize(node[j]); for (size_t i = 0; i < output_size; i++) { - const char* ith_output_name = nullptr; - api->OrtNode_GetIthOutputName(node[j], i, &ith_output_name); + const char* ith_output_name = graph_api_->OrtNode_GetIthOutputName(node[j], i); if (ith_output_name != nullptr) { output_map[ith_output_name] = i; } @@ -1574,7 +1517,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } else { ret = p->CreateNodeComputeInfoFromGraph(graph[j], node[j], input_map, output_map, &node_compute_info[j]); } - if (ret != nullptr) return api->CreateStatus(api->GetErrorCode(ret), api->GetErrorMessage(ret)); + if (ret != nullptr) return api_->CreateStatus(api_->GetErrorCode(ret), api_->GetErrorMessage(ret)); } return nullptr; }; @@ -1582,11 +1525,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const OrtExecutionProvider::CanCopy = [](const OrtDevice* source, const OrtDevice* target) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); OrtMemoryInfoDeviceType source_device_type, target_device_type; - api->DeviceGetDeviceType(source, &source_device_type); - api->DeviceGetDeviceType(target, &target_device_type); + api_->DeviceGetDeviceType(source, &source_device_type); + api_->DeviceGetDeviceType(target, &target_device_type); OrtMemoryType source_mem_type, target_mem_type; - api->DeviceGetMemoryType(source, &source_mem_type); - api->DeviceGetMemoryType(target, &target_mem_type); + api_->DeviceGetMemoryType(source, &source_mem_type); + api_->DeviceGetMemoryType(target, &target_mem_type); return source_device_type == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU || source_mem_type == OrtMemoryType::OrtMemoryType_CUDA_PINNED || @@ -1648,11 +1591,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); OrtTypeConstraints* type_constraints = nullptr; - api->CreateOrtTypeConstraints(&type_constraints); - api->AddTypeConstraint(type_constraints, "T", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); // TODO(leca): other data type + api_->CreateOrtTypeConstraints(&type_constraints); + api_->AddTypeConstraint(type_constraints, "T", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); // TODO(leca): other data type OrtCustomOp* op = new MemcpyFromHost(); - api->OrtKernelRegistry_RegisterKernel(kernel_registry, op, type_constraints); - api->ReleaseTypeConstraints(type_constraints); + api_->OrtKernelRegistry_RegisterKernel(kernel_registry, op, type_constraints); + api_->ReleaseTypeConstraints(type_constraints); }; info_ = TensorrtExecutionProviderInfo::FromProviderOptions(ep_info); @@ -2082,7 +2025,7 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { ProviderOptions options; for (size_t i = 0; i < option_size; i++) options[ep_option_keys[i]] = ep_option_values[i]; - std::unique_ptr ret = std::make_unique("tensorrtEp", std::move(options)); + std::unique_ptr ret = std::make_unique(tensorrtEp.c_str(), std::move(options)); return ret.release(); }; } @@ -2171,7 +2114,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); void* buf_data = nullptr; - size_t buf_size = api_->OrtGraph_SerializeToArray(graph_body_viewer, &buf_data); + size_t buf_size = graph_api_->OrtGraph_SerializeToArray(graph_body_viewer, &buf_data); trt_parser->parse(buf_data, buf_size, model_path_); trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); @@ -2339,8 +2282,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } } - const char* node_name = nullptr; - api_->OrtNode_GetName(fused_node, &node_name); + const char* node_name = graph_api_->OrtNode_GetName(fused_node); // Load INT8 calibration table std::unordered_map dynamic_range_map; @@ -2714,7 +2656,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort if (iter != output_map.end()) { output_indexes[output_name] = iter->second; } - output_types[output_name] = api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); + output_types[output_name] = graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); } // Save TRT engine, other TRT objects and input/output info to map @@ -2816,14 +2758,14 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort std::unordered_set input_names; OrtMemoryInfo* mem_info = nullptr; - api->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); + api_->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); if (this_->alloc_ == nullptr) { - Ort::ThrowOnError(api->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); + Ort::ThrowOnError(api_->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); } OrtAllocator* alloc = this_->alloc_; void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + Ort::ThrowOnError(api_->KernelContext_GetGPUComputeStream(context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache @@ -2886,7 +2828,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort *(trt_state->engine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); if (!(*(trt_state->engine))) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); } //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; trt_engine = trt_state->engine->get(); @@ -2898,11 +2840,11 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // Decrypt engine size_t engine_size = 0; if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not get engine buffer size"); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not get engine buffer size"); } std::unique_ptr engine_buf{new char[engine_size]}; if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); } // Deserialize engine // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc @@ -2910,7 +2852,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort trt_state->engine->reset(); *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); if (!(*(trt_state->engine))) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path).c_str()); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path).c_str()); } //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; trt_engine = trt_state->engine->get(); @@ -2929,7 +2871,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort if (shape_ranges.find(input_name) != shape_ranges.end()) { auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, shape_tensor_values, shape_tensor_values_int64, stream, &engine_update); if (status != nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); } } } @@ -2956,7 +2898,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort #pragma warning(pop) #endif if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); } } @@ -3032,7 +2974,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); if (timing_cache == nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not create timing cache: " + timing_cache_path).c_str()); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not create timing cache: " + timing_cache_path).c_str()); } trt_config->setTimingCache(*timing_cache, this_->force_timing_cache_match_); if (this_->detailed_build_log_) { @@ -3057,12 +2999,12 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort serialized_engine = std::unique_ptr( trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); if (!serialized_engine) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create engine from network."); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create engine from network."); } *(trt_state->engine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (!(*(trt_state->engine))) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to deserialize engine."); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to deserialize engine."); } if (this_->detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); @@ -3070,7 +3012,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } } if (!(*(trt_state->engine))) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP Failed to Build Engine."); } trt_engine = trt_state->engine->get(); if (trt_state->engine_cache_enable) { @@ -3083,7 +3025,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. if (trt_state->engine_encryption != nullptr) { if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine encryption function encrypt"); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not call engine encryption function encrypt"); } //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; } else { @@ -3101,7 +3043,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort auto timing_cache = trt_config->getTimingCache(); std::unique_ptr timingCacheHostData{timing_cache->serialize()}; if (timingCacheHostData == nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not serialize timing cache: " + timing_cache_path).c_str()); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not serialize timing cache: " + timing_cache_path).c_str()); } saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); if (this_->detailed_build_log_) { @@ -3125,7 +3067,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort true /* serialize refitted engine to disk */, this_->detailed_build_log_); if (status != nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); } } } @@ -3144,7 +3086,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort trt_state->engine->get()->createExecutionContext()); } if (!(*(trt_state->context))) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create context."); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP failed to create context."); } trt_context = trt_state->context->get(); } @@ -3180,7 +3122,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); if (status != nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); } } @@ -3213,7 +3155,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort OrtStatusPtr status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); if (status != nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); } } @@ -3244,7 +3186,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // Run TRT inference if (!trt_context->enqueueV3(stream)) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP execution context enqueue failed."); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP execution context enqueue failed."); } /* @@ -3285,7 +3227,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } auto status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); if (status != nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); } } else { auto& output_tensor = output_tensors[i]; @@ -3375,8 +3317,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi trt_context = std::unique_ptr(trt_engine->createExecutionContext()); } - const char* fused_node_name = nullptr; - api_->OrtNode_GetName(fused_node, &fused_node_name); + const char* fused_node_name = graph_api_->OrtNode_GetName(fused_node); if (!trt_context) { return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not build execution context for fused node: " + std::string(fused_node_name)).c_str()); @@ -3400,9 +3341,9 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi } // Create output to type map - size_t graph_output_size = api_->OrtGraph_GetOutputSize(graph_body_viewer); + size_t graph_output_size = graph_api_->OrtGraph_GetOutputSize(graph_body_viewer); for (size_t i = 0; i < graph_output_size; i++) { - output_types[api_->OrtGraph_GetIthOutputName(graph_body_viewer, i)] = api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); + output_types[graph_api_->OrtGraph_GetIthOutputName(graph_body_viewer, i)] = graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); } // Save TRT engine, TRT context and input/output info to map @@ -3457,14 +3398,14 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input OrtMemoryInfo* mem_info = nullptr; - api->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); + api_->CreateMemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, this_->device_id_, OrtMemType::OrtMemTypeDefault, &mem_info); if (this_->alloc_ == nullptr) { - Ort::ThrowOnError(api->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); + Ort::ThrowOnError(api_->KernelContext_GetAllocator(context, mem_info, &(this_->alloc_))); } OrtAllocator* alloc = this_->alloc_; void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + Ort::ThrowOnError(api_->KernelContext_GetGPUComputeStream(context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); // Get input and output binding names @@ -3495,7 +3436,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi OrtStatusPtr status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); if (status != nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); } } @@ -3528,7 +3469,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi OrtStatusPtr status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); if (status != nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api->GetErrorMessage(status)); + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); } } @@ -3559,7 +3500,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // Run TRT inference if (!trt_context->enqueueV3(stream)) { - return api->CreateStatus(OrtErrorCode::ORT_FAIL, "TensorRT EP execution context enqueue failed."); + return api_->CreateStatus(OrtErrorCode::ORT_FAIL, "TensorRT EP execution context enqueue failed."); } /* @@ -3600,7 +3541,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi } OrtStatusPtr status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); if (status != nullptr) { - return api->CreateStatus(OrtErrorCode::ORT_FAIL, api->GetErrorMessage(status)); + return api_->CreateStatus(OrtErrorCode::ORT_FAIL, api_->GetErrorMessage(status)); } } else { auto& output_tensor = output_tensors[i]; @@ -3652,9 +3593,8 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } iterations++; - size_t nodes_count = 0; const size_t* node_index = nullptr; - api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_count, &node_index); + size_t nodes_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_index); for (const auto& group : nodes_vector_input) { // Construct subgraph if (!group.first.empty()) { @@ -3663,10 +3603,10 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } else { const OrtGraphViewer* sub_graph_viewer = nullptr; - api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); + graph_api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); void* buf_data = nullptr; - size_t buf_size = api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data); + size_t buf_size = graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data); // Get supported node list recursively SubGraphCollection_t parser_nodes_list; @@ -3690,9 +3630,8 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect #endif SubGraphCollection_t next_nodes_list; - size_t subgraph_node_count = 0; const size_t* subgraph_node_index = nullptr; - api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_count, &subgraph_node_index); + size_t subgraph_node_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index); next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 56feecea84e4e..255a8d411e014 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -241,6 +241,7 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; static const OrtApi* api_; + static const OrtGraphApi* graph_api_; std::unordered_map trt_node_name_with_precision_; std::unordered_map> dynamic_range_map_; std::unordered_map cache_suffix_; diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index e9a9ff0cd46c1..124d85657e222 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -271,15 +271,12 @@ std::string GetTimingCachePath(const std::string& root, std::string& compute_cap HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { HashValue model_hash = 0; const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - const OrtGraph* cur_graph = nullptr; - api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph); - bool is_subgraph = false; - api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION); + const OrtGraph* cur_graph = graph_api->OrtGraph_GetOrtGraph(graph_viewer); + bool is_subgraph = graph_api->OrtGraph_IsSubgraph(cur_graph); while (is_subgraph) { - const OrtGraph* parent_graph = nullptr; - api->OrtGraph_GetParentGraph(cur_graph, &parent_graph); - cur_graph = parent_graph; - api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + cur_graph = graph_api->OrtGraph_GetParentGraph(cur_graph); + is_subgraph = graph_api->OrtGraph_IsSubgraph(cur_graph); } const OrtGraph* main_graph = cur_graph; @@ -289,8 +286,7 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); }; - const std::filesystem::path* model_path = nullptr; - api->OrtGraph_GetModelPath(graph_viewer, (const void**)&model_path); + const std::filesystem::path* model_path = static_cast(graph_api->OrtGraph_GetModelPath(graph_viewer)); // Use the model's file name instead of the entire path to avoid cache regeneration if path changes if (model_path->has_filename()) { @@ -312,27 +308,22 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { // fingerprint current graph by hashing graph inputs // const std::vector& input_names = nullptr; const char** input_names = nullptr; - size_t input_count = 0; - api->OrtGraph_GetInputsIncludingInitializers(graph_viewer, &input_count, &input_names); + size_t input_count = graph_api->OrtGraph_GetInputsIncludingInitializers(graph_viewer, &input_names); for (size_t i = 0; i < input_count; ++i) { hash_str(input_names[i]); } // hashing output of each node - const int number_of_ort_nodes = api->OrtGraph_NumberOfNodes(graph_viewer); + const int number_of_ort_nodes = graph_api->OrtGraph_NumberOfNodes(graph_viewer); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - size_t nodes_count = 0; const size_t* nodes_index = nullptr; - api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_count, &nodes_index); + size_t nodes_count = graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_index); for (const auto& index : nodes_vector) { - const OrtNode* node = nullptr; - api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index], &node); - size_t output_size = 0; - api->OrtNode_GetOutputSize(node, &output_size); + const OrtNode* node = graph_api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index]); + size_t output_size = graph_api->OrtNode_GetOutputSize(node); for (size_t i = 0; i < output_size; ++i) { - const char* output_name = nullptr; - api->OrtNode_GetIthOutputName(node, i, &output_name); + const char* output_name = graph_api->OrtNode_GetIthOutputName(node, i); if (output_name != nullptr) { hash_str(output_name); } From b1f8e2a9cdbbc643a254878d8979a8711970b34d Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 14 Oct 2024 22:15:45 +0000 Subject: [PATCH 48/81] Python API --- onnxruntime/__init__.py | 1 + .../onnxruntime_inference_collection.py | 3 +- .../python/onnxruntime_pybind_state.cc | 34 +++++++++++++++++-- .../test/python/onnxruntime_test_plugin_ep.py | 13 +++++++ samples/outTreeEp/CMakeLists.txt | 3 +- 5 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/test/python/onnxruntime_test_plugin_ep.py diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 944740a4ccad8..c96bf331706e0 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -46,6 +46,7 @@ from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401 from onnxruntime.capi._pybind_state import set_seed # noqa: F401 + from onnxruntime.capi._pybind_state import register_plugin_execution_provider_library # noqa: F401 import_capi_exception = None except Exception as e: diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index ecae280e92ae5..09fe20ebd8a5b 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -63,9 +63,10 @@ def check_and_normalize_provider_args( return [], [] provider_name_to_options = collections.OrderedDict() + plugin_eps = C.get_available_plugin_providers() def set_provider_options(name, options): - if name not in available_provider_names: + if name not in plugin_eps and name not in available_provider_names: warnings.warn( "Specified provider '{}' is not in available provider names." "Available providers: '{}'".format(name, ", ".join(available_provider_names)) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e13285c60e69f..92164728dd6c0 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -17,6 +17,7 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" +#include "core/framework/provider_factory_adapter.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" #include "core/framework/sparse_tensor.h" @@ -29,6 +30,7 @@ #include "core/session/IOBinding.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_c_api_ep.h" #include "core/session/provider_bridge_ort.h" #ifdef ENABLE_ATEN @@ -1187,6 +1189,19 @@ std::unique_ptr CreateExecutionProviderInstance( ->CreateProvider(); #endif } else { + OrtExecutionProviderFactory* plugin_ep_factory = GetEnv()->GetOrtExecutionProviderFactory(type); + if (plugin_ep_factory != nullptr) { + std::vector keys, values; + const auto it = provider_options_map.find(type); + if (it != provider_options_map.end()) { + for (const auto& [k, v] : it->second) { + keys.push_back(k.c_str()); + values.push_back(v.c_str()); + } + } + onnxruntime::ExecutionProviderFactoryAdapter ep_factory_adapter(plugin_ep_factory, keys.data(), values.data(), keys.size()); + return ep_factory_adapter.CreateProvider(); + } // check whether it is a dynamic load EP: const auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { @@ -1217,8 +1232,6 @@ std::unique_ptr CreateExecutionProviderInstance( */ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector& provider_types, const ProviderOptionsMap& provider_options_map) { - ORT_UNUSED_PARAMETER(provider_options_map); - for (const std::string& type : provider_types) { auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map); if (ep) @@ -1317,6 +1330,11 @@ static void LogDeprecationWarning( } #endif +// TODO(leca): when will this variable be unset? It is saved in Environment thus should be cross-session, which means +// once the session ends, the plugin ep should still be left in the Environment +// Should implement Environment::RemovePluginEp() which will be invoked in ~EnvInitializer(), and also clear plugin_execution_providers there +static std::unordered_set plugin_execution_providers; + void addGlobalMethods(py::module& m) { m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance."); m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer."); @@ -1465,6 +1483,18 @@ void addGlobalMethods(py::module& m) { contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); }); #endif + m.def("register_plugin_execution_provider_library", [](const char* provider_type, const char* library_path) -> void { + void* handle = nullptr; + OrtPybindThrowIfError(Env::Default().LoadDynamicLibrary(ToPathString(library_path), false, &handle)); + if (handle) { + OrtExecutionProviderFactory* (*symbol)(); + OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, "RegisterCustomEp", (void**)&symbol)); + auto env = GetEnv(); + env->InsertCustomEp(provider_type, symbol()); + plugin_execution_providers.insert(std::string(provider_type)); + } + }); + m.def("get_available_plugin_providers", []() -> std::unordered_set { return plugin_execution_providers; }); } void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { diff --git a/onnxruntime/test/python/onnxruntime_test_plugin_ep.py b/onnxruntime/test/python/onnxruntime_test_plugin_ep.py new file mode 100644 index 0000000000000..3ab4e67953c84 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_plugin_ep.py @@ -0,0 +1,13 @@ +import onnxruntime as ort +import numpy + +ort.register_plugin_execution_provider_library("outTreeEp", "/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so") + +sess_options = ort.SessionOptions() +sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL +#session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=[("CPUExecutionProvider")]) +#session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) +session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["outTreeEp", "CPUExecutionProvider"], provider_options=[{"int_property":"3", "str_property":"strvalue"}, {}]) + +y = session.run(None, {'x': numpy.array([-3.0, 5.0, -2.0, 4.0]).astype(numpy.float32)}) +print(y) diff --git a/samples/outTreeEp/CMakeLists.txt b/samples/outTreeEp/CMakeLists.txt index d4193f6f8fffa..21d416c5e098f 100644 --- a/samples/outTreeEp/CMakeLists.txt +++ b/samples/outTreeEp/CMakeLists.txt @@ -9,4 +9,5 @@ add_library(outTreeEp SHARED out_tree_ep.cc) target_include_directories(outTreeEp PUBLIC "../../include/onnxruntime") # looks we need this in Win as in Windows you cannot build shared library with undefined symbol -#target_link_libraries(outTreeEp PUBLIC "/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so") +# link ORT for Python API. Otherwise there will be error(undefined symbol: OrtGetApiBase) when loading this shared library +target_link_libraries(outTreeEp PUBLIC "/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so") From 7acaaab5726b623d46ce315609efb5696585eff6 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Tue, 15 Oct 2024 19:41:35 +0800 Subject: [PATCH 49/81] fix memory leak (#22444) --- .../core/session/onnxruntime_c_api_ep.h | 5 +++++ onnxruntime/core/framework/provider_adapter.h | 2 ++ .../core/session/onnxruntime_c_api_ep.cc | 17 +++++++++++++++++ onnxruntime/core/session/ort_apis_ep.h | 4 ++++ samples/outTreeEp/out_tree_ep.cc | 11 +++++++++++ .../tensorRTEp/tensorrt_execution_provider.cc | 14 ++++++++++++++ 6 files changed, 53 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index 87d742911bef3..ce3982c86ace2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -74,6 +74,7 @@ typedef struct OrtExecutionProvider { bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); int(ORT_API_CALL* CreatePreferredAllocators)(OrtExecutionProvider* this_, OrtAllocator*** ort_allocators); + void(ORT_API_CALL* ReleaseIndexedSubGraphs)(OrtIndexedSubGraph** indexed_sub_graphs, size_t num_sub_graph); const char* type; OrtCreateStream* create_stream; const OrtDevice* default_device; @@ -128,6 +129,8 @@ size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ voi ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss +ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph); + const char*(ORT_API_CALL* OrtNode_GetName)(const OrtNode* node); const char*(ORT_API_CALL* OrtNode_GetDescription)(const OrtNode* node); @@ -181,5 +184,7 @@ int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)N float(ORT_API_CALL* OrtNode_GetAttributeFloat)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; size_t(ORT_API_CALL* OrtNode_GetSubgraphs)(const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs); + +ORT_API2_STATUS(OrtFreeMem, void* p); }; typedef struct OrtGraphApi OrtGraphApi; diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 7f5582da84b33..1a8ebd15ac27a 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -76,6 +76,8 @@ class ExecutionProviderAdapter : public IExecutionProvider { ret.push_back(std::make_unique(std::move(sb))); } + + if (indexed_subgraph && ep_impl_->ReleaseIndexedSubGraphs) ep_impl_->ReleaseIndexedSubGraphs(indexed_subgraph, cnt); return ret; } diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index 9f4ca8093cdb3..909c48d547b8a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -524,6 +524,14 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraphViewer* graph) { + if (graph) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + delete graph_viewer; + } + return nullptr; +} + ORT_API(const char*, OrtGraphApis::OrtNode_GetName, const OrtNode* node) { const ::onnxruntime::Node* n = reinterpret_cast(node); return n->Name().c_str(); @@ -676,6 +684,13 @@ ORT_API(size_t, OrtGraphApis::OrtNode_GetSubgraphs, const OrtNode* node, _Outptr return ret; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtFreeMem, void* p) { + if (p) { + free(p); + } + return nullptr; +} + static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetName, &OrtGraphApis::OrtGraph_IsConstantInitializer, @@ -698,6 +713,7 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetValueInfo, &OrtGraphApis::OrtGraph_SerializeToArray, &OrtGraphApis::OrtGraph_GetSubGraph, + &OrtGraphApis::OrtGraph_ReleaseGraph, &OrtGraphApis::OrtNode_GetName, &OrtGraphApis::OrtNode_GetDescription, &OrtGraphApis::OrtNode_GetDomain, @@ -725,6 +741,7 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtNode_GetAttributeInt, &OrtGraphApis::OrtNode_GetAttributeFloat, &OrtGraphApis::OrtNode_GetSubgraphs, + &OrtGraphApis::OrtFreeMem, }; ORT_API(const OrtGraphApi*, OrtGraphApis::GetGraphApi, uint32_t) { diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index d82b5e9742e43..f34d1461fd1ef 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -45,6 +45,8 @@ ORT_API(size_t, OrtGraph_SerializeToArray, const OrtGraphViewer*, _Out_ void** d ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); +ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph); + ORT_API(const char*, OrtNode_GetName, const OrtNode* node); ORT_API(const char*, OrtNode_GetDescription, const OrtNode* node); @@ -99,4 +101,6 @@ ORT_API(float, OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) ORT_API(size_t, OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs); +ORT_API_STATUS_IMPL(OrtFreeMem, void* p); + } diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index 679babd03874b..adbf686866cfb 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -78,6 +78,17 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe } return nullptr; }; + + OrtExecutionProvider::ReleaseIndexedSubGraphs = [](OrtIndexedSubGraph** indexed_sub_graphs, size_t num_sub_graph) { + if (indexed_sub_graphs == nullptr) return; + for (size_t i = 0; i < num_sub_graph; i++) { + OrtIndexedSubGraph* sub_graph = indexed_sub_graphs[i]; + delete[] sub_graph->node_index; + delete sub_graph->meta_def; + delete sub_graph; + } + delete[] indexed_sub_graphs; + }; } OutTreeEpFactory::OutTreeEpFactory() { diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 6e6b5ce6b0d7a..2bed0827611f7 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1578,6 +1578,17 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const return ret; }; + OrtExecutionProvider::ReleaseIndexedSubGraphs = [](OrtIndexedSubGraph** indexed_sub_graphs, size_t num_sub_graph) { + if (indexed_sub_graphs == nullptr) return; + for (size_t i = 0; i < num_sub_graph; i++) { + OrtIndexedSubGraph* sub_graph = indexed_sub_graphs[i]; + delete[] sub_graph->node_index; + delete sub_graph->meta_def; + delete sub_graph; + } + delete[] indexed_sub_graphs; + }; + type = ep_type; create_stream = new OrtCreateStream(); create_stream->device_type = 1; // GPU @@ -2117,6 +2128,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort size_t buf_size = graph_api_->OrtGraph_SerializeToArray(graph_body_viewer, &buf_data); trt_parser->parse(buf_data, buf_size, model_path_); trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); + graph_api_->OrtFreeMem(buf_data); // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow if (fp16_enable_ && layer_norm_fp32_fallback_) { @@ -3625,6 +3637,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect #pragma warning(disable : 4996) #endif trt_parser->supportsModel(buf_data, buf_size, parser_nodes_list, model_path_); + graph_api_->OrtFreeMem(buf_data); #if defined(_MSC_VER) #pragma warning(pop) #endif @@ -3639,6 +3652,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } nodes_list_output.push_back(next_nodes_list[i]); } + graph_api_->OrtGraph_ReleaseGraph(sub_graph_viewer); } } } From d150a033dd82fafe4b10897732d91278f7075a55 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:19:38 +0800 Subject: [PATCH 50/81] refactor all functions in onnxruntime_c_api_ep with status as return (#22481) --- .../core/session/onnxruntime_c_api_ep.h | 94 +++---- .../core/session/onnxruntime_c_api_ep.cc | 265 +++++++++++------- onnxruntime/core/session/ort_apis_ep.h | 94 +++---- samples/outTreeEp/out_tree_ep.cc | 19 +- samples/tensorRTEp/onnx_ctx_model_helper.cc | 50 ++-- .../tensorRTEp/tensorrt_execution_provider.cc | 204 +++++++++----- .../tensorrt_execution_provider_utils.h | 32 ++- 7 files changed, 456 insertions(+), 302 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index ce3982c86ace2..f2ae20332cf7c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -87,103 +87,103 @@ typedef struct OrtExecutionProviderFactory { } OrtExecutionProviderFactory; struct OrtGraphApi { -const char*(ORT_API_CALL* OrtGraph_GetName)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtGraph_GetName, const OrtGraphViewer* graph, _Outptr_ const char** out); -bool(ORT_API_CALL* OrtGraph_IsConstantInitializer)(const OrtGraphViewer* graph, const char* name, bool check_outer_scope)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* out); -size_t(ORT_API_CALL* OrtGraph_GetNodesIndexInTopologicalOrder)(const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order); +ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order, _Out_ size_t* num_nodes); -bool(ORT_API_CALL* OrtGraph_IsSubgraph)(const OrtGraph* graph); +ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out); -const OrtGraph*(ORT_API_CALL* OrtGraph_GetParentGraph)(const OrtGraph* graph); +ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); -const OrtNode*(ORT_API_CALL* OrtGraph_GetParenNode)(const OrtGraphViewer* graph); +ORT_API2_STATUS(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); -const void*(ORT_API_CALL* OrtGraph_GetModelPath)(const OrtGraphViewer* graph); +ORT_API2_STATUS(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** model_path); -const OrtGraph*(ORT_API_CALL* OrtGraph_GetOrtGraph)(const OrtGraphViewer* graph_viewer); +ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); -size_t(ORT_API_CALL* OrtGraph_GetInputsIncludingInitializers)(const OrtGraphViewer* graph, _Outptr_ const char*** input_names); +ORT_API2_STATUS(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); -const OrtNode*(ORT_API_CALL* OrtGraph_GetOrtNode)(const OrtGraphViewer* graph, size_t node_index); +ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); -size_t(ORT_API_CALL* OrtGraph_GetNodesConsumingInput)(const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers); // TODO(leca): ValueConsumers::comprehensive ? +ORT_API2_STATUS(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers, _Out_ size_t* num_consumers); // TODO(leca): ValueConsumers::comprehensive ? -const OrtNode*(ORT_API_CALL* OrtGraph_GetNodeProducingOutput)(const OrtGraphViewer* graph, const char* output_name); +ORT_API2_STATUS(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** node); -int(ORT_API_CALL* OrtGraph_NumberOfNodes)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtGraph_NumberOfNodes, const OrtGraphViewer* graph, _Out_ int* num_nodes); -int(ORT_API_CALL* OrtGraph_MaxNodeIndex)(const OrtGraphViewer* graph); +ORT_API2_STATUS(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* max_node_index); -size_t(ORT_API_CALL* OrtGraph_GetOutputSize)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtGraph_GetOutputSize, const OrtGraphViewer* graph, _Out_ size_t* output_len); -const char*(ORT_API_CALL* OrtGraph_GetIthOutputName)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i, _Outptr_ const char** out); -int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i, _Out_ int32_t* out); -bool(ORT_API_CALL* OrtGraph_GetInitializerTensor)(const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); +ORT_API2_STATUS(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**, _Out_ bool* ret); -bool(ORT_API_CALL* OrtGraph_GetValueInfo)(const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); +ORT_API2_STATUS(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret); -size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; // TODO(leca): review and discuss +ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); // TODO(leca): review and discuss ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph); -const char*(ORT_API_CALL* OrtNode_GetName)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Outptr_ const char** out); -const char*(ORT_API_CALL* OrtNode_GetDescription)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetDescription, const OrtNode* node, _Outptr_ const char** out); -const char*(ORT_API_CALL* OrtNode_GetDomain)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetDomain, const OrtNode* node, _Outptr_ const char** out); -int(ORT_API_CALL* OrtNode_SinceVersion)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_SinceVersion, const OrtNode* node, _Out_ int* out); -const char*(ORT_API_CALL* OrtNode_GetExecutionProviderType)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** out); -const char*(ORT_API_CALL* OrtNode_GetOpType)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetOpType, const OrtNode* node, _Outptr_ const char** out); -size_t(ORT_API_CALL* OrtNode_GetImplicitInputSize)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* out); -const char*(ORT_API_CALL* OrtNode_GetIthImplicitInputName)(const OrtNode* node, size_t i); +ORT_API2_STATUS(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -size_t(ORT_API_CALL* OrtNode_GetInputSize)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* out); -const char*(ORT_API_CALL* OrtNode_GetIthInputName)(const OrtNode* node, size_t i); +ORT_API2_STATUS(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -size_t(ORT_API_CALL* OrtNode_GetOutputSize)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* out); -const char*(ORT_API_CALL* OrtNode_GetIthOutputName)(const OrtNode* node, size_t i); +ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -size_t(ORT_API_CALL* OrtNode_GetIndex)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* out); -size_t(ORT_API_CALL* OrtNode_GetAttributeNames)(const OrtNode*, _Out_ const char*** names); +ORT_API2_STATUS(OrtNode_GetAttributeNames, const OrtNode* node, _Out_ const char*** names, _Out_ size_t* num); -size_t(ORT_API_CALL* OrtNode_GetAttributeSize)(const OrtNode* node); +ORT_API2_STATUS(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* out); -int(ORT_API_CALL* OrtNode_GetAttributeType)(const OrtNode* node, const char* attribute)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; // AttributeProto_AttributeType +ORT_API2_STATUS(OrtNode_GetAttributeType, const OrtNode* node, const char* attribute, _Out_ int* out); // AttributeProto_AttributeType -size_t(ORT_API_CALL* OrtNode_GetAttributeKeyCount)(const OrtNode* node, const char* key); +ORT_API2_STATUS(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* out); -int(ORT_API_CALL* OrtNode_GetAttributeIntSize)(const OrtNode* node, const char* key); +ORT_API2_STATUS(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* out); -int(ORT_API_CALL* OrtNode_GetAttributeFloatSize)(const OrtNode* node, const char* key); +ORT_API2_STATUS(OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* out); -int(ORT_API_CALL* OrtNode_GetAttributeStringSize)(const OrtNode* node, const char* key); +ORT_API2_STATUS(OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* out); -int64_t(ORT_API_CALL* OrtNode_GetAttributeIthInt)(const OrtNode* node, const char* key, int i); +ORT_API2_STATUS(OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* out); -float(ORT_API_CALL* OrtNode_GetAttributeIthFloat)(const OrtNode* node, const char* key, int i); +ORT_API2_STATUS(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* out); -const char*(ORT_API_CALL* OrtNode_GetAttributeIthStr)(const OrtNode* node, const char* key, int i); +ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Outptr_ const char** out); -const char*(ORT_API_CALL* OrtNode_GetAttributeStr)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out); -int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtNode_GetAttributeInt, const OrtNode* node, const char* key, _Out_ int64_t* out); -float(ORT_API_CALL* OrtNode_GetAttributeFloat)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; +ORT_API2_STATUS(OrtNode_GetAttributeFloat, const OrtNode* node, const char* key, _Out_ float* out); -size_t(ORT_API_CALL* OrtNode_GetSubgraphs)(const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs); +ORT_API2_STATUS(OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs, _Out_ size_t* num_subgraphs); ORT_API2_STATUS(OrtFreeMem, void* p); }; diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index 909c48d547b8a..4fbc4b38fffd1 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -11,106 +11,124 @@ using namespace onnxruntime; -ORT_API(const char*, OrtGraphApis::OrtGraph_GetName, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetName, const OrtGraphViewer* graph, _Out_ const char** out) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->Name().c_str(); + *out = graph_viewer->Name().c_str(); + return nullptr; } -ORT_API(bool, OrtGraphApis::OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* out) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->IsConstantInitializer(name, check_outer_scope); + *out = graph_viewer->IsConstantInitializer(name, check_outer_scope); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order, _Out_ size_t* num_nodes) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const std::vector& nodes = graph_viewer->GetNodesInTopologicalOrder(static_cast(execution_order)); *nodes_index_in_topological_order = nodes.data(); - return nodes.size(); + *num_nodes = nodes.size(); + return nullptr; } -ORT_API(bool, OrtGraphApis::OrtGraph_IsSubgraph, const OrtGraph* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out) { const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); - return graph_ptr->IsSubgraph(); + *out = graph_ptr->IsSubgraph(); + return nullptr; } -ORT_API(const OrtGraph*, OrtGraphApis::OrtGraph_GetParentGraph, const OrtGraph* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph) { const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); - return reinterpret_cast(graph_ptr->ParentGraph()); + *parent_graph = reinterpret_cast(graph_ptr->ParentGraph()); + return nullptr; } -ORT_API(const OrtNode*, OrtGraphApis::OrtGraph_GetParenNode, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return reinterpret_cast(graph_viewer->ParentNode()); + *parent_node = reinterpret_cast(graph_viewer->ParentNode()); + return nullptr; } -ORT_API(const void*, OrtGraphApis::OrtGraph_GetModelPath, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** model_path) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return reinterpret_cast(&graph_viewer->ModelPath()); + *model_path = reinterpret_cast(&graph_viewer->ModelPath()); + return nullptr; } -ORT_API(const OrtGraph*, OrtGraphApis::OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph) { const ::onnxruntime::GraphViewer* graph_viewer_ptr = reinterpret_cast(graph_viewer); - return reinterpret_cast(&graph_viewer_ptr->GetGraph()); + *graph = reinterpret_cast(&graph_viewer_ptr->GetGraph()); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const auto& inputs = graph_viewer->GetInputsIncludingInitializers(); - size_t ret = inputs.size(); - *input_names = new const char*[ret]; - for (size_t i = 0; i < ret; i++) (*input_names)[i] = inputs[i]->Name().c_str(); - return ret; + *input_len = inputs.size(); + *input_names = new const char*[*input_len]; + for (size_t i = 0; i < *input_len; i++) (*input_names)[i] = inputs[i]->Name().c_str(); + return nullptr; } -ORT_API(const OrtNode*, OrtGraphApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return reinterpret_cast(graph_viewer->GetNode(node_index)); + *node = reinterpret_cast(graph_viewer->GetNode(node_index)); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers, _Out_ size_t* num_consumers) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); std::vector consumer_nodes = graph_viewer->GetConsumerNodes(input_name); - size_t ret = consumer_nodes.size(); - *consumers = new const OrtNode* [ret]; - for (size_t i = 0; i < ret; i++) (*consumers)[i] = reinterpret_cast(consumer_nodes[i]); + *num_consumers = consumer_nodes.size(); + *consumers = new const OrtNode* [*num_consumers]; + for (size_t i = 0; i < *num_consumers; i++) (*consumers)[i] = reinterpret_cast(consumer_nodes[i]); - return ret; + return nullptr; } -ORT_API(const OrtNode*, OrtGraphApis::OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** node) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return reinterpret_cast(graph_viewer->GetProducerNode(output_name)); + *node = reinterpret_cast(graph_viewer->GetProducerNode(output_name)); + return nullptr; } -ORT_API(int, OrtGraphApis::OrtGraph_NumberOfNodes, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_NumberOfNodes, const OrtGraphViewer* graph, _Out_ int* num_nodes) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->NumberOfNodes(); + *num_nodes = graph_viewer->NumberOfNodes(); + return nullptr; } -ORT_API(int, OrtGraphApis::OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* max_node_index) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->MaxNodeIndex(); + *max_node_index = graph_viewer->MaxNodeIndex(); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtGraph_GetOutputSize, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetOutputSize, const OrtGraphViewer* graph, _Out_ size_t* output_len) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->GetOutputs().size(); + *output_len = graph_viewer->GetOutputs().size(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i, _Outptr_ const char** out) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->GetOutputs()[i]->Name().c_str(); + *out = graph_viewer->GetOutputs()[i]->Name().c_str(); + return nullptr; } -ORT_API(int32_t, OrtGraphApis::OrtGraph_GetIthOutputElemType, const OrtGraphViewer* graph, size_t i) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetIthOutputElemType, const OrtGraphViewer* graph, size_t i, _Out_ int32_t* out) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - return graph_viewer->GetOutputs()[i]->TypeAsProto()->tensor_type().elem_type(); + *out = graph_viewer->GetOutputs()[i]->TypeAsProto()->tensor_type().elem_type(); + return nullptr; } -ORT_API(bool, OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** out) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** out, _Out_ bool* ret) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const onnx::TensorProto* initializer = nullptr; - if (!graph_viewer->GetInitializedTensor(initializer_name, initializer)) return false; + if (!graph_viewer->GetInitializedTensor(initializer_name, initializer)) { + *ret = false; + return nullptr; + } *out = new OrtTensorRef(); // TODO(leca): release (*out)->shape_len = initializer->dims_size(); (*out)->shape = new int64_t [initializer->dims_size()]; @@ -126,7 +144,8 @@ ORT_API(bool, OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphViewer* (*out)->data = reinterpret_cast(initializer->float_data().data()); break; } - return true; + *ret = true; + return nullptr; } static ONNXTensorElementDataType GetDataTypeFromTypeProto(const onnx::TypeProto* type) { // onnxruntime\core\optimizer\transpose_optimization\ort_optimizer_api_impl.cc @@ -135,7 +154,7 @@ static ONNXTensorElementDataType GetDataTypeFromTypeProto(const onnx::TypeProto* return static_cast(type->tensor_type().elem_type()); } -ORT_API(bool, OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const NodeArg* node_arg = graph_viewer->GetNodeArg(name); @@ -147,10 +166,11 @@ ORT_API(bool, OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* graph, (*out)->shape = new int64_t [(*out)->shape_len]; for (size_t i = 0; i < (*out)->shape_len; i++) ((*out)->shape)[i] = utils::HasDimValue(dims[i]) ? dims[i].dim_value() : -1; - return true; + *ret = true; + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); Model model(graph_viewer->Name(), true, ModelMetaData(), PathString(), #if defined(ORT_MINIMAL_BUILD) @@ -161,10 +181,10 @@ ORT_API(size_t, OrtGraphApis::OrtGraph_SerializeToArray, const OrtGraphViewer* g graph_viewer->DomainToVersionMap(), std::vector(), graph_viewer->GetGraph().GetLogger()); onnx::ModelProto model_proto = model.ToProto(); GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true, ExecutionOrder::PRIORITY_BASED); - size_t ret = model_proto.ByteSizeLong(); - *data = malloc(ret); // TODO(leca): release - model_proto.SerializeToArray(*data, ret); - return ret; + *data_size = model_proto.ByteSizeLong(); + *data = malloc(*data_size); // TODO(leca): release + model_proto.SerializeToArray(*data, *data_size); + return nullptr; } struct SubGraphContext2 { @@ -532,156 +552,183 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraphViewer* g return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetName, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetName, const OrtNode* node, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->Name().c_str(); + *out = n->Name().c_str(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetDescription, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetDescription, const OrtNode* node, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->Description().c_str(); + *out = n->Description().c_str(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetDomain, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetDomain, const OrtNode* node, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->Domain().c_str(); + *out = n->Domain().c_str(); + return nullptr; } -ORT_API(int, OrtGraphApis::OrtNode_SinceVersion, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_SinceVersion, const OrtNode* node, _Out_ int* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->SinceVersion(); + *out = n->SinceVersion(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetExecutionProviderType, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetExecutionProviderType, const OrtNode* node, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetExecutionProviderType().c_str(); + *out = n->GetExecutionProviderType().c_str(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetOpType, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetOpType, const OrtNode* node, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->OpType().c_str(); + *out = n->OpType().c_str(); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtNode_GetImplicitInputSize, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->ImplicitInputDefs().size(); + *out = n->ImplicitInputDefs().size(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); assert(i < n->ImplicitInputDefs().size()); - return n->ImplicitInputDefs()[i]->Name().c_str(); + *out = n->ImplicitInputDefs()[i]->Name().c_str(); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtNode_GetInputSize, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->InputDefs().size(); + *out = n->InputDefs().size(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetIthInputName, const OrtNode* node, size_t i) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); assert(i < n->InputDefs().size()); - return n->InputDefs()[i]->Name().c_str(); + *out = n->InputDefs()[i]->Name().c_str(); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtNode_GetOutputSize, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->OutputDefs().size(); + *out = n->OutputDefs().size(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetIthOutputName, const OrtNode* node, size_t i) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); assert(i < n->OutputDefs().size()); - if (n->OutputDefs()[i]->Exists()) return n->OutputDefs()[i]->Name().c_str(); + if (n->OutputDefs()[i]->Exists()) { + *out = n->OutputDefs()[i]->Name().c_str(); + return nullptr; + } return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtNode_GetIndex, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->Index(); + *out = n->Index(); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtNode_GetAttributeNames, const OrtNode* node, _Out_ const char*** names) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeNames, const OrtNode* node, _Out_ const char*** names, _Out_ size_t* num) { const ::onnxruntime::Node* n = reinterpret_cast(node); - size_t ret = n->GetAttributes().size(); - *names = new const char* [ret]; + *num = n->GetAttributes().size(); + *names = new const char* [*num]; int i = 0; for (const auto& [k, v] : n->GetAttributes()) { (*names)[i++] = k.c_str(); } - return ret; + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtNode_GetAttributeSize, const OrtNode* node) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().size(); + *out = n->GetAttributes().size(); + return nullptr; } -ORT_API(int, OrtGraphApis::OrtNode_GetAttributeType, const OrtNode* node, const char* attribute) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeType, const OrtNode* node, const char* attribute, _Out_ int* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return static_cast(n->GetAttributes().at(attribute).type()); + *out = static_cast(n->GetAttributes().at(attribute).type()); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().count(key); + *out = n->GetAttributes().count(key); + return nullptr; } -ORT_API(int, OrtGraphApis::OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).ints_size(); + *out = n->GetAttributes().at(key).ints_size(); + return nullptr; } -ORT_API(int, OrtGraphApis::OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).floats_size(); + *out = n->GetAttributes().at(key).floats_size(); + return nullptr; } -ORT_API(int, OrtGraphApis::OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).strings_size(); + *out = n->GetAttributes().at(key).strings_size(); + return nullptr; } -ORT_API(int64_t, OrtGraphApis::OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).ints(i); + *out = n->GetAttributes().at(key).ints(i); + return nullptr; } -ORT_API(float, OrtGraphApis::OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).floats(i); + *out = n->GetAttributes().at(key).floats(i); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).strings(i).c_str(); + *out = n->GetAttributes().at(key).strings(i).c_str(); + return nullptr; } -ORT_API(const char*, OrtGraphApis::OrtNode_GetAttributeStr, const OrtNode* node, const char* key) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).s().c_str(); + *out = n->GetAttributes().at(key).s().c_str(); + return nullptr; } -ORT_API(int64_t, OrtGraphApis::OrtNode_GetAttributeInt, const OrtNode* node, const char* key) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeInt, const OrtNode* node, const char* key, _Out_ int64_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).i(); + *out = n->GetAttributes().at(key).i(); + return nullptr; } -ORT_API(float, OrtGraphApis::OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeFloat, const OrtNode* node, const char* key, _Out_ float* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); - return n->GetAttributes().at(key).f(); + *out = n->GetAttributes().at(key).f(); + return nullptr; } -ORT_API(size_t, OrtGraphApis::OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs, _Out_ size_t* num) { const ::onnxruntime::Node* n = reinterpret_cast(node); std::vector> subg = n->GetSubgraphs(); - size_t ret = subg.size(); - *subgraphs = new const OrtGraphViewer* [ret]; - for (size_t i = 0; i < ret; i++) { + *num = subg.size(); + *subgraphs = new const OrtGraphViewer* [*num]; + for (size_t i = 0; i < *num; i++) { const ::onnxruntime::GraphViewer* graph_viewer = new const ::onnxruntime::GraphViewer(*subg[i]); (*subgraphs)[i] = reinterpret_cast(graph_viewer); } - return ret; + return nullptr; } ORT_API_STATUS_IMPL(OrtGraphApis::OrtFreeMem, void* p) { diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index f34d1461fd1ef..d1fb60b6fbbca 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -3,103 +3,103 @@ namespace OrtGraphApis { ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version); -ORT_API(const char*, OrtGraph_GetName, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtGraph_GetName, const OrtGraphViewer* graph, _Outptr_ const char** out); -ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope)ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* out); -ORT_API(size_t, OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order); +ORT_API_STATUS_IMPL(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order, _Out_ size_t* num_nodes); -ORT_API(bool, OrtGraph_IsSubgraph, const OrtGraph* graph); +ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out); -ORT_API(const OrtGraph*, OrtGraph_GetParentGraph, const OrtGraph* graph); +ORT_API_STATUS_IMPL(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); -ORT_API(const OrtNode*, OrtGraph_GetParenNode, const OrtGraphViewer* graph); +ORT_API_STATUS_IMPL(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); -ORT_API(const void*, OrtGraph_GetModelPath, const OrtGraphViewer* graph); +ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** model_path); -ORT_API(const OrtGraph*, OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer); +ORT_API_STATUS_IMPL(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); -ORT_API(size_t, OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names); +ORT_API_STATUS_IMPL(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); -ORT_API(const OrtNode*, OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index); +ORT_API_STATUS_IMPL(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); -ORT_API(size_t, OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers); +ORT_API_STATUS_IMPL(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers, _Out_ size_t* num_consumers); -ORT_API(const OrtNode*, OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name); +ORT_API_STATUS_IMPL(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** node); -ORT_API(int, OrtGraph_NumberOfNodes, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtGraph_NumberOfNodes, const OrtGraphViewer* graph, _Out_ int* num_nodes); -ORT_API(int, OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph); +ORT_API_STATUS_IMPL(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* max_node_index); -ORT_API(size_t, OrtGraph_GetOutputSize, const OrtGraphViewer*) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtGraph_GetOutputSize, const OrtGraphViewer* graph, _Out_ size_t* output_len); -ORT_API(const char*, OrtGraph_GetIthOutputName, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i, _Outptr_ const char** out); -ORT_API(int32_t, OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i, _Out_ int32_t* out); -ORT_API(bool, OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); +ORT_API_STATUS_IMPL(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** tensor, _Out_ bool* ret); -ORT_API(bool, OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); +ORT_API_STATUS_IMPL(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret); -ORT_API(size_t, OrtGraph_SerializeToArray, const OrtGraphViewer*, _Out_ void** data); +ORT_API_STATUS_IMPL(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph); -ORT_API(const char*, OrtNode_GetName, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Outptr_ const char** out); -ORT_API(const char*, OrtNode_GetDescription, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetDescription, const OrtNode* node, _Outptr_ const char** out); -ORT_API(const char*, OrtNode_GetDomain, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetDomain, const OrtNode* node, _Outptr_ const char** out); -ORT_API(int, OrtNode_SinceVersion, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_SinceVersion, const OrtNode* node, _Out_ int* out); -ORT_API(const char*, OrtNode_GetExecutionProviderType, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetExecutionProviderType, const OrtNode* node, _Outptr_ const char** out); -ORT_API(const char*, OrtNode_GetOpType, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetOpType, const OrtNode* node, _Outptr_ const char** out); -ORT_API(size_t, OrtNode_GetImplicitInputSize, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* out); -ORT_API(const char*, OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i); +ORT_API_STATUS_IMPL(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -ORT_API(size_t, OrtNode_GetInputSize, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* out); -ORT_API(const char*, OrtNode_GetIthInputName, const OrtNode* node, size_t i); +ORT_API_STATUS_IMPL(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -ORT_API(size_t, OrtNode_GetOutputSize, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* out); -ORT_API(const char*, OrtNode_GetIthOutputName, const OrtNode* node, size_t i); +ORT_API_STATUS_IMPL(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -ORT_API(size_t, OrtNode_GetIndex, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* out); -ORT_API(size_t, OrtNode_GetAttributeNames, const OrtNode* node, const char*** names); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeNames, const OrtNode* node, _Out_ const char*** names, _Out_ size_t* num); -ORT_API(size_t, OrtNode_GetAttributeSize, const OrtNode* node); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* out); -ORT_API(int, OrtNode_GetAttributeType, const OrtNode* node, const char* attribute) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtNode_GetAttributeType, const OrtNode* node, const char* attribute, _Out_ int* out); -ORT_API(size_t, OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* out); -ORT_API(int, OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* out); -ORT_API(int, OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* out); -ORT_API(int, OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* out); -ORT_API(int64_t, OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* out); -ORT_API(float, OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* out); -ORT_API(const char*, OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Outptr_ const char** out); -ORT_API(const char*, OrtNode_GetAttributeStr, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out); -ORT_API(int64_t, OrtNode_GetAttributeInt, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtNode_GetAttributeInt, const OrtNode* node, const char* key, _Out_ int64_t* out); -ORT_API(float, OrtNode_GetAttributeFloat, const OrtNode* node, const char* key) ORT_ALL_ARGS_NONNULL; +ORT_API_STATUS_IMPL(OrtNode_GetAttributeFloat, const OrtNode* node, const char* key, _Out_ float* out); -ORT_API(size_t, OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs); +ORT_API_STATUS_IMPL(OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs, _Out_ size_t* num_subgraphs); ORT_API_STATUS_IMPL(OrtFreeMem, void* p); diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index adbf686866cfb..028c001cbeae6 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -11,10 +11,13 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe const OrtGraphApi* ort_graph_api = api->GetGraphApi(ORT_API_VERSION); std::vector cache; const size_t* nodes_index = nullptr; - size_t nodes_count = ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 0, &nodes_index); + size_t nodes_count = 0; + ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 0, &nodes_index, &nodes_count); for (size_t i = 0; i < nodes_count; i++) { - const OrtNode* node = ort_graph_api->OrtGraph_GetOrtNode(graph, nodes_index[i]); - const char* node_op_type = ort_graph_api->OrtNode_GetOpType(node); + const OrtNode* node = nullptr; + ort_graph_api->OrtGraph_GetOrtNode(graph, nodes_index[i], &node); + const char* node_op_type = nullptr; + ort_graph_api->OrtNode_GetOpType(node, &node_op_type); if (!strcmp(node_op_type, "Relu")) { OrtIndexedSubGraph* subgraph = new OrtIndexedSubGraph(); subgraph->node_index_len = 1; @@ -24,17 +27,19 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe subgraph->meta_def = new OrtMetaDef(); subgraph->meta_def->name = "Relu_subgraph"; subgraph->meta_def->input_len = 0; - subgraph->meta_def->input_len = ort_graph_api->OrtNode_GetInputSize(node); + ort_graph_api->OrtNode_GetInputSize(node, &subgraph->meta_def->input_len); subgraph->meta_def->inputs = new char* [subgraph->meta_def->input_len]; for (size_t j = 0; j < subgraph->meta_def->input_len; j++) { - const char* input_j = ort_graph_api->OrtNode_GetIthInputName(node, j); + const char* input_j = nullptr; + ort_graph_api->OrtNode_GetIthInputName(node, j, &input_j); subgraph->meta_def->inputs[j] = const_cast(input_j); } - subgraph->meta_def->output_len = ort_graph_api->OrtNode_GetOutputSize(node); + ort_graph_api->OrtNode_GetOutputSize(node, &subgraph->meta_def->output_len); subgraph->meta_def->outputs = new char* [subgraph->meta_def->output_len]; for (size_t j = 0; j < subgraph->meta_def->output_len; j++) { - const char* output_j = ort_graph_api->OrtNode_GetIthOutputName(node, j); + const char* output_j = nullptr; + ort_graph_api->OrtNode_GetIthOutputName(node, j, &output_j); subgraph->meta_def->outputs[j] = const_cast(output_j); } diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index 66d99faa51b09..426d2484ef98d 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -9,13 +9,16 @@ namespace onnxruntime { bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION); - int maxNodeIndex = graph_api->OrtGraph_MaxNodeIndex(graph_viewer); + int maxNodeIndex = 0; + graph_api->OrtGraph_MaxNodeIndex(graph_viewer, &maxNodeIndex); for (int i = 0; i < maxNodeIndex; ++i) { - const OrtNode* node = graph_api->OrtGraph_GetOrtNode(graph_viewer, i); + const OrtNode* node = nullptr; + graph_api->OrtGraph_GetOrtNode(graph_viewer, i, &node); if (node == nullptr) { continue; } - const char* opType = graph_api->OrtNode_GetOpType(node); + const char* opType = nullptr; + graph_api->OrtNode_GetOpType(node, &opType); if (strcmp(opType, EPCONTEXT_OP.c_str()) == 0) { return true; } @@ -116,12 +119,16 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView if (!ValidateEPCtxNode(graph_viewer)) { return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "It's not a valid EP Context node"); } - const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0); + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); - const int64_t embed_mode = graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str()); + int64_t embed_mode = -1; + graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode); if (embed_mode) { // Get engine from byte stream. - const std::string& context_binary(graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str())); + const char* context_binary_cstr = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr); + std::string context_binary(context_binary_cstr); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; @@ -130,7 +137,9 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView } } else { // Get engine from cache file. - std::string cache_path(graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str())); + const char* cache_path_cstr = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &cache_path_cstr); + std::string cache_path(cache_path_cstr); // For security purpose, in the case of running context model, TRT EP won't allow // engine cache path to be the relative path like "../file_path" or the absolute path. @@ -182,7 +191,9 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); if (weight_stripped_engine_refit_) { - const std::string onnx_model_filename(graph_api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str())); + const char* onnx_model_filename_cstr = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str(), &onnx_model_filename_cstr); + const std::string onnx_model_filename(onnx_model_filename_cstr); std::string weight_stripped_engine_cache = engine_cache_path.string(); auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, @@ -200,15 +211,21 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView } bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_viewer) { - assert(graph_api_->OrtGraph_NumberOfNodes(graph_viewer) == 1); - const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0); - const char* opType = graph_api_->OrtNode_GetOpType(node); + int node_count = 0; + graph_api_->OrtGraph_NumberOfNodes(graph_viewer, &node_count); + assert(node_count == 1); + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); + const char* opType = nullptr; + graph_api_->OrtNode_GetOpType(node, &opType); assert(strcmp(opType, EPCONTEXT_OP.c_str()) == 0); - size_t key_count = graph_api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str()); + size_t key_count = 0; + graph_api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str(), &key_count); // Show the warning if compute capability is not matched if (key_count > 0) { - const char* model_compute_capability = graph_api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str()); + const char* model_compute_capability = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str(), &model_compute_capability); // Verify if engine was compiled with ampere+ hardware compatibility enabled if (strcmp(model_compute_capability, "80+") == 0) { // if (std::stoi(compute_capability_) < 80) { @@ -222,12 +239,13 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_vi } // "embed_mode" attr and "ep_cache_context" attr should be present - key_count = graph_api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str()); + graph_api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str(), &key_count); assert(key_count > 0); - key_count = graph_api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str()); + graph_api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str(), &key_count); assert(key_count > 0); - const int64_t embed_mode = graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str()); + int64_t embed_mode = -1; + graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode); if (embed_mode == 1) { // engine binary data // LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 2bed0827611f7..9a7ca3739581a 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -974,7 +974,8 @@ OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, // Detect and remove cycles from supported node list bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles) const { const size_t* nodes_index = nullptr; - size_t node_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index); + size_t node_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index, &node_count); bool trt_cycle = true, cycle_detected = false; while (trt_cycle) { trt_cycle = false; @@ -1018,29 +1019,37 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& // Add non TensorRT nodes to the maps for (const auto& index : non_trt_node_index) { - const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph, index); - const char* node_name_char = graph_api_->OrtNode_GetName(node); + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph, index, &node); + const char* node_name_char = nullptr; + graph_api_->OrtNode_GetName(node, &node_name_char); const std::string node_name(node_name_char); if (node_to_index_map.find(node_name) == node_to_index_map.end()) { index_to_node_map[id] = node_name; node_to_index_map[node_name] = id++; } - size_t input_count = graph_api_->OrtNode_GetInputSize(node); + size_t input_count = 0; + graph_api_->OrtNode_GetInputSize(node, &input_count); for (size_t i = 0; i < input_count; ++i) { - const char* input_name_char = graph_api_->OrtNode_GetIthInputName(node, i); + const char* input_name_char = nullptr; + graph_api_->OrtNode_GetIthInputName(node, i, &input_name_char); input_to_nodes_map[std::string(input_name_char)].insert(node_name); } - size_t implicit_input_count = graph_api_->OrtNode_GetImplicitInputSize(node); + size_t implicit_input_count = 0; + graph_api_->OrtNode_GetImplicitInputSize(node, &implicit_input_count); for (size_t i = 0; i < implicit_input_count; ++i) { - const char* input_name_char = graph_api_->OrtNode_GetIthImplicitInputName(node, i); + const char* input_name_char = nullptr; + graph_api_->OrtNode_GetIthImplicitInputName(node, i, &input_name_char); input_to_nodes_map[std::string(input_name_char)].insert(node_name); } - size_t output_count = graph_api_->OrtNode_GetOutputSize(node); + size_t output_count = 0; + graph_api_->OrtNode_GetOutputSize(node, &output_count); for (size_t i = 0; i < output_count; ++i) { - const char* output_name_char = graph_api_->OrtNode_GetIthOutputName(node, i); + const char* output_name_char = nullptr; + graph_api_->OrtNode_GetIthOutputName(node, i, &output_name_char); node_to_outputs_map[node_name].insert(std::string(output_name_char)); } } @@ -1100,11 +1109,15 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& // Check the graph is the subgraph of control flow op bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const { - const OrtGraph* cur_graph = graph_api_->OrtGraph_GetOrtGraph(graph); - bool is_subgraph = graph_api_->OrtGraph_IsSubgraph(cur_graph); + const OrtGraph* cur_graph = nullptr; + graph_api_->OrtGraph_GetOrtGraph(graph, &cur_graph); + bool is_subgraph = false; + graph_api_->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); if (is_subgraph) { - const OrtNode* node = graph_api_->OrtGraph_GetParenNode(graph); - const char* node_op_type = graph_api_->OrtNode_GetOpType(node); + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetParenNode(graph, &node); + const char* node_op_type = nullptr; + graph_api_->OrtNode_GetOpType(node, &node_op_type); if (control_flow_op_set_.find(std::string(node_op_type)) != control_flow_op_set_.end()) { return true; } @@ -1114,14 +1127,18 @@ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* // Check whether all the nodes of the graph are assigned to specific ep bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraphViewer* graph, const std::string& provider_type) const { - const int number_of_ort_nodes = graph_api_->OrtGraph_NumberOfNodes(graph); + int number_of_ort_nodes = 0; + graph_api_->OrtGraph_NumberOfNodes(graph, &number_of_ort_nodes); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); const size_t* nodes_index = nullptr; - size_t node_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index); + size_t node_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index, &node_count); for (const auto& index : nodes_vector) { - const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph, nodes_index[index]); - const char* node_ep_type = graph_api_->OrtNode_GetExecutionProviderType(node); + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph, nodes_index[index], &node); + const char* node_ep_type = nullptr; + graph_api_->OrtNode_GetExecutionProviderType(node, &node_ep_type); if (strcmp(node_ep_type, provider_type.c_str())) { return false; } @@ -1143,7 +1160,8 @@ bool TensorrtExecutionProvider::IsSubGraphFullySupported(SubGraphCollection_t su std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const OrtGraphViewer* graph, const HashValue& model_hash, int subgraph_index) const { const size_t* node_index = nullptr; - size_t nodes_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_index); + size_t nodes_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_index, &nodes_count); std::unordered_set node_set; node_set.reserve(graph_nodes_index.first.size()); for (const auto& index : graph_nodes_index.first) { @@ -1152,9 +1170,12 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr // Get parent graph output names std::unordered_set graph_output_names; - size_t graph_output_size = graph_api_->OrtGraph_GetOutputSize(graph); + size_t graph_output_size = 0; + graph_api_->OrtGraph_GetOutputSize(graph, &graph_output_size); for (size_t i = 0; i < graph_output_size; i++) { - graph_output_names.insert(graph_api_->OrtGraph_GetIthOutputName(graph, i)); + char const* output_name = nullptr; + graph_api_->OrtGraph_GetIthOutputName(graph, i, &output_name); + graph_output_names.insert(output_name); } // Find inputs and outputs of the subgraph @@ -1172,59 +1193,76 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr int i = 0; for (const auto& index : graph_nodes_index.first) { sub_graph->node_index[i++] = node_index[index]; - const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph, node_index[index]); - size_t input_size = graph_api_->OrtNode_GetInputSize(node); + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph, node_index[index], &node); + size_t input_size = 0; + graph_api_->OrtNode_GetInputSize(node, &input_size); for (size_t j = 0; j < input_size; j++) { - const char* input_name = graph_api_->OrtNode_GetIthInputName(node, j); - if (graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true)) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, j, &input_name); + bool is_initializer = false; + graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_initializer); + if (is_initializer) { initializers.push_back(input_name); continue; } - const OrtNode* producer = graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name); + const OrtNode* producer = nullptr; + graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); // If the input is not produced by any node, it is a graph input if (producer == nullptr) { input_to_order[input_name] = input_order++; continue; } - size_t producer_index = graph_api_->OrtNode_GetIndex(producer); + size_t producer_index = -1; + graph_api_->OrtNode_GetIndex(producer, &producer_index); // If the producer node is not in the subgraph, the input is a graph input if (node_set.find(producer_index) == node_set.end()) { input_to_order[input_name] = input_order++; } } - size_t implicit_input_size = graph_api_->OrtNode_GetImplicitInputSize(node); + size_t implicit_input_size = 0; + graph_api_->OrtNode_GetImplicitInputSize(node, &implicit_input_size); for (size_t j = 0; j < implicit_input_size; j++) { - const char* input_name = graph_api_->OrtNode_GetIthImplicitInputName(node, j); - if (graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true)) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthImplicitInputName(node, j, &input_name); + bool is_initializer = false; + graph_api_->OrtGraph_IsConstantInitializer(graph, input_name, true, &is_initializer); + if (is_initializer) { initializers.push_back(input_name); continue; } - const OrtNode* producer = graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name); + const OrtNode* producer = nullptr; + graph_api_->OrtGraph_GetNodeProducingOutput(graph, input_name, &producer); // If the input is not produced by any node, it is a graph input if (producer == nullptr) { input_to_order[input_name] = input_order++; continue; } - size_t producer_index = graph_api_->OrtNode_GetIndex(producer); + size_t producer_index = -1; + graph_api_->OrtNode_GetIndex(producer, &producer_index); // If the producer node is not in the subgraph, the input is a graph input if (node_set.find(producer_index) == node_set.end()) { input_to_order[input_name] = input_order++; } } - size_t output_size = graph_api_->OrtNode_GetOutputSize(node); + size_t output_size = 0; + graph_api_->OrtNode_GetOutputSize(node, &output_size); for (size_t j = 0; j < output_size; j++) { - const char* output_name = graph_api_->OrtNode_GetIthOutputName(node, j); + const char* output_name = nullptr; + graph_api_->OrtNode_GetIthOutputName(node, j, &output_name); // If the output is the graph output, it is a subgraph output if (graph_output_names.find(output_name) != graph_output_names.end()) { output_to_order[output_name] = output_order++; continue; } const OrtNode** consumers = nullptr; - size_t consumer_count = graph_api_->OrtGraph_GetNodesConsumingInput(graph, output_name, &consumers); + size_t consumer_count = 0; + graph_api_->OrtGraph_GetNodesConsumingInput(graph, output_name, &consumers, &consumer_count); for (size_t k = 0; k < consumer_count; k++) { - size_t consumer_index = graph_api_->OrtNode_GetIndex(consumers[k]); + size_t consumer_index = -1; + graph_api_->OrtNode_GetIndex(consumers[k], &consumer_index); // If the consumer node is not in the subgraph, the output is a subgraph output if (node_set.find(consumer_index) == node_set.end()) { output_to_order[output_name] = output_order++; @@ -1245,10 +1283,13 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr // Generate unique kernel name for TRT subgraph std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); - const OrtGraph* cur_graph = graph_api_->OrtGraph_GetOrtGraph(graph); - bool is_subgraph = graph_api_->OrtGraph_IsSubgraph(cur_graph); + const OrtGraph* cur_graph = nullptr; + graph_api_->OrtGraph_GetOrtGraph(graph, &cur_graph); + bool is_subgraph = false; + graph_api_->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); const std::string graph_type = is_subgraph ? "subgraph" : "graph"; - const char* graph_name = graph_api_->OrtGraph_GetName(graph); + const char* graph_name = nullptr; + graph_api_->OrtGraph_GetName(graph, &graph_name); std::string meta_def_name = "TRTKernel_" + graph_type + "_" + std::string(graph_name) + subgraph_id; sub_graph->meta_def->name = new char [meta_def_name.length() + 1]; strcpy(sub_graph->meta_def->name, meta_def_name.c_str()); @@ -1288,7 +1329,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { const TensorrtExecutionProvider* p = static_cast(this_); // Get ModelPath - const std::filesystem::path* model_path = static_cast(graph_api_->OrtGraph_GetModelPath(graph)); + const std::filesystem::path* model_path = nullptr; + graph_api_->OrtGraph_GetModelPath(graph, reinterpret_cast(&model_path)); const auto& path_string = model_path->string(); #ifdef _WIN32 std::strncpy_s(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); @@ -1297,7 +1339,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const #endif p->model_path_[sizeof(p->model_path_) - 1] = '\0'; - if (graph_api_->OrtGraph_NumberOfNodes(graph) == 1 && GraphHasCtxNode(graph)) { + int node_count = 0; + graph_api_->OrtGraph_NumberOfNodes(graph, &node_count); + if (node_count == 1 && GraphHasCtxNode(graph)) { SubGraph_t supported_node_vector = {{0}, true}; std::unique_ptr sub_graph = p->GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); *cnt = 1; @@ -1310,16 +1354,20 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const HashValue model_hash = TRTGenerateId(graph); // Get supported node list from TensorRT parser - const int number_of_ort_nodes = graph_api_->OrtGraph_NumberOfNodes(graph); + int number_of_ort_nodes = 0; + graph_api_->OrtGraph_NumberOfNodes(graph, &number_of_ort_nodes); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); std::vector filtered_nodes_vector; const size_t* nodes_index = nullptr; - size_t nodes_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index); + size_t nodes_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &nodes_index, &nodes_count); for (const auto& index : nodes_vector) { - const OrtNode* node = graph_api_->OrtGraph_GetOrtNode(graph, nodes_index[index]); - const char* node_op_type = graph_api_->OrtNode_GetOpType(node); + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph, nodes_index[index], &node); + const char* node_op_type = nullptr; + graph_api_->OrtNode_GetOpType(node, &node_op_type); /* If current node is control flow op, we take different approach based on following four cases: * @@ -1332,12 +1380,15 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const */ if (p->control_flow_op_set_.find(std::string(node_op_type)) != p->control_flow_op_set_.end()) { const OrtGraphViewer** subgraphs = nullptr; - size_t subgraph_count = graph_api_->OrtNode_GetSubgraphs(node, &subgraphs); + size_t subgraph_count = 0; + graph_api_->OrtNode_GetSubgraphs(node, &subgraphs, &subgraph_count); if (subgraph_count != 0) { bool all_subgraphs_are_supported = true; for (size_t i = 0; i < subgraph_count; i++) { // TRT EP should consider the empty subgraph is fully supported by TRT. - if (graph_api_->OrtGraph_NumberOfNodes(subgraphs[i]) == 0) { + int number_of_ort_subgraph_nodes = 0; + graph_api_->OrtGraph_NumberOfNodes(subgraphs[i], &number_of_ort_subgraph_nodes); + if (number_of_ort_subgraph_nodes == 0) { continue; } if (!p->AllNodesAssignedToSpecificEP(subgraphs[i], tensorrtEp)) { @@ -1398,20 +1449,26 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT. - const OrtNode* parent_node = graph_api_->OrtGraph_GetParenNode(graph); - const char* parent_node_op_type = graph_api_->OrtNode_GetOpType(parent_node); + const OrtNode* parent_node = nullptr; + graph_api_->OrtGraph_GetParenNode(graph, &parent_node); + const char* parent_node_op_type = nullptr; + graph_api_->OrtNode_GetOpType(parent_node, &parent_node_op_type); if (strcmp(parent_node_op_type, "If") == 0) { all_subgraphs_are_supported = false; SubGraphCollection_t subgraph_supported_nodes_vector; const OrtGraphViewer** subgraphs = nullptr; - size_t subgraph_count = graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs); - const OrtGraph* origin_graph = graph_api_->OrtGraph_GetOrtGraph(graph); + size_t subgraph_count = 0; + graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs, &subgraph_count); + const OrtGraph* origin_graph = nullptr; + graph_api_->OrtGraph_GetOrtGraph(graph, &origin_graph); for (size_t i = 0; i < subgraph_count; i++) { - const OrtGraph* subgraph = graph_api_->OrtGraph_GetOrtGraph(subgraphs[i]); + const OrtGraph* subgraph = nullptr; + graph_api_->OrtGraph_GetOrtGraph(subgraphs[i], &subgraph); if (subgraph == origin_graph) { continue; } - const int number_of_ort_subgraph_nodes = graph_api_->OrtGraph_NumberOfNodes(subgraphs[i]); + int number_of_ort_subgraph_nodes = 0; + graph_api_->OrtGraph_NumberOfNodes(subgraphs[i], &number_of_ort_subgraph_nodes); std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; @@ -1497,15 +1554,19 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const this_->extra_param_for_compute_func = p; for (size_t j = 0; j < cnt; j++) { std::unordered_map input_map, output_map; - size_t input_size = graph_api_->OrtNode_GetInputSize(node[j]); + size_t input_size = 0; + graph_api_->OrtNode_GetInputSize(node[j], &input_size); for (size_t i = 0; i < input_size; i++) { - const char* ith_input_name = graph_api_->OrtNode_GetIthInputName(node[j], i); + const char* ith_input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node[j], i, &ith_input_name); input_map[ith_input_name] = i; } - size_t output_size = graph_api_->OrtNode_GetOutputSize(node[j]); + size_t output_size = 0; + graph_api_->OrtNode_GetOutputSize(node[j], &output_size); for (size_t i = 0; i < output_size; i++) { - const char* ith_output_name = graph_api_->OrtNode_GetIthOutputName(node[j], i); + const char* ith_output_name = nullptr; + graph_api_->OrtNode_GetIthOutputName(node[j], i, &ith_output_name); if (ith_output_name != nullptr) { output_map[ith_output_name] = i; } @@ -2125,7 +2186,8 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); void* buf_data = nullptr; - size_t buf_size = graph_api_->OrtGraph_SerializeToArray(graph_body_viewer, &buf_data); + size_t buf_size = 0; + graph_api_->OrtGraph_SerializeToArray(graph_body_viewer, &buf_data, &buf_size); trt_parser->parse(buf_data, buf_size, model_path_); trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); graph_api_->OrtFreeMem(buf_data); @@ -2294,7 +2356,8 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } } - const char* node_name = graph_api_->OrtNode_GetName(fused_node); + const char* node_name = nullptr; + graph_api_->OrtNode_GetName(fused_node, &node_name); // Load INT8 calibration table std::unordered_map dynamic_range_map; @@ -2668,7 +2731,9 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort if (iter != output_map.end()) { output_indexes[output_name] = iter->second; } - output_types[output_name] = graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); + int32_t output_type = 0; + graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i, &output_type); + output_types[output_name] = output_type; } // Save TRT engine, other TRT objects and input/output info to map @@ -3329,7 +3394,8 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi trt_context = std::unique_ptr(trt_engine->createExecutionContext()); } - const char* fused_node_name = graph_api_->OrtNode_GetName(fused_node); + const char* fused_node_name = nullptr; + graph_api_->OrtNode_GetName(fused_node, &fused_node_name); if (!trt_context) { return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT EP could not build execution context for fused node: " + std::string(fused_node_name)).c_str()); @@ -3353,9 +3419,14 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi } // Create output to type map - size_t graph_output_size = graph_api_->OrtGraph_GetOutputSize(graph_body_viewer); + size_t graph_output_size = 0; + graph_api_->OrtGraph_GetOutputSize(graph_body_viewer, &graph_output_size); for (size_t i = 0; i < graph_output_size; i++) { - output_types[graph_api_->OrtGraph_GetIthOutputName(graph_body_viewer, i)] = graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i); + char const* output_name = nullptr; + graph_api_->OrtGraph_GetIthOutputName(graph_body_viewer, i, &output_name); + int32_t output_type = 0; + graph_api_->OrtGraph_GetIthOutputElemType(graph_body_viewer, i, &output_type); + output_types[output_name] = output_type; } // Save TRT engine, TRT context and input/output info to map @@ -3606,7 +3677,8 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect iterations++; const size_t* node_index = nullptr; - size_t nodes_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_index); + size_t nodes_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph, 1, &node_index, &nodes_count); for (const auto& group : nodes_vector_input) { // Construct subgraph if (!group.first.empty()) { @@ -3618,7 +3690,8 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect graph_api_->OrtGraph_GetSubGraph(graph, group.first.size(), group.first.data(), &sub_graph_viewer); void* buf_data = nullptr; - size_t buf_size = graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data); + size_t buf_size = 0; + graph_api_->OrtGraph_SerializeToArray(sub_graph_viewer, &buf_data, &buf_size); // Get supported node list recursively SubGraphCollection_t parser_nodes_list; @@ -3644,7 +3717,8 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect SubGraphCollection_t next_nodes_list; const size_t* subgraph_node_index = nullptr; - size_t subgraph_node_count = graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index); + size_t subgraph_node_count = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(sub_graph_viewer, 1, &subgraph_node_index, &subgraph_node_count); next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph_viewer, early_termination); for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index 124d85657e222..bc08040686592 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -272,11 +272,14 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { HashValue model_hash = 0; const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION); - const OrtGraph* cur_graph = graph_api->OrtGraph_GetOrtGraph(graph_viewer); - bool is_subgraph = graph_api->OrtGraph_IsSubgraph(cur_graph); + const OrtGraph* cur_graph = nullptr; + graph_api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph); + bool is_subgraph = false; + graph_api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); while (is_subgraph) { - cur_graph = graph_api->OrtGraph_GetParentGraph(cur_graph); - is_subgraph = graph_api->OrtGraph_IsSubgraph(cur_graph); + graph_api->OrtGraph_GetParentGraph(cur_graph, &cur_graph); + is_subgraph = false; + graph_api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); } const OrtGraph* main_graph = cur_graph; @@ -286,7 +289,8 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); }; - const std::filesystem::path* model_path = static_cast(graph_api->OrtGraph_GetModelPath(graph_viewer)); + const std::filesystem::path* model_path = nullptr; + graph_api->OrtGraph_GetModelPath(graph_viewer, reinterpret_cast(&model_path)); // Use the model's file name instead of the entire path to avoid cache regeneration if path changes if (model_path->has_filename()) { @@ -308,22 +312,28 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { // fingerprint current graph by hashing graph inputs // const std::vector& input_names = nullptr; const char** input_names = nullptr; - size_t input_count = graph_api->OrtGraph_GetInputsIncludingInitializers(graph_viewer, &input_names); + size_t input_count = 0; + graph_api->OrtGraph_GetInputsIncludingInitializers(graph_viewer, &input_names, &input_count); for (size_t i = 0; i < input_count; ++i) { hash_str(input_names[i]); } // hashing output of each node - const int number_of_ort_nodes = graph_api->OrtGraph_NumberOfNodes(graph_viewer); + int number_of_ort_nodes = 0; + graph_api->OrtGraph_NumberOfNodes(graph_viewer, &number_of_ort_nodes); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); const size_t* nodes_index = nullptr; - size_t nodes_count = graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_index); + size_t nodes_count = 0; + graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes_index, &nodes_count); for (const auto& index : nodes_vector) { - const OrtNode* node = graph_api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index]); - size_t output_size = graph_api->OrtNode_GetOutputSize(node); + const OrtNode* node = nullptr; + graph_api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index], &node); + size_t output_size = 0; + graph_api->OrtNode_GetOutputSize(node, &output_size); for (size_t i = 0; i < output_size; ++i) { - const char* output_name = graph_api->OrtNode_GetIthOutputName(node, i); + const char* output_name = nullptr; + graph_api->OrtNode_GetIthOutputName(node, i, &output_name); if (output_name != nullptr) { hash_str(output_name); } From da5b6eb333ef9aff1e33a2b449c610f3ef9c43ed Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Fri, 18 Oct 2024 01:17:18 +0000 Subject: [PATCH 51/81] resolve comments --- .../core/framework/execution_provider.h | 4 +-- .../onnxruntime/core/session/environment.h | 6 ++-- .../core/session/onnxruntime_c_api.h | 12 +++---- .../core/session/onnxruntime_c_api_ep.h | 4 +-- onnxruntime/core/framework/provider_adapter.h | 2 +- .../core/framework/provider_factory_adapter.h | 2 -- onnxruntime/core/framework/session_state.cc | 2 +- onnxruntime/core/session/environment.cc | 10 +++--- onnxruntime/core/session/onnxruntime_c_api.cc | 36 +++++++++---------- .../core/session/onnxruntime_c_api_ep.cc | 8 ++--- onnxruntime/core/session/ort_apis.h | 12 +++---- onnxruntime/core/session/ort_apis_ep.h | 4 +-- onnxruntime/core/session/ort_env.cc | 8 ++--- onnxruntime/core/session/ort_env.h | 4 +-- .../python/onnxruntime_pybind_state.cc | 4 +-- .../test/python/onnxruntime_test_plugin_ep.py | 6 ++-- samples/c_test/test.cpp | 16 ++++----- samples/outTreeEp/out_tree_ep.cc | 4 +-- .../tensorRTEp/tensorrt_execution_provider.cc | 16 ++++----- .../tensorrt_execution_provider_utils.h | 2 +- 20 files changed, 80 insertions(+), 82 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index da0c665cb2c1b..2ddfb1609c37c 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -75,7 +75,7 @@ class IExecutionProvider { */ const OrtDevice default_device_; - bool intree_ep = true; + bool builtin_ep_ = true; public: virtual ~IExecutionProvider() = default; @@ -327,7 +327,7 @@ class IExecutionProvider { return InlinedVector(); } - bool IsIntreeEp() const { return intree_ep; } + bool IsBuiltInEp() const { return builtin_ep_; } private: const std::string type_; diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index c05d9768d9b0b..2b05bc08ac376 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -89,9 +89,9 @@ class Environment { */ Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg = nullptr); - void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); + void InsertPluginEpFactory(const char* ep_name, OrtExecutionProviderFactory* ep_factory); - OrtExecutionProviderFactory* GetOrtExecutionProviderFactory(const std::string& ep_name); + OrtExecutionProviderFactory* GetPluginExecutionProviderFactory(const std::string& ep_name); private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); @@ -104,6 +104,6 @@ class Environment { std::unique_ptr inter_op_thread_pool_; bool create_global_thread_pools_{false}; std::vector shared_allocators_; - std::unordered_map> custom_ep_factories_; + std::unordered_map> plugin_ep_factories_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 87e5867793b61..9c0f55dc7c0a9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4674,27 +4674,27 @@ struct OrtApi { ORT_API2_STATUS(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out); - ORT_API2_STATUS(DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); + ORT_API2_STATUS(DeviceGetType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); ORT_API2_STATUS(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out); - ORT_API2_STATUS(DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out); + ORT_API2_STATUS(DeviceGetId, _In_ const OrtDevice* device, _Out_ int16_t* out); ORT_CLASS_RELEASE(Device); - ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); + ORT_API2_STATUS(RegisterPluginExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); - ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, + ORT_API2_STATUS(SessionOptionsAppendPluginExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); - ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); - ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); ORT_CLASS_RELEASE(TypeConstraints); + ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); + const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(uint32_t version)NO_EXCEPTION; }; // struct OrtApi diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index f2ae20332cf7c..448af6d95b0b5 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -147,11 +147,11 @@ ORT_API2_STATUS(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* ORT_API2_STATUS(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -ORT_API2_STATUS(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* out); +ORT_API2_STATUS(OrtNode_GetNumInputs, const OrtNode* node, _Out_ size_t* out); ORT_API2_STATUS(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -ORT_API2_STATUS(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* out); +ORT_API2_STATUS(OrtNode_GetNumOutputs, const OrtNode* node, _Out_ size_t* out); ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Outptr_ const char** out); diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index 1a8ebd15ac27a..8e557fe3c690d 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -37,7 +37,7 @@ class DataTransferAdapter : public IDataTransfer { class ExecutionProviderAdapter : public IExecutionProvider { public: ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type, ep->default_device ? *(ep->default_device) : OrtDevice()), ep_impl_(ep) { - intree_ep = false; + builtin_ep_ = false; if (ep_impl_->RegisterKernels) { kernel_registry_ = std::make_shared(); ep_impl_->RegisterKernels(reinterpret_cast(kernel_registry_.get())); diff --git a/onnxruntime/core/framework/provider_factory_adapter.h b/onnxruntime/core/framework/provider_factory_adapter.h index 9cfa68ecaa864..89695e88220f8 100644 --- a/onnxruntime/core/framework/provider_factory_adapter.h +++ b/onnxruntime/core/framework/provider_factory_adapter.h @@ -25,8 +25,6 @@ std::unique_ptr CreateProvider() override { return std::make_unique(ep_factory_->CreateExecutionProvider(ep_factory_, keys_.data(), values_.data(), provider_option_length_)); } OrtExecutionProviderFactory* ep_factory_; -//const char* const* provider_option_keys_; -//const char* const* provider_option_values_; std::vector provider_option_keys_, provider_option_values_; std::vector keys_, values_; size_t provider_option_length_; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 437c04d758931..d76c06079c44b 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -64,7 +64,7 @@ class StreamCommandHandleRegistryImpl : public IStreamCommandHandleRegistry { #ifdef ORT_ENABLE_STREAM static std::string ShouldPostPoneRegisterResourceFor(IExecutionProvider* ep, const ExecutionProviders& all_ep) { - if (ep->IsIntreeEp()) return ""; // TODO(leca): Or use dynamic_cast to check is it ExecutionProviderAdapter instance? Need to disable onnxruntime_DISABLE_RTTI + if (ep->IsBuiltInEp()) return ""; // TODO(leca): Or use dynamic_cast to check is it ExecutionProviderAdapter instance? Need to disable onnxruntime_DISABLE_RTTI for (auto& any_ep : all_ep) { if (any_ep->Type() != ep->Type() && any_ep->GetOrtDeviceByMemType(OrtMemTypeDefault) == ep->GetOrtDeviceByMemType(OrtMemTypeDefault)) return any_ep->Type(); } diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 8083a473211d7..9e6e0d3ba003f 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -348,14 +348,14 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ return Status{ONNXRUNTIME, common::INVALID_ARGUMENT, provider_type + " is not implemented in CreateAndRegisterAllocatorV2()"}; } -void Environment::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory) { +void Environment::InsertPluginEpFactory(const char* ep_name, OrtExecutionProviderFactory* ep_factory) { std::unique_ptr p(ep_factory); - custom_ep_factories_.insert({ep_name, std::move(p)}); // TODO(leca): review + plugin_ep_factories_.insert({ep_name, std::move(p)}); // TODO(leca): review } -OrtExecutionProviderFactory* Environment::GetOrtExecutionProviderFactory(const std::string& ep_name) { - std::unordered_map>::const_iterator it = custom_ep_factories_.find(ep_name); - if (it == custom_ep_factories_.end()) return nullptr; +OrtExecutionProviderFactory* Environment::GetPluginExecutionProviderFactory(const std::string& ep_name) { + std::unordered_map>::const_iterator it = plugin_ep_factories_.find(ep_name); + if (it == plugin_ep_factories_.end()) return nullptr; return it->second.get(); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index e18b61618ffeb..e0f54de387784 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -99,8 +99,6 @@ using onnxruntime::common::Status; using namespace onnxruntime; -typedef std::unordered_map ModelMetaData; - #ifndef ORT_STATUS_PTR #ifdef _WIN32 #define ORT_STATUS_PTR _Check_return_ _Ret_maybenull_ OrtStatusPtr @@ -2371,7 +2369,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateDevice, _In_ enum OrtMemoryInfoDeviceType dev return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out) { +ORT_API_STATUS_IMPL(OrtApis::DeviceGetType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out) { *out = static_cast(device->Type()); return nullptr; } @@ -2381,7 +2379,7 @@ ORT_API_STATUS_IMPL(OrtApis::DeviceGetMemoryType, _In_ const OrtDevice* device, return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out) { +ORT_API_STATUS_IMPL(OrtApis::DeviceGetId, _In_ const OrtDevice* device, _Out_ int16_t* out) { *out = device->Id(); return nullptr; } @@ -2390,23 +2388,23 @@ ORT_API(void, OrtApis::ReleaseDevice, OrtDevice* device) { delete device; } -ORT_API_STATUS_IMPL(OrtApis::RegisterOrtExecutionProviderLibrary, _In_ const char* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name) { +ORT_API_STATUS_IMPL(OrtApis::RegisterPluginExecutionProviderLibrary, _In_ const char* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name) { API_IMPL_BEGIN void* handle = nullptr; ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(ToPathString(lib_path), false, &handle)); if (handle) { OrtExecutionProviderFactory* (*symbol)(); ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, "RegisterCustomEp", (void**)&symbol)); - env->InsertCustomEp(ep_name, symbol()); + env->InsertPluginEpFactory(ep_name, symbol()); return nullptr; } return CreateStatus(ORT_RUNTIME_EXCEPTION, "cannot load the shared library for out-tree EP"); API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendPluginExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { - OrtExecutionProviderFactory* ep_factory = env->GetOrtExecutionProviderFactory(ep_name); + OrtExecutionProviderFactory* ep_factory = env->GetPluginExecutionProviderFactory(ep_name); if (ep_factory) { std::shared_ptr factory = std::make_shared(ep_factory, provider_options_keys, provider_options_values, num_keys); options->provider_factories.push_back(std::move(factory)); @@ -2414,12 +2412,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendOrtExecutionProvider, _In_ OrtS return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints) { - KernelRegistry* kr = reinterpret_cast(kernel_registry); - KernelCreateInfo kci = CreateKernelCreateInfo2("", custom_op, type_constraints); - return ToOrtStatus(kr->Register(std::move(kci))); -} - ORT_API_STATUS_IMPL(OrtApis::CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints) { std::unique_ptr otc = std::make_unique(); *type_constraints = otc.release(); @@ -2435,6 +2427,12 @@ ORT_API(void, OrtApis::ReleaseTypeConstraints, OrtTypeConstraints* type_constrai delete type_constraints; } +ORT_API_STATUS_IMPL(OrtApis::OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints) { + KernelRegistry* kr = reinterpret_cast(kernel_registry); + KernelCreateInfo kci = CreateKernelCreateInfo2("", custom_op, type_constraints); + return ToOrtStatus(kr->Register(std::move(kci))); +} + ORT_API(const OrtGraphApi*, OrtApis::GetGraphApi, uint32_t version) { //if (version >= xx && version <= ORT_API_VERSION) return OrtGraphApis::GetGraphApi(version); @@ -2819,17 +2817,17 @@ static constexpr OrtApi ort_api_1_to_19 = { // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::CreateDevice, - &OrtApis::DeviceGetDeviceType, + &OrtApis::DeviceGetType, &OrtApis::DeviceGetMemoryType, - &OrtApis::DeviceGetDeviceId, + &OrtApis::DeviceGetId, &OrtApis::ReleaseDevice, - &OrtApis::RegisterOrtExecutionProviderLibrary, - &OrtApis::SessionOptionsAppendOrtExecutionProvider, + &OrtApis::RegisterPluginExecutionProviderLibrary, + &OrtApis::SessionOptionsAppendPluginExecutionProvider, - &OrtApis::OrtKernelRegistry_RegisterKernel, &OrtApis::CreateOrtTypeConstraints, &OrtApis::AddTypeConstraint, &OrtApis::ReleaseTypeConstraints, + &OrtApis::OrtKernelRegistry_RegisterKernel, &OrtApis::GetGraphApi, }; diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index 4fbc4b38fffd1..d2afefc22d2a7 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -601,7 +601,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetIthImplicitInputName, const OrtNode return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* out) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetNumInputs, const OrtNode* node, _Out_ size_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); *out = n->InputDefs().size(); return nullptr; @@ -614,7 +614,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetIthInputName, const OrtNode* node, return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* out) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetNumOutputs, const OrtNode* node, _Out_ size_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); *out = n->OutputDefs().size(); return nullptr; @@ -769,9 +769,9 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtNode_GetOpType, &OrtGraphApis::OrtNode_GetImplicitInputSize, &OrtGraphApis::OrtNode_GetIthImplicitInputName, - &OrtGraphApis::OrtNode_GetInputSize, + &OrtGraphApis::OrtNode_GetNumInputs, &OrtGraphApis::OrtNode_GetIthInputName, - &OrtGraphApis::OrtNode_GetOutputSize, + &OrtGraphApis::OrtNode_GetNumOutputs, &OrtGraphApis::OrtNode_GetIthOutputName, &OrtGraphApis::OrtNode_GetIndex, &OrtGraphApis::OrtNode_GetAttributeNames, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 5ce145daf3fe9..6530bfb6205c1 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -526,26 +526,26 @@ ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ ORT_API_STATUS_IMPL(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out); -ORT_API_STATUS_IMPL(DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); +ORT_API_STATUS_IMPL(DeviceGetType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); ORT_API_STATUS_IMPL(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out); -ORT_API_STATUS_IMPL(DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out); +ORT_API_STATUS_IMPL(DeviceGetId, _In_ const OrtDevice* device, _Out_ int16_t* out); ORT_API(void, ReleaseDevice, _Frees_ptr_opt_ OrtDevice*); -ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); +ORT_API_STATUS_IMPL(RegisterPluginExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); -ORT_API_STATUS_IMPL(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, +ORT_API_STATUS_IMPL(SessionOptionsAppendPluginExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); -ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); - ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); ORT_API_STATUS_IMPL(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); ORT_API(void, ReleaseTypeConstraints, _In_ OrtTypeConstraints* type_constraints); +ORT_API_STATUS_IMPL(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); + ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index d1fb60b6fbbca..a36efeb1ee225 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -63,11 +63,11 @@ ORT_API_STATUS_IMPL(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ siz ORT_API_STATUS_IMPL(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -ORT_API_STATUS_IMPL(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* out); +ORT_API_STATUS_IMPL(OrtNode_GetNumInputs, const OrtNode* node, _Out_ size_t* out); ORT_API_STATUS_IMPL(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); -ORT_API_STATUS_IMPL(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* out); +ORT_API_STATUS_IMPL(OrtNode_GetNumOutputs, const OrtNode* node, _Out_ size_t* out); ORT_API_STATUS_IMPL(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Outptr_ const char** out); diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 4b15cdcb88351..fecd7a85ce9b1 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -112,10 +112,10 @@ onnxruntime::common::Status OrtEnv::CreateAndRegisterAllocatorV2(const std::stri return value_->CreateAndRegisterAllocatorV2(provider_type, mem_info, options, arena_cfg); } -void OrtEnv::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory) { - value_->InsertCustomEp(ep_name, ep_factory); +void OrtEnv::InsertPluginEpFactory(const char* ep_name, OrtExecutionProviderFactory* ep_factory) { + value_->InsertPluginEpFactory(ep_name, ep_factory); } -OrtExecutionProviderFactory* OrtEnv::GetOrtExecutionProviderFactory(const char* ep_name) { - return value_->GetOrtExecutionProviderFactory(ep_name); +OrtExecutionProviderFactory* OrtEnv::GetPluginExecutionProviderFactory(const char* ep_name) { + return value_->GetPluginExecutionProviderFactory(ep_name); } diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index dd0a87c44515d..621f6d8096953 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -66,9 +66,9 @@ struct OrtEnv { ~OrtEnv(); onnxruntime::common::Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg = nullptr); - void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); + void InsertPluginEpFactory(const char* ep_name, OrtExecutionProviderFactory* ep_factory); - OrtExecutionProviderFactory* GetOrtExecutionProviderFactory(const char* ep_name); + OrtExecutionProviderFactory* GetPluginExecutionProviderFactory(const char* ep_name); private: static std::unique_ptr p_instance_; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 92164728dd6c0..7e0cd174a42a7 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1189,7 +1189,7 @@ std::unique_ptr CreateExecutionProviderInstance( ->CreateProvider(); #endif } else { - OrtExecutionProviderFactory* plugin_ep_factory = GetEnv()->GetOrtExecutionProviderFactory(type); + OrtExecutionProviderFactory* plugin_ep_factory = GetEnv()->GetPluginExecutionProviderFactory(type); if (plugin_ep_factory != nullptr) { std::vector keys, values; const auto it = provider_options_map.find(type); @@ -1490,7 +1490,7 @@ void addGlobalMethods(py::module& m) { OrtExecutionProviderFactory* (*symbol)(); OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, "RegisterCustomEp", (void**)&symbol)); auto env = GetEnv(); - env->InsertCustomEp(provider_type, symbol()); + env->InsertPluginEpFactory(provider_type, symbol()); plugin_execution_providers.insert(std::string(provider_type)); } }); diff --git a/onnxruntime/test/python/onnxruntime_test_plugin_ep.py b/onnxruntime/test/python/onnxruntime_test_plugin_ep.py index 3ab4e67953c84..fcbccbe2931c4 100644 --- a/onnxruntime/test/python/onnxruntime_test_plugin_ep.py +++ b/onnxruntime/test/python/onnxruntime_test_plugin_ep.py @@ -1,13 +1,15 @@ import onnxruntime as ort import numpy -ort.register_plugin_execution_provider_library("outTreeEp", "/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so") +#ort.register_plugin_execution_provider_library("outTreeEp", "/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so") +ort.register_plugin_execution_provider_library("tensorrtEp", "/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so") sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL #session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=[("CPUExecutionProvider")]) #session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) -session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["outTreeEp", "CPUExecutionProvider"], provider_options=[{"int_property":"3", "str_property":"strvalue"}, {}]) +#session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["outTreeEp", "CPUExecutionProvider"], provider_options=[{"int_property":"3", "str_property":"strvalue"}, {}]) +session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["tensorrtEp", "CPUExecutionProvider"], provider_options=[{"device_id":"0", "str_property":"strvalue"}, {}]) y = session.run(None, {'x': numpy.array([-3.0, 5.0, -2.0, 4.0]).astype(numpy.float32)}) print(y) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 34f4fce0b0102..35ee9348b92d1 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -12,27 +12,27 @@ inline void THROW_ON_ERROR(OrtStatus* status) { } void TestCompileBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { - THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", env, "outTreeEp")); + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", env, "outTreeEp")); std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; - THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "outTreeEp", env, keys.data(), values.data(), keys.size())); + THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "outTreeEp", env, keys.data(), values.data(), keys.size())); } void TestKernelBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { - THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp_kernel/build/libkernelEp.so", env, "kernelEp")); + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp_kernel/build/libkernelEp.so", env, "kernelEp")); std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; - THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "kernelEp", env, keys.data(), values.data(), keys.size())); + THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "kernelEp", env, keys.data(), values.data(), keys.size())); } void TestTensorRTEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { - THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; - THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); + THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); } void TestTensorRTAndCudaEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { - THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; - THROW_ON_ERROR(g_ort->SessionOptionsAppendOrtExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); + THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); OrtCUDAProviderOptionsV2* cuda_options = nullptr; THROW_ON_ERROR(g_ort->CreateCUDAProviderOptions(&cuda_options)); diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index 028c001cbeae6..c77891bab9c72 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -27,7 +27,7 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe subgraph->meta_def = new OrtMetaDef(); subgraph->meta_def->name = "Relu_subgraph"; subgraph->meta_def->input_len = 0; - ort_graph_api->OrtNode_GetInputSize(node, &subgraph->meta_def->input_len); + ort_graph_api->OrtNode_GetNumInputs(node, &subgraph->meta_def->input_len); subgraph->meta_def->inputs = new char* [subgraph->meta_def->input_len]; for (size_t j = 0; j < subgraph->meta_def->input_len; j++) { const char* input_j = nullptr; @@ -35,7 +35,7 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe subgraph->meta_def->inputs[j] = const_cast(input_j); } - ort_graph_api->OrtNode_GetOutputSize(node, &subgraph->meta_def->output_len); + ort_graph_api->OrtNode_GetNumOutputs(node, &subgraph->meta_def->output_len); subgraph->meta_def->outputs = new char* [subgraph->meta_def->output_len]; for (size_t j = 0; j < subgraph->meta_def->output_len; j++) { const char* output_j = nullptr; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 9a7ca3739581a..691d863b8c531 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1030,7 +1030,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& } size_t input_count = 0; - graph_api_->OrtNode_GetInputSize(node, &input_count); + graph_api_->OrtNode_GetNumInputs(node, &input_count); for (size_t i = 0; i < input_count; ++i) { const char* input_name_char = nullptr; graph_api_->OrtNode_GetIthInputName(node, i, &input_name_char); @@ -1046,7 +1046,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& } size_t output_count = 0; - graph_api_->OrtNode_GetOutputSize(node, &output_count); + graph_api_->OrtNode_GetNumOutputs(node, &output_count); for (size_t i = 0; i < output_count; ++i) { const char* output_name_char = nullptr; graph_api_->OrtNode_GetIthOutputName(node, i, &output_name_char); @@ -1196,7 +1196,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr const OrtNode* node = nullptr; graph_api_->OrtGraph_GetOrtNode(graph, node_index[index], &node); size_t input_size = 0; - graph_api_->OrtNode_GetInputSize(node, &input_size); + graph_api_->OrtNode_GetNumInputs(node, &input_size); for (size_t j = 0; j < input_size; j++) { const char* input_name = nullptr; graph_api_->OrtNode_GetIthInputName(node, j, &input_name); @@ -1248,7 +1248,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr } size_t output_size = 0; - graph_api_->OrtNode_GetOutputSize(node, &output_size); + graph_api_->OrtNode_GetNumOutputs(node, &output_size); for (size_t j = 0; j < output_size; j++) { const char* output_name = nullptr; graph_api_->OrtNode_GetIthOutputName(node, j, &output_name); @@ -1555,7 +1555,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const for (size_t j = 0; j < cnt; j++) { std::unordered_map input_map, output_map; size_t input_size = 0; - graph_api_->OrtNode_GetInputSize(node[j], &input_size); + graph_api_->OrtNode_GetNumInputs(node[j], &input_size); for (size_t i = 0; i < input_size; i++) { const char* ith_input_name = nullptr; graph_api_->OrtNode_GetIthInputName(node[j], i, &ith_input_name); @@ -1563,7 +1563,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } size_t output_size = 0; - graph_api_->OrtNode_GetOutputSize(node[j], &output_size); + graph_api_->OrtNode_GetNumOutputs(node[j], &output_size); for (size_t i = 0; i < output_size; i++) { const char* ith_output_name = nullptr; graph_api_->OrtNode_GetIthOutputName(node[j], i, &ith_output_name); @@ -1586,8 +1586,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const OrtExecutionProvider::CanCopy = [](const OrtDevice* source, const OrtDevice* target) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); OrtMemoryInfoDeviceType source_device_type, target_device_type; - api_->DeviceGetDeviceType(source, &source_device_type); - api_->DeviceGetDeviceType(target, &target_device_type); + api_->DeviceGetType(source, &source_device_type); + api_->DeviceGetType(target, &target_device_type); OrtMemoryType source_mem_type, target_mem_type; api_->DeviceGetMemoryType(source, &source_mem_type); api_->DeviceGetMemoryType(target, &target_mem_type); diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index bc08040686592..4ead7f8c087d5 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -330,7 +330,7 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { const OrtNode* node = nullptr; graph_api->OrtGraph_GetOrtNode(graph_viewer, nodes_index[index], &node); size_t output_size = 0; - graph_api->OrtNode_GetOutputSize(node, &output_size); + graph_api->OrtNode_GetNumOutputs(node, &output_size); for (size_t i = 0; i < output_size; ++i) { const char* output_name = nullptr; graph_api->OrtNode_GetIthOutputName(node, i, &output_name); From d280e594da23195c3f69d8e96cd898584a6e17e0 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Fri, 18 Oct 2024 17:59:24 +0800 Subject: [PATCH 52/81] add documents for all functions in c_api_ep (#22502) --- .../core/session/onnxruntime_c_api_ep.h | 342 ++++++++++++++++++ 1 file changed, 342 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index 448af6d95b0b5..5cb07dd9e5014 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -87,104 +87,446 @@ typedef struct OrtExecutionProviderFactory { } OrtExecutionProviderFactory; struct OrtGraphApi { +/** \brief Get the graph name + * + * \param[in] graph The graph to query + * \param[out] out The name of the graph + * + */ ORT_API2_STATUS(OrtGraph_GetName, const OrtGraphViewer* graph, _Outptr_ const char** out); +/** \brief Check if the name is a constant initializer of the graph + * + * \param[in] graph The graph to query + * \param[in] name The name to check + * \param[in] check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. + * \param[out] out True if the name is a constant initializer of the graph + * + */ ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* out); +/** \brief Get the NodeIndex values of the graph nodes sorted in topological order + * + * \param[in] graph The graph to query + * \param[in] execution_order The execution order can be 0, 1 or 2 + * 0 means the nodes are sorted in topological order. + * 1 means the nodes are sorted in topological order with priority. + * 2 means the nodes are sorted in memory efficient topological order. + * \param[out] nodes_index_in_topological_order The NodeIndex values of the graph nodes sorted in topological order + * \param[out] num_nodes The number of nodes + * + */ ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order, _Out_ size_t* num_nodes); +/** \brief Check if the graph is a subgraph + * + * \param[in] graph The graph to query + * \param[out] out True if the graph is a subgraph + * + */ ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out); +/** \brief Get the parent graph of the graph + * + * \param[in] graph The graph to query + * \param[out] parent_graph The parent graph of the graph + * + */ ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); +/** \brief Get the parent node of the graph + * + * \param[in] graph The graph to query + * \param[out] parent_node The node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. + * + */ ORT_API2_STATUS(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); +/** \brief Gets the path of the owning model if any + * + * \param[in] graph The graph to query + * \param[out] model_path The path of the owning model if any + * + */ ORT_API2_STATUS(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** model_path); +/** \brief Get the internal graph in the graph viewer + * + * \param[in] graph_viewer The graph viewer to query + * \param[out] graph The internal graph in the graph viewer + * + */ ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); +/** \brief Gets the Graph inputs including initializers, in the same order as defined in the GraphProto. + * + * \param[in] graph The graph to query + * \param[out] input_names The input names + * \param[out] input_len The number of inputs + * + */ ORT_API2_STATUS(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); +/** \brief Get const Node given specific node index. May return nullptr if node as been freed. + * + * \param[in] graph The graph to query + * \param[in] node_index The index of the node + * \param[out] node The node + * + */ ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); +/** \brief Get the consumer nodes of a node arg with the given name + * + * \param[in] graph The graph to query + * \param[in] input_name The name of the node arg + * \param[out] consumers The consumer nodes of the node arg + * \param[out] num_consumers The number of consumer nodes + * + */ ORT_API2_STATUS(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers, _Out_ size_t* num_consumers); // TODO(leca): ValueConsumers::comprehensive ? +/** \brief Get the producer node of a node arg with the given name + * + * \param[in] graph The graph to query + * \param[in] output_name The name of the node arg + * \param[out] node The node producing the node arg + * + */ ORT_API2_STATUS(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** node); +/** \brief Gets the number of valid Nodes in the Graph. + * + * \param[in] graph The graph to query + * \param[out] num_nodes The number of valid nodes in the graph + * + */ ORT_API2_STATUS(OrtGraph_NumberOfNodes, const OrtGraphViewer* graph, _Out_ int* num_nodes); +/** \brief Gets the maximum NodeIndex value used in the Graph. + * + * \param[in] graph The graph to query + * \param[out] max_node_index The maximum NodeIndex value used by Nodes in the Graph + * + */ ORT_API2_STATUS(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* max_node_index); +/** \brief Gets the number of outputs of the Graph. + * + * \param[in] graph The graph to query + * \param[out] output_len The number of outputs of the graph + * + */ ORT_API2_STATUS(OrtGraph_GetOutputSize, const OrtGraphViewer* graph, _Out_ size_t* output_len); +/** \brief Gets the name of the i-th output of the Graph. + * + * \param[in] graph The graph to query + * \param[in] i The index of the output + * \param[out] out The name of the i-th output of the graph + * + */ ORT_API2_STATUS(OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i, _Outptr_ const char** out); +/** \brief Gets the element type of the i-th output of the Graph. + * + * \param[in] graph The graph to query + * \param[in] i The index of the output + * \param[out] out The element type of the i-th output of the graph + * + */ ORT_API2_STATUS(OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i, _Out_ int32_t* out); +/** \brief Gets the initializer tensor of the Graph. + * + * \param[in] graph The graph to query + * \param[in] initializer_name The name of the initializer tensor + * \param[out] out The initializer tensor + * \param[out] ret True if the initializer tensor is found + * + */ ORT_API2_STATUS(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**, _Out_ bool* ret); +/** \brief Gets the value info of the node arg with the given name. + * + * \param[in] graph The graph to query + * \param[in] name The name of the node arg + * \param[out] out The value info + * \param[out] ret True if the value info is found + * + */ ORT_API2_STATUS(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret); +/** \brief Serialize the Graph to a byte array. + * + * \param[in] graph The graph to serialize + * \param[out] data The byte array + * \param[out] data_size The size of the byte array + * + * \remarks The caller is responsible for freeing the byte array using OrtFreeMem. + * + */ ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); // TODO(leca): review and discuss +/** \brief Construct a subgraph from the Graph with the given node indices. + * + * \param[in] graph The graph to query + * \param[in] node_num The number of node indices + * \param[in] node_indices The indices of the nodes to include in the subgraph + * \param[out] subgraph The constructed subgraph + * + * \remarks The caller is responsible for releasing the subgraph using OrtGraph_ReleaseGraph. + * + */ ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss +/** \brief Release the graph + * + * \param[in] graph The graph to release + * + */ ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph); +/** \brief Gets the name of the node + * + * \param[in] node The node to query + * \param[out] out The name of the node + * + */ ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Outptr_ const char** out); +/** \brief Gets the description of the node + * + * \param[in] node The node to query + * \param[out] out The description of the node + * + */ ORT_API2_STATUS(OrtNode_GetDescription, const OrtNode* node, _Outptr_ const char** out); +/** \brief Gets the domain of the node + * + * \param[in] node The node to query + * \param[out] out The domain of the node + * + */ ORT_API2_STATUS(OrtNode_GetDomain, const OrtNode* node, _Outptr_ const char** out); +/** \brief Gets the opset version that the Node's operator was first defined in. + * + * \param[in] node The node to query + * \param[out] out The since version of the node + * + */ ORT_API2_STATUS(OrtNode_SinceVersion, const OrtNode* node, _Out_ int* out); +/** \brief Gets the execution ProviderType that this node will be executed by. + * + * \param[in] node The node to query + * \param[out] out The execution ProviderType of the node + * + */ ORT_API2_STATUS(OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** out); +/** \brief Gets the Node's operator type. + * + * \param[in] node The node to query + * \param[out] out The operator type of the node + * + */ ORT_API2_STATUS(OrtNode_GetOpType, const OrtNode* node, _Outptr_ const char** out); +/** \brief Gets the number of implicit inputs of the node. + * + * \param[in] node The node to query + * \param[out] out The number of implicit inputs of the node + * + */ ORT_API2_STATUS(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* out); +/** \brief Gets the i-th implicit input name of the node. + * + * \param[in] node The node to query + * \param[in] i The index of the implicit input + * \param[out] out The i-th implicit input name of the node + * + */ ORT_API2_STATUS(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); +/** \brief Gets the number of inputs of the node. + * + * \param[in] node The node to query + * \param[out] out The number of inputs of the node + * + */ ORT_API2_STATUS(OrtNode_GetNumInputs, const OrtNode* node, _Out_ size_t* out); +/** \brief Gets the i-th input name of the node. + * + * \param[in] node The node to query + * \param[in] i The index of the input + * \param[out] out The i-th input name of the node + * + */ ORT_API2_STATUS(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Outptr_ const char** out); +/** \brief Gets the number of outputs of the node. + * + * \param[in] node The node to query + * \param[out] out The number of outputs of the node + * + */ ORT_API2_STATUS(OrtNode_GetNumOutputs, const OrtNode* node, _Out_ size_t* out); +/** \brief Gets the i-th output name of the node. + * + * \param[in] node The node to query + * \param[in] i The index of the output + * \param[out] out The i-th output name of the node + * + */ ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Outptr_ const char** out); +/** \brief Gets the Node's NodeIndex. + * + * \param[in] node The node to query + * \param[out] out The Node's NodeIndex + * + */ ORT_API2_STATUS(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* out); +/** \brief Gets attribute names of the node. + * + * \param[in] node The node to query + * \param[out] names The attribute names of the node + * \param[out] num The number of attribute names + * + */ ORT_API2_STATUS(OrtNode_GetAttributeNames, const OrtNode* node, _Out_ const char*** names, _Out_ size_t* num); +/** \brief Gets the attribute size of the node. + * + * \param[in] node The node to query + * \param[out] out The attribute size of the node + * + */ ORT_API2_STATUS(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* out); +/** \brief Gets the attribute type of the node. + * + * \param[in] node The node to query + * \param[in] attribute The attribute name + * \param[out] out The attribute type of the node + * + */ ORT_API2_STATUS(OrtNode_GetAttributeType, const OrtNode* node, const char* attribute, _Out_ int* out); // AttributeProto_AttributeType +/** \brief Check if the attribute key exists in the node. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[out] out 1 if the attribute key exists in the node, 0 otherwise + * + */ ORT_API2_STATUS(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* out); +/** \brief Gets how many ints are in the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[out] out The number of ints in the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* out); +/** \brief Gets how many floats are in the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[out] out The number of floats in the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* out); +/** \brief Gets how many strings are in the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[out] out The number of strings in the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* out); +/** \brief Gets the i-th int in the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[in] i The index of the int + * \param[out] out The i-th int in the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* out); +/** \brief Gets the i-th float in the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[in] i The index of the float + * \param[out] out The i-th float in the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* out); +/** \brief Gets the i-th string in the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[in] i The index of the string + * \param[out] out The i-th string in the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Outptr_ const char** out); +/** \brief Gets the string value of the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[out] out The string value of the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out); +/** \brief Gets the int value of the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[out] out The int value of the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeInt, const OrtNode* node, const char* key, _Out_ int64_t* out); +/** \brief Gets the float value of the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[out] out The float value of the attribute + * + */ ORT_API2_STATUS(OrtNode_GetAttributeFloat, const OrtNode* node, const char* key, _Out_ float* out); +/** \brief Gets the subgraphs of the given node. + * + * \param[in] node The node to query + * \param[out] subgraphs The subgraphs of the node + * \param[out] num_subgraphs The number of subgraphs + * + */ ORT_API2_STATUS(OrtNode_GetSubgraphs, const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs, _Out_ size_t* num_subgraphs); +/** \brief Free the memory + * + * \param[in] p The memory to free + * + */ ORT_API2_STATUS(OrtFreeMem, void* p); }; typedef struct OrtGraphApi OrtGraphApi; From cbe98e7971e9bcb2b237e3c5637d5c89952480dd Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Sat, 19 Oct 2024 00:43:23 +0000 Subject: [PATCH 53/81] fix comments --- .../core/session/onnxruntime_c_api.h | 99 +++++++++++++++++-- .../core/session/onnxruntime_c_api_ep.h | 28 +++++- .../core/session/onnxruntime_c_api_ep.cc | 36 +++++-- onnxruntime/core/session/ort_apis_ep.h | 6 +- samples/tensorRTEp/CMakeLists.txt | 7 +- .../tensorrt_execution_provider_utils.h | 2 +- 6 files changed, 156 insertions(+), 22 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9c0f55dc7c0a9..d317ca02ac407 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4672,29 +4672,114 @@ struct OrtApi { _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, size_t num_external_initializer_files); + /** \brief Create OrtDevice object. + * + * \param[in] device_type + * \param[in] memory_type + * \param[in] device_id + * \param[out] out OrtDevice object + * + * \since Version 1.xx. + */ ORT_API2_STATUS(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out); + /** \brief Get OrtMemoryInfoDeviceType property from OrtDevice object. + * + * \param[in] device OrtDevice object + * \param[out] out OrtMemoryInfoDeviceType property + * + * \since Version 1.xx. + */ ORT_API2_STATUS(DeviceGetType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); + /** \brief Get OrtMemoryType property from OrtDevice object. + * + * \param[in] device OrtDevice object + * \param[out] out OrtMemoryType property + * + * \since Version 1.xx. + */ ORT_API2_STATUS(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out); + /** \brief Get device id property from OrtDevice object. + * + * \param[in] device OrtDevice object + * \param[out] out device id property + * + * \since Version 1.xx. + */ ORT_API2_STATUS(DeviceGetId, _In_ const OrtDevice* device, _Out_ int16_t* out); + /** \brief Release OrtDevice object. + * + * \since Version 1.xx. + */ ORT_CLASS_RELEASE(Device); + /** \brief Register the plugin ExecutionProvider library + * + * The plugin ExecutionProvider library will be loaded and EP factory object will be created and saved in OrtEnv object + * + * \param[in] lib_path the path of the plugin ExecutionProvider library + * \param[in] env OrtEnv object + * \param[in] ep_name the plugin ExecutionProvider name + * + * \since Version 1.xx. + */ ORT_API2_STATUS(RegisterPluginExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); + /** \brief Append the plugin ExecutionProvider factory into the session option with provider options + * + * \param[in] options OrtSessionOptions object + * \param[in] ep_name the plugin ExecutionProvider name + * \param[in] env OrtEnv object + * \param[in] provider_options_keys provider options' keys + * \param[in] provider_options_values provider options' values + * \param[in] num_keys the number of the provider options' key-value pairs + * + * \since Version 1.xx. + */ ORT_API2_STATUS(SessionOptionsAppendPluginExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + /** \brief Create OrtTypeConstraints object + * + * \param[out] OrtTypeConstraints object + * + * \since Version 1.xx. + */ ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); + /** \brief Add a specific type constraint into OrtTypeConstraints object + * + * \param[in] type_constraints OrtTypeConstraints object + * \param[in] type_symbol symbol string to represent a specific type + * \param[in] type a specific type + * + * \since Version 1.xx. + */ ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); + /** \brief Release OrtTypeConstraints object. + * + * \since Version 1.xx. + */ ORT_CLASS_RELEASE(TypeConstraints); - ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); + /** \brief Create KernelCreateInfo with custom op and type constraints, and register it + * + * \param[in] kernel_registry Opaque pointer of KernelRegistry object + * \param[in] custom_op Custom Op where the kernel compute function is defined + * \param[in] type_constraints + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, _In_ OrtKernelRegistry* kernel_registry, _In_ OrtCustomOp* custom_op, _In_ OrtTypeConstraints* type_constraints); + /** \brief Get Graph API + * + * \since Version 1.xx. + */ const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(uint32_t version)NO_EXCEPTION; }; // struct OrtApi @@ -4725,11 +4810,13 @@ typedef enum OrtCustomOpInputOutputCharacteristic { */ struct OrtCustomOp { #ifdef __cplusplus - // TODO(leca): initialize all member function pointers to nullptr? - OrtCustomOp() : CreateKernel{nullptr}, KernelCompute{nullptr}, KernelDestroy{nullptr}, GetInputCharacteristic{nullptr}, - GetOutputCharacteristic{nullptr}, GetVariadicInputMinArity{nullptr}, GetVariadicOutputMinArity{nullptr}, - GetStartVersion{nullptr}, GetEndVersion{nullptr}, GetMayInplace{nullptr}, ReleaseMayInplace{nullptr}, - GetAliasMap{nullptr}, ReleaseAliasMap{nullptr} {} + OrtCustomOp() : CreateKernel{nullptr}, GetName{nullptr}, GetExecutionProviderType{nullptr}, GetInputType{nullptr}, + GetInputTypeCount{nullptr}, GetOutputType{nullptr}, GetOutputTypeCount{nullptr}, KernelCompute{nullptr}, + KernelDestroy{nullptr}, GetInputCharacteristic{nullptr}, GetOutputCharacteristic{nullptr}, + GetInputMemoryType{nullptr}, GetVariadicInputMinArity{nullptr}, GetVariadicInputHomogeneity{nullptr}, + GetVariadicOutputMinArity{nullptr}, GetVariadicOutputHomogeneity{nullptr}, CreateKernelV2{nullptr}, + KernelComputeV2{nullptr}, InferOutputShapeFn{nullptr}, GetStartVersion{nullptr}, GetEndVersion{nullptr}, + GetMayInplace{nullptr}, ReleaseMayInplace{nullptr}, GetAliasMap{nullptr}, ReleaseAliasMap{nullptr} {} #endif uint32_t version; // Must be initialized to ORT_API_VERSION diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index 5cb07dd9e5014..b0662ce435b46 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -158,14 +158,32 @@ ORT_API2_STATUS(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ con */ ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); -/** \brief Gets the Graph inputs including initializers, in the same order as defined in the GraphProto. +/** \brief Gets the Graph inputs with no matching initializers, in the same order as defined in the GraphProto. * * \param[in] graph The graph to query * \param[out] input_names The input names * \param[out] input_len The number of inputs * */ -ORT_API2_STATUS(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); +ORT_API2_STATUS(OrtGraph_GetRequiredInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); + +/** \brief Gets the Graph inputs with matching initializers, in the same order as defined in the GraphProto. + * + * \param[in] graph The graph to query + * \param[out] input_names The input names + * \param[out] input_len The number of inputs + * + */ +ORT_API2_STATUS(OrtGraph_GetAllInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); + +/** \brief Gets all the Graph initializers' name + * + * \param[in] graph The graph to query + * \param[out] initializer_names The initializer names + * \param[out] initializer_len The number of initializers + * + */ +ORT_API2_STATUS(OrtGraph_GetAllInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** initializer_names, _Out_ size_t* initializer_len); /** \brief Get const Node given specific node index. May return nullptr if node as been freed. * @@ -280,7 +298,11 @@ ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ vo */ ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss -/** \brief Release the graph +/** \brief Release the graph. + * + * NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocate model instead of + * graph, this API release graph's owning_model explicitly which in turn will release the graph + * (because graph is hosted in an unique_ptr in Model class) * * \param[in] graph The graph to release * diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index d2afefc22d2a7..067bb9a9703f9 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -61,15 +61,34 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetOrtGraph, const OrtGraphViewer* gr return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetRequiredInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const auto& inputs = graph_viewer->GetInputs(); + *input_len = inputs.size(); + *input_names = new const char*[*input_len]; // TODO(leca): release + for (size_t i = 0; i < *input_len; i++) (*input_names)[i] = inputs[i]->Name().c_str(); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetAllInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const auto& inputs = graph_viewer->GetInputsIncludingInitializers(); *input_len = inputs.size(); - *input_names = new const char*[*input_len]; + *input_names = new const char*[*input_len]; // TODO(leca): release for (size_t i = 0; i < *input_len; i++) (*input_names)[i] = inputs[i]->Name().c_str(); return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetAllInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** initializer_names, _Out_ size_t* initializer_len) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + const auto& initializers = graph_viewer->GetAllInitializedTensors(); + *initializer_len = initializers.size(); + *initializer_names = new const char*[*initializer_len]; // TODO(leca): release + int i = 0; + for (const auto& [key, value] : initializers) (*initializer_names)[i++] = key.c_str(); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); *node = reinterpret_cast(graph_viewer->GetNode(node_index)); @@ -129,7 +148,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphV *ret = false; return nullptr; } - *out = new OrtTensorRef(); // TODO(leca): release + *out = new OrtTensorRef(); // TODO(leca): 1. release, 2. other datatypes in the following switch (*out)->shape_len = initializer->dims_size(); (*out)->shape = new int64_t [initializer->dims_size()]; for (size_t i = 0; i < (*out)->shape_len; i++) { @@ -182,7 +201,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_SerializeToArray, const OrtGraphViewe onnx::ModelProto model_proto = model.ToProto(); GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true, ExecutionOrder::PRIORITY_BASED); *data_size = model_proto.ByteSizeLong(); - *data = malloc(*data_size); // TODO(leca): release + *data = malloc(*data_size); model_proto.SerializeToArray(*data, *data_size); return nullptr; } @@ -432,7 +451,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr for (const auto* output_arg : graph_viewer->GetOutputs()) { graph_output_names.insert(output_arg->Name()); } - // TODO(leca): cannot use unique_ptr here, otherwise when this function exits, sub_graph_viewer->graph_->graph_proto_, which is from model_build->model_proto_, will be nullptr. + // NOTE!!: cannot use unique_ptr here, otherwise when this function exits, sub_graph_viewer->graph_->graph_proto_, which is from model_build->model_proto_, will be nullptr. // Pay special attention when Graph object is releasing. We need to release model_build seperately then. Model* model_build = new Model (graph_viewer->Name(), true, ModelMetaData(), PathString(), #if !defined(ORT_MINIMAL_BUILD) @@ -511,7 +530,6 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr // TODO:yang // Only if the newly built graph has control flow op as well as it has parent node, // it needs to handle outer scope values before calling graph.Resolve(). - // TODO(leca): Is local variable enough? Do we need to make it EP class variable? std::unordered_map> subgraph_context_map; if (has_control_flow_op && graph_viewer->ParentNode()) { // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); @@ -547,7 +565,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraphViewer* graph) { if (graph) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - delete graph_viewer; + delete &(graph_viewer->GetGraph()).GetModel(); } return nullptr; } @@ -747,7 +765,9 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetParenNode, &OrtGraphApis::OrtGraph_GetModelPath, &OrtGraphApis::OrtGraph_GetOrtGraph, - &OrtGraphApis::OrtGraph_GetInputsIncludingInitializers, + &OrtGraphApis::OrtGraph_GetRequiredInputs, + &OrtGraphApis::OrtGraph_GetAllInputs, + &OrtGraphApis::OrtGraph_GetAllInitializers, &OrtGraphApis::OrtGraph_GetOrtNode, &OrtGraphApis::OrtGraph_GetNodesConsumingInput, &OrtGraphApis::OrtGraph_GetNodeProducingOutput, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index a36efeb1ee225..4034c4fe96ae9 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -19,7 +19,11 @@ ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ ORT_API_STATUS_IMPL(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); -ORT_API_STATUS_IMPL(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); +ORT_API_STATUS_IMPL(OrtGraph_GetRequiredInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); + +ORT_API_STATUS_IMPL(OrtGraph_GetAllInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); + +ORT_API_STATUS_IMPL(OrtGraph_GetAllInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); ORT_API_STATUS_IMPL(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index f9711bd77537f..15b5f65acf2a7 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -1,6 +1,6 @@ # usage: # cd build/ -# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/leca/TensorRT-10.4.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/leca/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") # cmake --build ./ cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) @@ -24,7 +24,7 @@ target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "../../build/tensorrt/Debug/_deps/protobuf-src/src") ## looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol -target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so" +target_link_libraries(TensorRTEp PUBLIC "/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so" ${TENSORRT_HOME}/lib/libnvinfer.so ${TENSORRT_HOME}/lib/libnvinfer_plugin.so ${TENSORRT_HOME}/lib/libnvonnxparser.so @@ -33,4 +33,5 @@ target_link_libraries(TensorRTEp PUBLIC #"/home/leca/code/onnxruntime/build/tens "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/onnx-build/libonnx.a" "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/onnx-build/libonnx_proto.a" "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotobufd.a" - "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotocd.a") + "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotocd.a" + ) diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index 4ead7f8c087d5..ace9d73dd5e36 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -313,7 +313,7 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { // const std::vector& input_names = nullptr; const char** input_names = nullptr; size_t input_count = 0; - graph_api->OrtGraph_GetInputsIncludingInitializers(graph_viewer, &input_names, &input_count); + graph_api->OrtGraph_GetAllInputs(graph_viewer, &input_names, &input_count); for (size_t i = 0; i < input_count; ++i) { hash_str(input_names[i]); } From 1529059972bb37de2d91d9431f5ea6a7861c7e50 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Mon, 21 Oct 2024 21:25:48 +0800 Subject: [PATCH 54/81] fix memory leak (#22522) --- .../core/session/onnxruntime_c_api_ep.h | 37 ++++++++++++++++ .../core/session/onnxruntime_c_api_ep.cc | 43 ++++++++++++++++--- onnxruntime/core/session/ort_apis_ep.h | 6 +++ 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index b0662ce435b46..4c3a9a03611ae 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -159,6 +159,8 @@ ORT_API2_STATUS(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ con ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); /** \brief Gets the Graph inputs with no matching initializers, in the same order as defined in the GraphProto. + * + * NOTE!!: The caller is responsible for releasing the char array using ReleaseCharArray. * * \param[in] graph The graph to query * \param[out] input_names The input names @@ -168,6 +170,8 @@ ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outpt ORT_API2_STATUS(OrtGraph_GetRequiredInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); /** \brief Gets the Graph inputs with matching initializers, in the same order as defined in the GraphProto. + * + * NOTE!!: The caller is responsible for releasing the char array using ReleaseCharArray. * * \param[in] graph The graph to query * \param[out] input_names The input names @@ -177,6 +181,8 @@ ORT_API2_STATUS(OrtGraph_GetRequiredInputs, const OrtGraphViewer* graph, _Outptr ORT_API2_STATUS(OrtGraph_GetAllInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); /** \brief Gets all the Graph initializers' name + * + * NOTE!!: The caller is responsible for releasing the char array using ReleaseCharArray. * * \param[in] graph The graph to query * \param[out] initializer_names The initializer names @@ -185,6 +191,15 @@ ORT_API2_STATUS(OrtGraph_GetAllInputs, const OrtGraphViewer* graph, _Outptr_ con */ ORT_API2_STATUS(OrtGraph_GetAllInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** initializer_names, _Out_ size_t* initializer_len); +/** \brief Release the char array + * + * NOTE!!: Invoke this function after the use of OrtGraph_GetRequiredInputs, OrtGraph_GetAllInputs, OrtGraph_GetAllInitializers. + * + * \param[in] char_array The char array to release + * + */ +ORT_API2_STATUS(ReleaseCharArray, const char** char_array); + /** \brief Get const Node given specific node index. May return nullptr if node as been freed. * * \param[in] graph The graph to query @@ -256,6 +271,8 @@ ORT_API2_STATUS(OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size_t i ORT_API2_STATUS(OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i, _Out_ int32_t* out); /** \brief Gets the initializer tensor of the Graph. + * + * NOTE!!: The caller is responsible for releasing the initializer tensor using OrtGraph_ReleaseInitializerTensor. * * \param[in] graph The graph to query * \param[in] initializer_name The name of the initializer tensor @@ -265,7 +282,18 @@ ORT_API2_STATUS(OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i, */ ORT_API2_STATUS(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**, _Out_ bool* ret); +/** \brief Release the initializer tensor. + * + * NOTE!!: Invoke this function after the use of OrtGraph_GetInitializerTensor. + * + * \param[in] tensor The initializer tensor to release + * + */ +ORT_API2_STATUS(OrtGraph_ReleaseInitializerTensor, OrtTensorRef* tensor); + /** \brief Gets the value info of the node arg with the given name. + * + * NOTE!!: The caller is responsible for releasing the value info using OrtGraph_ReleaseValueInfo. * * \param[in] graph The graph to query * \param[in] name The name of the node arg @@ -275,6 +303,15 @@ ORT_API2_STATUS(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, cons */ ORT_API2_STATUS(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret); +/** \brief Release the value info. + * + * NOTE!!: Invoke this function after the use of OrtGraph_GetValueInfo. + * + * \param[in] value_info The value info to release + * + */ +ORT_API2_STATUS(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info); + /** \brief Serialize the Graph to a byte array. * * \param[in] graph The graph to serialize diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index 067bb9a9703f9..2d2fbc1d2c266 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -65,7 +65,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetRequiredInputs, const OrtGraphView const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const auto& inputs = graph_viewer->GetInputs(); *input_len = inputs.size(); - *input_names = new const char*[*input_len]; // TODO(leca): release + *input_names = new const char*[*input_len]; // Should be released by the caller using OrtGraphApis::ReleaseCharArray for (size_t i = 0; i < *input_len; i++) (*input_names)[i] = inputs[i]->Name().c_str(); return nullptr; } @@ -74,7 +74,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetAllInputs, const OrtGraphViewer* g const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const auto& inputs = graph_viewer->GetInputsIncludingInitializers(); *input_len = inputs.size(); - *input_names = new const char*[*input_len]; // TODO(leca): release + *input_names = new const char*[*input_len]; // Should be released by the caller using OrtGraphApis::ReleaseCharArray for (size_t i = 0; i < *input_len; i++) (*input_names)[i] = inputs[i]->Name().c_str(); return nullptr; } @@ -83,12 +83,20 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetAllInitializers, const OrtGraphVie const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const auto& initializers = graph_viewer->GetAllInitializedTensors(); *initializer_len = initializers.size(); - *initializer_names = new const char*[*initializer_len]; // TODO(leca): release + *initializer_names = new const char*[*initializer_len]; // Should be released by the caller using OrtGraphApis::ReleaseCharArray int i = 0; for (const auto& [key, value] : initializers) (*initializer_names)[i++] = key.c_str(); return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::ReleaseCharArray, const char** char_array) { + if (!char_array) { + return nullptr; + } + delete[] char_array; + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); *node = reinterpret_cast(graph_viewer->GetNode(node_index)); @@ -148,7 +156,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphV *ret = false; return nullptr; } - *out = new OrtTensorRef(); // TODO(leca): 1. release, 2. other datatypes in the following switch + *out = new OrtTensorRef(); // TODO(leca): other datatypes in the following switch (*out)->shape_len = initializer->dims_size(); (*out)->shape = new int64_t [initializer->dims_size()]; for (size_t i = 0; i < (*out)->shape_len; i++) { @@ -167,6 +175,17 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphV return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseInitializerTensor, OrtTensorRef* tensor) { + if (!tensor) { + return nullptr; + } + if (tensor->shape) { + delete[] tensor->shape; + } + delete tensor; + return nullptr; +} + static ONNXTensorElementDataType GetDataTypeFromTypeProto(const onnx::TypeProto* type) { // onnxruntime\core\optimizer\transpose_optimization\ort_optimizer_api_impl.cc if (!type || !utils::HasTensorType(*type) || !utils::HasElementType(*type)) return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; @@ -177,7 +196,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* g const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const NodeArg* node_arg = graph_viewer->GetNodeArg(name); - *out = new OrtValueInfoRef(); // TODO(leca): release + *out = new OrtValueInfoRef(); const onnx::TypeProto* type = node_arg->TypeAsProto(); (*out)->data_type = GetDataTypeFromTypeProto(type); const auto& dims = utils::TryGetShape(*type)->dim(); @@ -189,6 +208,17 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* g return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info) { + if (!value_info) { + return nullptr; + } + if (value_info->shape) { + delete[] value_info->shape; + } + delete value_info; + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); Model model(graph_viewer->Name(), true, ModelMetaData(), PathString(), @@ -768,6 +798,7 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetRequiredInputs, &OrtGraphApis::OrtGraph_GetAllInputs, &OrtGraphApis::OrtGraph_GetAllInitializers, + &OrtGraphApis::ReleaseCharArray, &OrtGraphApis::OrtGraph_GetOrtNode, &OrtGraphApis::OrtGraph_GetNodesConsumingInput, &OrtGraphApis::OrtGraph_GetNodeProducingOutput, @@ -777,7 +808,9 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetIthOutputName, &OrtGraphApis::OrtGraph_GetIthOutputElemType, &OrtGraphApis::OrtGraph_GetInitializerTensor, + &OrtGraphApis::OrtGraph_ReleaseInitializerTensor, &OrtGraphApis::OrtGraph_GetValueInfo, + &OrtGraphApis::OrtGraph_ReleaseValueInfo, &OrtGraphApis::OrtGraph_SerializeToArray, &OrtGraphApis::OrtGraph_GetSubGraph, &OrtGraphApis::OrtGraph_ReleaseGraph, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index 4034c4fe96ae9..196cab41264ca 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -25,6 +25,8 @@ ORT_API_STATUS_IMPL(OrtGraph_GetAllInputs, const OrtGraphViewer* graph, _Outptr_ ORT_API_STATUS_IMPL(OrtGraph_GetAllInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); +ORT_API_STATUS_IMPL(ReleaseCharArray, const char** char_array); + ORT_API_STATUS_IMPL(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); ORT_API_STATUS_IMPL(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers, _Out_ size_t* num_consumers); @@ -43,8 +45,12 @@ ORT_API_STATUS_IMPL(OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t ORT_API_STATUS_IMPL(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** tensor, _Out_ bool* ret); +ORT_API_STATUS_IMPL(OrtGraph_ReleaseInitializerTensor, OrtTensorRef* tensor); + ORT_API_STATUS_IMPL(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret); +ORT_API_STATUS_IMPL(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info); + ORT_API_STATUS_IMPL(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); From fa549f897baed7de5221bcc4ab8c5bb521d69d0c Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Thu, 24 Oct 2024 20:05:18 +0800 Subject: [PATCH 55/81] add mutex to plugin trt ep (#22581) --- samples/tensorRTEp/CMakeLists.txt | 4 +++- .../tensorRTEp/tensorrt_execution_provider.cc | 19 ++++++++++++------- .../tensorRTEp/tensorrt_execution_provider.h | 14 +++++++++++--- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index 15b5f65acf2a7..8fa7aefd14f47 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -21,7 +21,8 @@ target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "../../build/tensorrt/Debug/_deps/gsl-src/include" "../../build/tensorrt/Debug/_deps/onnx-src" "../../build/tensorrt/Debug/_deps/onnx-build" - "../../build/tensorrt/Debug/_deps/protobuf-src/src") + "../../build/tensorrt/Debug/_deps/protobuf-src/src" + "../../build/tensorrt/Debug/_deps/google_nsync-src/public") ## looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol target_link_libraries(TensorRTEp PUBLIC "/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so" @@ -34,4 +35,5 @@ target_link_libraries(TensorRTEp PUBLIC "/home/leca/code/onnxruntime/build/tenso "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/onnx-build/libonnx_proto.a" "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotobufd.a" "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotocd.a" + "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/google_nsync-build/libnsync_cpp.a" ) diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 691d863b8c531..953512234cf2c 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -366,6 +366,11 @@ TensorrtLogger& GetTensorrtLogger(bool verbose_log) { return trt_logger; } +std::unique_lock TensorrtExecutionProvider::GetApiLock() const { + static OrtMutex singleton; + return std::unique_lock(singleton); +} + template void GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, void* shape_values, @@ -2088,7 +2093,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const } { - // auto lock = GetApiLock(); // TODO(leca) + auto lock = GetApiLock(); runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); } } @@ -2105,7 +2110,7 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const { if (!builder_) { { - // auto lock = GetApiLock(); // TODO(leca) + auto lock = GetApiLock(); builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); } } @@ -2525,7 +2530,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } { // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading. - // auto lock = GetApiLock(); // TODO(leca) + auto lock = GetApiLock(); // If explicit profile flag is on and engine cache enable flag is on, // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine. @@ -2786,7 +2791,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort *p = {context->AllocateFunc, context->DestroyFunc, context->allocator_handle, context->node_name, this_->builder_.get(), &(this_->parsers_[context->node_name]), &(this_->engines_[context->node_name]), &(this_->contexts_[context->node_name]), &(this_->networks_[context->node_name]), this_->input_info_[context->node_name], this_->output_info_[context->node_name], - this_->input_shape_ranges_[context->node_name], /*&tensorrt_mu_,*/ this_->fp16_enable_, this_->int8_enable_, this_->int8_calibration_cache_available_, + this_->input_shape_ranges_[context->node_name], &this_->tensorrt_mu_, this_->fp16_enable_, this_->int8_enable_, this_->int8_calibration_cache_available_, this_->dla_enable_, this_->dla_core_, &(this_->max_workspace_size_), this_->trt_node_name_with_precision_[context->node_name], this_->engine_cache_enable_, this_->cache_path_, this_->runtime_.get(), this_->profiles_[context->node_name], this_->context_memory_sharing_enable_, &(this_->max_ctx_mem_size_), this_->dynamic_range_map_[context->node_name], this_->engine_decryption_enable_, @@ -2811,7 +2816,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - //std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); // TODO(leca) + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; @@ -3068,7 +3073,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // Build engine std::unique_ptr serialized_engine; { - //auto lock = GetApiLock(); // TODO(leca) + auto lock = this_->GetApiLock(); std::chrono::steady_clock::time_point engine_build_start; if (this_->detailed_build_log_) { engine_build_start = std::chrono::steady_clock::now(); @@ -3467,7 +3472,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // The whole compute_function should be considered the critical section. // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading -//TODO(leca): std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 255a8d411e014..4da00f4724116 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -4,6 +4,7 @@ #include #include "core/session/onnxruntime_c_api_ep.h" #include "core/framework/provider_options.h" +#include "core/platform/ort_mutex.h" #include "tensorrt_execution_provider_info.h" #include "nv_includes.h" @@ -161,7 +162,7 @@ struct TensorrtFuncState { std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; -// OrtMutex* tensorrt_mu_ptr = nullptr; + OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; @@ -207,7 +208,7 @@ struct TensorrtShortFuncState { std::vector> output_info; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; -// OrtMutex* tensorrt_mu_ptr = nullptr; + OrtMutex* tensorrt_mu_ptr = nullptr; }; using DDSOutputAllocatorMap = std::unordered_map>; @@ -231,6 +232,13 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { bool DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const OrtGraphViewer* graph, const HashValue& model_hash, bool remove_cycles = true) const; + /** + Get a unique_lock object to control the concurrency behavior. + Every api call not in the thread-safe operations(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading) + should be protected by a lock when invoked by multiple threads concurrently. + */ + std::unique_lock GetApiLock() const; + /**Check the graph is the subgraph of control flow op*/ bool IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const; @@ -272,7 +280,7 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { std::string tactic_sources_; std::string global_cache_path_, cache_path_, engine_decryption_lib_path_; std::unique_ptr runtime_ = nullptr; -// OrtMutex tensorrt_mu_; + OrtMutex tensorrt_mu_; int device_id_; std::string compute_capability_; bool context_memory_sharing_enable_ = false; From a28ad3803e9ef75cc23ebda5264824005851a989 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 24 Oct 2024 17:50:42 +0000 Subject: [PATCH 56/81] use std::mutex instead of OrtMutex and fix build error in Windows --- cmake/onnxruntime_unittests.cmake | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- onnxruntime/core/session/onnxruntime_c_api_ep.cc | 10 +++++----- samples/tensorRTEp/CMakeLists.txt | 3 +-- samples/tensorRTEp/tensorrt_execution_provider.cc | 10 +++++----- samples/tensorRTEp/tensorrt_execution_provider.h | 10 +++++----- 6 files changed, 18 insertions(+), 19 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0159c35d1941b..d3acd1718dd87 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1227,7 +1227,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") list(APPEND onnxruntime_perf_test_libs onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 gtest absl_failure_signal_handler absl_examine_stack absl_flags_parse absl_flags_usage absl_flags_usage_internal) - endif() + endif() target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) if(WIN32) target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index e0f54de387784..c1a0b726c5d7e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2388,7 +2388,7 @@ ORT_API(void, OrtApis::ReleaseDevice, OrtDevice* device) { delete device; } -ORT_API_STATUS_IMPL(OrtApis::RegisterPluginExecutionProviderLibrary, _In_ const char* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name) { +ORT_API_STATUS_IMPL(OrtApis::RegisterPluginExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name) { API_IMPL_BEGIN void* handle = nullptr; ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(ToPathString(lib_path), false, &handle)); diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index 2d2fbc1d2c266..c83d5883aac61 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -160,7 +160,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphV (*out)->shape_len = initializer->dims_size(); (*out)->shape = new int64_t [initializer->dims_size()]; for (size_t i = 0; i < (*out)->shape_len; i++) { - ((*out)->shape)[i] = initializer->dims(i); + ((*out)->shape)[i] = initializer->dims(static_cast(i)); } (*out)->data_type = static_cast(initializer->data_type()); @@ -202,7 +202,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* g const auto& dims = utils::TryGetShape(*type)->dim(); (*out)->shape_len = dims.size(); (*out)->shape = new int64_t [(*out)->shape_len]; - for (size_t i = 0; i < (*out)->shape_len; i++) ((*out)->shape)[i] = utils::HasDimValue(dims[i]) ? dims[i].dim_value() : -1; + for (size_t i = 0; i < (*out)->shape_len; i++) ((*out)->shape)[i] = utils::HasDimValue(dims[static_cast(i)]) ? dims[static_cast(i)].dim_value() : -1; *ret = true; return nullptr; @@ -232,7 +232,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_SerializeToArray, const OrtGraphViewe GraphViewerToProto(*graph_viewer, *model_proto.mutable_graph(), true, true, ExecutionOrder::PRIORITY_BASED); *data_size = model_proto.ByteSizeLong(); *data = malloc(*data_size); - model_proto.SerializeToArray(*data, *data_size); + model_proto.SerializeToArray(*data, static_cast(*data_size)); return nullptr; } @@ -544,8 +544,8 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr NodeAttributes node_attributes; node_attributes.reserve(num_attributes); - for (int i = 0; i < num_attributes; ++i) { - auto& attr = node_proto->attribute(i); + for (int ii = 0; ii < num_attributes; ++ii) { + auto& attr = node_proto->attribute(ii); node_attributes.emplace(attr.name(), attr); } diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index 8fa7aefd14f47..641d05d1b1ad2 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -22,7 +22,7 @@ target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "../../build/tensorrt/Debug/_deps/onnx-src" "../../build/tensorrt/Debug/_deps/onnx-build" "../../build/tensorrt/Debug/_deps/protobuf-src/src" - "../../build/tensorrt/Debug/_deps/google_nsync-src/public") +) ## looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol target_link_libraries(TensorRTEp PUBLIC "/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so" @@ -35,5 +35,4 @@ target_link_libraries(TensorRTEp PUBLIC "/home/leca/code/onnxruntime/build/tenso "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/onnx-build/libonnx_proto.a" "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotobufd.a" "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotocd.a" - "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/google_nsync-build/libnsync_cpp.a" ) diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 953512234cf2c..2e61d1cb1b92e 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -366,9 +366,9 @@ TensorrtLogger& GetTensorrtLogger(bool verbose_log) { return trt_logger; } -std::unique_lock TensorrtExecutionProvider::GetApiLock() const { - static OrtMutex singleton; - return std::unique_lock(singleton); +std::unique_lock TensorrtExecutionProvider::GetApiLock() const { + static std::mutex singleton; + return std::unique_lock(singleton); } template @@ -2816,7 +2816,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; @@ -3472,7 +3472,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi // The whole compute_function should be considered the critical section. // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 4da00f4724116..9ba05e951615b 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -2,9 +2,9 @@ #include #include #include +#include #include "core/session/onnxruntime_c_api_ep.h" #include "core/framework/provider_options.h" -#include "core/platform/ort_mutex.h" #include "tensorrt_execution_provider_info.h" #include "nv_includes.h" @@ -162,7 +162,7 @@ struct TensorrtFuncState { std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; - OrtMutex* tensorrt_mu_ptr = nullptr; + std::mutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; @@ -208,7 +208,7 @@ struct TensorrtShortFuncState { std::vector> output_info; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; - OrtMutex* tensorrt_mu_ptr = nullptr; + std::mutex* tensorrt_mu_ptr = nullptr; }; using DDSOutputAllocatorMap = std::unordered_map>; @@ -237,7 +237,7 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { Every api call not in the thread-safe operations(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading) should be protected by a lock when invoked by multiple threads concurrently. */ - std::unique_lock GetApiLock() const; + std::unique_lock GetApiLock() const; /**Check the graph is the subgraph of control flow op*/ bool IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const; @@ -280,7 +280,7 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { std::string tactic_sources_; std::string global_cache_path_, cache_path_, engine_decryption_lib_path_; std::unique_ptr runtime_ = nullptr; - OrtMutex tensorrt_mu_; + std::mutex tensorrt_mu_; int device_id_; std::string compute_capability_; bool context_memory_sharing_enable_ = false; From aa498055a9f84541f577dbbec37a76f7c041570e Mon Sep 17 00:00:00 2001 From: jslhcl Date: Fri, 25 Oct 2024 17:07:30 -0700 Subject: [PATCH 57/81] openvino --- .../core/session/onnxruntime_c_api_ep.h | 10 +- .../core/providers/openvino/ov_interface.cc | 2 +- .../core/session/onnxruntime_c_api_ep.cc | 9 +- onnxruntime/core/session/ort_apis_ep.h | 4 +- samples/openvino/CMakeLists.txt | 19 + samples/openvino/backend_manager.cc | 476 ++++++++++++++++++ samples/openvino/backend_manager.h | 61 +++ samples/openvino/backend_utils.cc | 272 ++++++++++ samples/openvino/backend_utils.h | 75 +++ samples/openvino/contexts.h | 55 ++ samples/openvino/ibackend.h | 29 ++ samples/openvino/onnx_ctx_model_helper.cc | 138 +++++ samples/openvino/onnx_ctx_model_helper.h | 37 ++ .../openvino/openvino_execution_provider.cc | 87 ++++ .../openvino/openvino_execution_provider.h | 175 +++++++ samples/openvino/openvino_utils.cc | 25 + samples/openvino/openvino_utils.h | 12 + samples/openvino/ov_interface.cc | 253 ++++++++++ samples/openvino/ov_interface.h | 99 ++++ samples/openvino/ov_versions/capability.cc | 226 +++++++++ samples/openvino/ov_versions/capability.h | 32 ++ .../tensorrt_execution_provider_utils.h | 2 +- 22 files changed, 2084 insertions(+), 14 deletions(-) create mode 100644 samples/openvino/CMakeLists.txt create mode 100644 samples/openvino/backend_manager.cc create mode 100644 samples/openvino/backend_manager.h create mode 100644 samples/openvino/backend_utils.cc create mode 100644 samples/openvino/backend_utils.h create mode 100644 samples/openvino/contexts.h create mode 100644 samples/openvino/ibackend.h create mode 100644 samples/openvino/onnx_ctx_model_helper.cc create mode 100644 samples/openvino/onnx_ctx_model_helper.h create mode 100644 samples/openvino/openvino_execution_provider.cc create mode 100644 samples/openvino/openvino_execution_provider.h create mode 100644 samples/openvino/openvino_utils.cc create mode 100644 samples/openvino/openvino_utils.h create mode 100644 samples/openvino/ov_interface.cc create mode 100644 samples/openvino/ov_interface.h create mode 100644 samples/openvino/ov_versions/capability.cc create mode 100644 samples/openvino/ov_versions/capability.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index 4c3a9a03611ae..e6525a2c512c6 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -191,6 +191,7 @@ ORT_API2_STATUS(OrtGraph_GetAllInputs, const OrtGraphViewer* graph, _Outptr_ con */ ORT_API2_STATUS(OrtGraph_GetAllInitializers, const OrtGraphViewer* graph, _Outptr_ const char*** initializer_names, _Out_ size_t* initializer_len); +// TODO(leca): maybe OrtGraph_ReleaseCharArray? /** \brief Release the char array * * NOTE!!: Invoke this function after the use of OrtGraph_GetRequiredInputs, OrtGraph_GetAllInputs, OrtGraph_GetAllInitializers. @@ -277,10 +278,9 @@ ORT_API2_STATUS(OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i, * \param[in] graph The graph to query * \param[in] initializer_name The name of the initializer tensor * \param[out] out The initializer tensor - * \param[out] ret True if the initializer tensor is found * */ -ORT_API2_STATUS(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**, _Out_ bool* ret); +ORT_API2_STATUS(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); /** \brief Release the initializer tensor. * @@ -291,6 +291,9 @@ ORT_API2_STATUS(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, cons */ ORT_API2_STATUS(OrtGraph_ReleaseInitializerTensor, OrtTensorRef* tensor); +// TODO(leca): Do we need to define and expose OrtValueInfoRef? +// We can also encapsulate it, provide input/output index or name, return the properties of OrtValueInfoRef(shape, data_type) +// Just like OrtGraph_GetIthOutputElemType /** \brief Gets the value info of the node arg with the given name. * * NOTE!!: The caller is responsible for releasing the value info using OrtGraph_ReleaseValueInfo. @@ -298,10 +301,9 @@ ORT_API2_STATUS(OrtGraph_ReleaseInitializerTensor, OrtTensorRef* tensor); * \param[in] graph The graph to query * \param[in] name The name of the node arg * \param[out] out The value info - * \param[out] ret True if the value info is found * */ -ORT_API2_STATUS(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret); +ORT_API2_STATUS(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out); /** \brief Release the value info. * diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 8dd00857b7dd0..84129ea7ada5c 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -63,7 +63,7 @@ std::shared_ptr OVCore::ReadModel(const std::string& model, const std return FE->convert(inputModel); } else { ORT_THROW(log_tag + "[OpenVINO-EP] Unknown exception while Reading network"); - return NULL; + //return NULL; } } catch (const Exception& e) { ORT_THROW(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what())); diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index c83d5883aac61..fd5781bc1df49 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -149,12 +149,11 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetIthOutputElemType, const OrtGraphV return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** out, _Out_ bool* ret) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** out) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const onnx::TensorProto* initializer = nullptr; if (!graph_viewer->GetInitializedTensor(initializer_name, initializer)) { - *ret = false; - return nullptr; + return nullptr; // TODO(leca): not return nullptr for this case? } *out = new OrtTensorRef(); // TODO(leca): other datatypes in the following switch (*out)->shape_len = initializer->dims_size(); @@ -171,7 +170,6 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetInitializerTensor, const OrtGraphV (*out)->data = reinterpret_cast(initializer->float_data().data()); break; } - *ret = true; return nullptr; } @@ -192,7 +190,7 @@ static ONNXTensorElementDataType GetDataTypeFromTypeProto(const onnx::TypeProto* return static_cast(type->tensor_type().elem_type()); } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const NodeArg* node_arg = graph_viewer->GetNodeArg(name); @@ -204,7 +202,6 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetValueInfo, const OrtGraphViewer* g (*out)->shape = new int64_t [(*out)->shape_len]; for (size_t i = 0; i < (*out)->shape_len; i++) ((*out)->shape)[i] = utils::HasDimValue(dims[static_cast(i)]) ? dims[static_cast(i)].dim_value() : -1; - *ret = true; return nullptr; } diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index 196cab41264ca..24337a6bf652a 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -43,11 +43,11 @@ ORT_API_STATUS_IMPL(OrtGraph_GetIthOutputName, const OrtGraphViewer* graph, size ORT_API_STATUS_IMPL(OrtGraph_GetIthOutputElemType, const OrtGraphViewer*, size_t i, _Out_ int32_t* out); -ORT_API_STATUS_IMPL(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** tensor, _Out_ bool* ret); +ORT_API_STATUS_IMPL(OrtGraph_GetInitializerTensor, const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef** tensor); ORT_API_STATUS_IMPL(OrtGraph_ReleaseInitializerTensor, OrtTensorRef* tensor); -ORT_API_STATUS_IMPL(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out, _Out_ bool* ret); +ORT_API_STATUS_IMPL(OrtGraph_GetValueInfo, const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef** out); ORT_API_STATUS_IMPL(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info); diff --git a/samples/openvino/CMakeLists.txt b/samples/openvino/CMakeLists.txt new file mode 100644 index 0000000000000..1a1deba629b35 --- /dev/null +++ b/samples/openvino/CMakeLists.txt @@ -0,0 +1,19 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DOPENVINO_HOME="C:/Program Files (x86)/Intel/openvino_2024.4.0/runtime" +# cmake --build . +cmake_minimum_required(VERSION 3.26) +project(TensorRTEp VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) + +find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) +list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime) + +file(GLOB openvino_src "./*.cc" "./ov_versions/*.cc") +add_library(OpenVINOEp SHARED ${openvino_src}) +target_include_directories(OpenVINOEp PUBLIC "../../include/onnxruntime" + ${OPENVINO_HOME}/include +) +target_link_libraries(OpenVINOEp PUBLIC "C:/Users/leca/source/onnxruntime/build/Windows/Debug/Debug/onnxruntime.lib" + ${OPENVINO_LIB_LIST} +) diff --git a/samples/openvino/backend_manager.cc b/samples/openvino/backend_manager.cc new file mode 100644 index 0000000000000..896aea8830b54 --- /dev/null +++ b/samples/openvino/backend_manager.cc @@ -0,0 +1,476 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "contexts.h" +#include "backend_manager.h" +#include "ibackend.h" +#include "backend_utils.h" +//#include "core/providers/openvino/qdq_transformations/qdq_stripping.h" + +namespace onnxruntime { +namespace openvino_ep { + +//GlobalContext& BackendManager::GetGlobalContext() { +// return global_context_; +//} +// +//BackendManager::BackendManager(const GlobalContext& global_context, +// const onnxruntime::Node& fused_node, +// const onnxruntime::GraphViewer& subgraph, +// const logging::Logger& logger, +// EPCtxHandler& ctx_handle) { +// global_context_ = global_context; +// ep_ctx_handle_ = ctx_handle; +// +// openvino_sdk_version_ = std::to_string(global_context_.OpenVINO_Version.at(0)) + "." + +// std::to_string(global_context_.OpenVINO_Version.at(1)); +// if (ep_ctx_handle_.CheckForOVEPCtxNode(subgraph, openvino_sdk_version_)) { +// if (ep_ctx_handle_.ImportBlobFromEPCtxModel(subgraph) != Status::OK()) +// ORT_THROW("Import blob from model failed"); +// } +// +// // Save the indexes of graph inputs among fused_node's inputDefs +// // (which also contains initializers). +// auto node_input_defs = fused_node.InputDefs(); +// int i = 0; +// for (auto idef : node_input_defs) { +// subgraph_context_.input_names.insert({idef->Name(), i}); +// i++; +// } +// +// const std::vector& graph_inputs = subgraph.GetInputs(); +// for (auto input : graph_inputs) { +// auto it = subgraph_context_.input_names.find(input->Name()); +// if (it == subgraph_context_.input_names.end()) { +// ORT_THROW("Input not found in the input defs list"); +// } +// int index = it->second; +// subgraph_context_.input_indexes.push_back(index); +// } +// +// auto graph_outputs_defs = fused_node.OutputDefs(); +// i = 0; +// for (auto output_def : graph_outputs_defs) { +// subgraph_context_.output_names.insert({output_def->Name(), i}); +// i++; +// } +// subgraph_context_.subgraph_name = fused_node.Name(); +// model_proto_ = GetModelProtoFromFusedNode(fused_node, subgraph, logger); +// std::string device_type = openvino_ep::BackendManager::GetGlobalContext().device_type; +// +// if (ModelHasSymbolicInputDims(subgraph)) { +// subgraph_context_.has_dynamic_input_shape = true; +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; +// ORT_ENFORCE(!global_context_.enable_qdq_optimizer, +// "QDQ stripping should not be enabled for models with dynamic input shapes. " +// "Set enable_qdq_optimizer to False"); +// if (GetGlobalContext().device_type.find("CPU") != std::string::npos || +// GetGlobalContext().device_type.find("GPU") != std::string::npos) { +// if (!GetGlobalContext().disable_dynamic_shapes) { +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " +// << "Creating backend Dynamic Shapes"; +// try { +// concrete_backend_ = BackendFactory::MakeBackend(*model_proto_, +// GetGlobalContext(), +// subgraph_context_, +// ep_ctx_handle_); +// } catch (std::string const& msg) { +// ORT_THROW(msg); +// } +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " +// << "Backend created for graph " << subgraph_context_.subgraph_name; +// } +// } +// } else { +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. " +// << "Initializing backend for graph " +// << subgraph_context_.subgraph_name; +// +// subgraph_context_.has_dynamic_input_shape = false; +// +// // OV NPU plugin is supported with fallback to OV CPU upon compilation failures. +// try { +// concrete_backend_ = BackendFactory::MakeBackend(*model_proto_, +// GetGlobalContext(), +// subgraph_context_, +// ep_ctx_handle_); +// } catch (const OnnxRuntimeException& ex) { +//#if defined(OPENVINO_DISABLE_NPU_FALLBACK) +// ORT_THROW(ex.what()); +//#else +// if (device_type.find("NPU") != std::string::npos && +// !GetGlobalContext().disable_cpu_fallback) { +// LOGS_DEFAULT(WARNING) << ex.what(); +// LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." +// << "Falling back to OV CPU for execution"; +// GetGlobalContext().device_type = "CPU"; +// GetGlobalContext().precision_str = "FP32"; +// try { +// concrete_backend_ = BackendFactory::MakeBackend(*model_proto_, +// GetGlobalContext(), +// subgraph_context_, +// ep_ctx_handle_); +// } catch (std::string const& msg) { +// ORT_THROW(msg); +// } +// } else { +// ORT_THROW(ex.what()); +// } +//#endif +// } +// } +//} +// +//// Call EPContext model exporter here if the provider option for exporting +//// precompiled blob is set. If that's the case: +//// By default, create model in embed mode where the blob stream is exported as data within +//// the EPContext node. +//Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& graph_body_viewer, +// const logging::Logger& logger) { +// if (GetGlobalContext().disable_dynamic_shapes && subgraph_context_.has_dynamic_input_shape) { +// std::string exception_str = +// "Exporting dynamically compiled models at runtime is not supported. " +// "Cannot export blobs of dynamic models that request static shape inference. " +// "To export this model, set disable_dynamic_shapes to False"; +// ORT_THROW(exception_str); +// } +// +// std::string model_blob_str; +// auto compiled_model = concrete_backend_->GetOVCompiledModel(); +// auto graph_name = global_context_.onnx_model_path_name; +// // Remove extension so we can append suffix to form the complete name of output graph +// graph_name = [&]() { +// size_t dot = graph_name.find_last_of("."); +// if (dot == std::string::npos) return graph_name; +// return graph_name.substr(0, dot); +// }(); +// // If embed_mode, then pass on the serialized blob +// // If not embed_mode, dump the blob here and only pass on the path to the blob +// if (global_context_.ep_context_embed_mode) { +// std::ostringstream model_blob_stream; +// compiled_model.export_model(model_blob_stream); +// model_blob_str = model_blob_stream.str(); +// ORT_ENFORCE(model_blob_str.size() != 0); +// } else { +// std::ofstream f(graph_name + ".blob", std::ios::out | std::ios::trunc | std::ios::binary); +// compiled_model.export_model(f); +// model_blob_str = graph_name + ".blob"; +// } +// +// ORT_RETURN_IF_ERROR(ep_ctx_handle_.ExportEPCtxModel(graph_body_viewer, +// graph_name, +// logger, +// global_context_.ep_context_embed_mode, +// model_blob_str, +// openvino_sdk_version_, +// GetGlobalContext().device_type)); +// +// return Status::OK(); +//} +// +//bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { +// bool has_batched_inputs = true; +// +// for (int i = 0; i < static_cast(subgraph_context_.input_indexes.size()); i++) { +// auto& input = model_proto.graph().input(subgraph_context_.input_indexes[i]); +// +// // Batch-process only raw image inputs (NCHW or NHWC layouts) +// auto& shape = input.type().tensor_type().shape(); +// if (shape.dim_size() != 4) { +// has_batched_inputs = false; +// break; +// } +// +// if (shape.dim(0).value_case() == shape.dim(0).kDimValue) { +// has_batched_inputs = false; +// break; +// } +// +// for (int index = 1; index < 4; index++) { +// if (shape.dim(index).value_case() != shape.dim(0).kDimValue) { +// has_batched_inputs = false; +// break; +// } +// } +// if (!has_batched_inputs) { +// break; +// } +// } +// return has_batched_inputs; +//} +// +//bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const { +// bool has_sym_dims = false; +// auto graph_inputs = subgraph.GetInputs(); +// for (auto input : graph_inputs) { +// if (input->Shape() == nullptr) { +// has_sym_dims = true; +// break; +// } +// for (auto& dim : input->Shape()->dim()) { +// if (dim.value_case() != dim.kDimValue) { +// has_sym_dims = true; +// break; +// } +// } +// if (has_sym_dims) { +// break; +// } +// } +// return has_sym_dims; +//} +// +//// Check to see if the graph is QDQ +//static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { +// std::unordered_set qdq_ops = {"QuantizeLinear", "DequantizeLinear"}; +// const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); +// +// for (size_t i = 0; i < node_indices.size(); i++) { +// gsl::not_null node(graph_viewer.GetNode(node_indices[i])); +// if (qdq_ops.find(node->OpType()) != qdq_ops.end()) { +// return true; +// } +// } +// return false; +//} +// +//static void DumpOpenVINOEPModel(std::string onnx_model_path_name, +// ONNX_NAMESPACE::ModelProto* model_proto, +// const onnxruntime::Node& fused_node) { +// if (openvino_ep::backend_utils::IsDebugEnabled()) { +// auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name; +//#ifdef _WIN32 +// size_t slash = model_name.find_last_of("\\"); +//#else +// size_t slash = model_name.find_last_of("/"); +//#endif +// model_name = model_name.substr(slash + 1, std::string::npos); +// size_t dot = model_name.find_last_of("."); +// model_name = model_name.substr(0, dot); +// +// std::string subgraph_name = fused_node.Name(); +// size_t dash = subgraph_name.find_last_of("-"); +// subgraph_name = subgraph_name.substr(dash, std::string::npos); +// +// const std::string name = model_name + subgraph_name + ".onnx"; +// +// std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); +// model_proto->SerializeToOstream(dump); +// } +//} +// +//std::unique_ptr +//BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, +// const onnxruntime::GraphViewer& subgraph, +// const logging::Logger& logger) const { +// std::chrono::time_point model_proto_create_start_, model_proto_create_end_; +// if (openvino_ep::backend_utils::IsDebugEnabled()) { +// model_proto_create_start_ = std::chrono::high_resolution_clock::now(); +// } +// +// auto print_model_proto_duration = [&]() { +// if (openvino_ep::backend_utils::IsDebugEnabled()) { +// model_proto_create_end_ = std::chrono::high_resolution_clock::now(); +// auto model_proto_create_duration = +// std::chrono::duration_cast( +// model_proto_create_end_ - model_proto_create_start_) +// .count(); +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model Proto creation took: " << model_proto_create_duration << " ms."; +// } +// }; +// +// // QDQ stripping enabled only for the NPU +// if (global_context_.device_type.find("NPU") != std::string::npos && +// global_context_.enable_qdq_optimizer && +// IsQDQGraph(subgraph)) { +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] QDQ optimization pass status: 1"; +// std::unique_ptr model; +// Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, model); +// auto model_proto = model->ToProto(); +// model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); +// print_model_proto_duration(); +// DumpOpenVINOEPModel(global_context_.onnx_model_path_name, model_proto.get(), fused_node); +// ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); +// return model_proto; +// } else { +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] QDQ optimization pass status: 0"; +// auto model = subgraph.CreateModel(logger); +// auto model_proto = model->ToProto(); +// model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); +// subgraph.ToProto(*model_proto->mutable_graph(), true, true); +// print_model_proto_duration(); +// DumpOpenVINOEPModel(global_context_.onnx_model_path_name, model_proto.get(), fused_node); +// return model_proto; +// } +//} + +std::vector> GetInputTensorShapes(const Ort::KernelContext& context) { + const auto input_count = context.GetInputCount(); + std::vector> input_shapes; + input_shapes.reserve(input_count); + for (size_t i = 0; i < input_count; i++) { + auto input_tensor = context.GetInput(i); + auto tensor_shape = input_tensor.GetTensorTypeAndShapeInfo().GetShape(); + input_shapes.push_back(std::move(tensor_shape)); + } + return input_shapes; +} + +std::string MakeMapKeyString(const std::vector>& shapes, + const std::string& device_type) { + std::string key; + key += device_type; + key += "|"; // separator + for (auto shape : shapes) { + for (auto dim : shape) { + std::ostringstream o; + o << dim; + key += o.str(); + key += ","; + } + key += "|"; + } + return key; +} + +//std::shared_ptr +//BackendManager::ReWriteInputShapeInfo(const ONNX_NAMESPACE::ModelProto& model_proto, +// const std::vector>& input_shapes) { +// auto model_copy = std::shared_ptr(ONNX_NAMESPACE::ModelProto::Create()); +// std::string proto_str; +// model_proto.SerializeToString(proto_str); +// model_copy->ParseFromString(proto_str); +// auto graph_proto = model_copy->mutable_graph(); +// +// for (size_t i = 0, limit = input_shapes.size(); i < limit; i++) { +// auto g_in_shape = graph_proto->mutable_input(static_cast(i)) +// ->mutable_type() +// ->mutable_tensor_type() +// ->mutable_shape(); +// g_in_shape->clear_dim(); +// const auto& shape = input_shapes[i]; +// for (size_t dim = 0, end = shape.size(); dim < end; dim++) { +// g_in_shape->add_dim()->set_dim_value(shape[dim]); +// } +// } +// return model_copy; +//} +// +//std::shared_ptr +//BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto) { +// auto model_copy = std::shared_ptr(ONNX_NAMESPACE::ModelProto::Create()); +// std::string proto_str; +// model_proto.SerializeToString(proto_str); +// model_copy->ParseFromString(proto_str); +// auto graph_proto = model_copy->mutable_graph(); +// +// for (int i = 0; i < graph_proto->input_size(); i++) { +// ONNX_NAMESPACE::TensorShapeProto* g_in_shape = +// graph_proto->mutable_input(static_cast(i)) +// ->mutable_type() +// ->mutable_tensor_type() +// ->mutable_shape(); +// g_in_shape->mutable_dim(0)->clear_dim_value(); +// g_in_shape->mutable_dim(0)->set_dim_value(1); +// } +// return model_copy; +//} +// +//void BackendManager::Compute(OrtKernelContext* context) { +// Ort::KernelContext ctx(context); +// std::chrono::high_resolution_clock::time_point start_compute, end_compute; +//#ifdef OPENVINO_FIL_ENABLED +// static bool fil_enabled = true; +// if (fil_enabled) { +// start_compute = std::chrono::high_resolution_clock::now(); +// LOGS_DEFAULT(INFO) << "Start Compute"; +// } +//#endif +// // OV NPU doesn't support dynamic shaped model inference. +// // if disable_dynamic_shapes is set to true then execution of dynamic model is done +// // by rewriting the model to static shaped model at runtime based on input shape. +// // disable_dynamic_shapes is always set to true for OV NPU plugin. +// bool use_dynamic_backend = true; +// if (subgraph_context_.has_dynamic_input_shape && +// !GetGlobalContext().disable_dynamic_shapes && +// (GetGlobalContext().device_type.find("CPU") != std::string::npos || +// GetGlobalContext().device_type.find("GPU") != std::string::npos)) { +// concrete_backend_->Infer(context); +// use_dynamic_backend = false; +// } else if (use_dynamic_backend && subgraph_context_.has_dynamic_input_shape) { +// std::vector> tensor_shapes = GetInputTensorShapes(ctx); +// auto key = MakeMapKeyString(tensor_shapes, GetGlobalContext().device_type); +// std::shared_ptr dynamic_backend; +// auto search = backend_map_.find(key); +// if (search == backend_map_.end()) { +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " +// << "Creating dynamic backend for key: " << key; +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " +// << "Backend created for graph " << subgraph_context_.subgraph_name; +// auto modelproto_with_concrete_shapes = ReWriteInputShapeInfo(*model_proto_, tensor_shapes); +// try { +// dynamic_backend = BackendFactory::MakeBackend(*modelproto_with_concrete_shapes, +// GetGlobalContext(), +// subgraph_context_, +// ep_ctx_handle_); +// } catch (const OnnxRuntimeException& ex) { +// // Build option disables fallback to CPU on compilation failures with NPU. +//#if defined(OPENVINO_DISABLE_NPU_FALLBACK) +// LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."; +// ORT_THROW(ex.what()); +//#else +// if (GetGlobalContext().device_type.find("NPU") != std::string::npos && +// !GetGlobalContext().disable_cpu_fallback) { +// LOGS_DEFAULT(WARNING) << ex.what(); +// LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." +// << "Falling back to OV CPU for execution"; +// GetGlobalContext().device_type = "CPU"; +// GetGlobalContext().precision_str = "FP32"; +// key = MakeMapKeyString(tensor_shapes, GetGlobalContext().device_type); +// try { +// dynamic_backend = BackendFactory::MakeBackend(*modelproto_with_concrete_shapes, +// GetGlobalContext(), +// subgraph_context_, +// ep_ctx_handle_); +// } catch (std::string const& msg) { +// ORT_THROW(msg); +// } +// } else { +// ORT_THROW(ex.what()); +// } +//#endif +// } +// backend_map_.insert({key, dynamic_backend}); +// } else { +// dynamic_backend = search->second; +// } +// +// dynamic_backend->Infer(context); +// } else { +// concrete_backend_->Infer(context); +// } +//#ifdef OPENVINO_FIL_ENABLED +// if (fil_enabled) { +// end_compute = std::chrono::high_resolution_clock::now(); +// LOGS_DEFAULT(INFO) << "End Compute"; +// std::chrono::duration compute_time = end_compute - start_compute; +// std::cout << "Compute Time: " << compute_time.count() << " s" << std::endl; +// fil_enabled = false; // calculating compute time for first run only +// } +//#endif +//} +// +//void BackendManager::ShutdownBackendManager() { +//} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/backend_manager.h b/samples/openvino/backend_manager.h new file mode 100644 index 0000000000000..a057c10377a01 --- /dev/null +++ b/samples/openvino/backend_manager.h @@ -0,0 +1,61 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include + +#include "ov_interface.h" +#include "contexts.h" +#include "onnx_ctx_model_helper.h" +#include "ibackend.h" + +//namespace onnxruntime { +//namespace openvino_ep { +// +//// Singleton class that manages all the backends +//class BackendManager { +// public: +// BackendManager(const GlobalContext& global_context, +// const onnxruntime::Node& fused_node, +// const onnxruntime::GraphViewer& subgraph, +// const logging::Logger& logger, +// EPCtxHandler& ctx_handle); +// void Compute(OrtKernelContext* context); +// void ShutdownBackendManager(); +// void SetGlobalCotext(const GlobalContext& global_context); +// GlobalContext& GetGlobalContext(); +// Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, +// const logging::Logger& logger); +// +// private: +// std::unique_ptr GetModelProtoFromFusedNode( +// const onnxruntime::Node& fused_node, +// const onnxruntime::GraphViewer& subgraph, +// const logging::Logger& logger) const; +// +// bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; +// bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; +// +// std::shared_ptr +// ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto); +// +// std::shared_ptr +// ReWriteInputShapeInfo(const ONNX_NAMESPACE::ModelProto& model_proto, +// const std::vector>& input_shapes); +// +// std::unique_ptr model_proto_; +// std::shared_ptr concrete_backend_; +// std::map> backend_map_; +// SubGraphContext subgraph_context_; +// GlobalContext global_context_; +// EPCtxHandler ep_ctx_handle_{}; +// std::string openvino_sdk_version_{}; +//}; +// +//} // namespace openvino_ep +//} // namespace onnxruntime +// diff --git a/samples/openvino/backend_utils.cc b/samples/openvino/backend_utils.cc new file mode 100644 index 0000000000000..62386e9fe4b7a --- /dev/null +++ b/samples/openvino/backend_utils.cc @@ -0,0 +1,272 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include +#include +#include +#include + +#include "openvino/pass/convert_fp32_to_fp16.hpp" +#include "openvino/pass/constant_folding.hpp" +#include "backend_utils.h" +#include "ov_interface.h" +#include "openvino_utils.h" + +using Exception = ov::Exception; + +namespace onnxruntime { +namespace openvino_ep { +namespace backend_utils { + +bool IsDebugEnabled() { + const std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_DEBUG"); + if (!env_name.empty()) { + return true; + } + return false; +} + +bool IsCILogEnabled() { + const std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_CI_LOG"); + if (!env_name.empty()) { + return true; + } + return false; +} + +struct static_cast_int64 { + template // T1 models type statically convertible to T + int64_t operator()(const T1& x) const { return static_cast(x); } +}; + +//std::shared_ptr +//CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, +// std::map>& const_outputs_map) { +// if (IsCILogEnabled()) { +// std::cout << "CreateNgraphFunc" << std::endl; +// } +// const std::string model = model_proto.SerializeAsString(); +// try { +// auto cnn_network = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name); +// +// // Check for Constant Folding +// if (!global_context.is_wholly_supported_graph) { +// ov::pass::ConstantFolding pass_const_obj; +// pass_const_obj.run_on_model(cnn_network); +// auto& results = const_cast(cnn_network.get()->get_results()); +// size_t index = results.size() - 1; +// +// for (auto it = results.rbegin(); it != results.rend(); ++it) { +// if (auto const_node = +// std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { +// const_outputs_map[(*it)->get_friendly_name()] = const_node; +// results.erase(results.begin() + index); +// } +// --index; +// } +// } +//#ifndef NDEBUG +// if (IsDebugEnabled()) { +// std::string name = cnn_network->get_friendly_name(); +// ov::pass::Serialize serializer(name + ".xml", name + ".bin"); +// serializer.run_on_model(cnn_network); +// } +//#endif +// return cnn_network; +// } catch (std::string const& msg) { +// throw std::runtime_error(msg); +// } +//} + +Ort::UnownedValue +GetOutputTensor(Ort::KernelContext& context, size_t batch_size, + OVInferRequestPtr infer_request, + std::string output_name, + std::unordered_map output_names) { + auto graph_output_blob = infer_request->GetTensor(output_name); + + auto graph_output_dims = graph_output_blob->get_shape(); + + if (batch_size > 1) { + // Add the batch size as dim 0. + graph_output_dims.insert(graph_output_dims.begin(), batch_size); + } + size_t num_dims = graph_output_dims.size(); + std::unique_ptr output_shape(new int64_t[num_dims]); + for (size_t j = 0; j < num_dims; j++) { + output_shape[j] = static_cast(graph_output_dims[j]); + } + auto it = output_names.find(output_name); + if (it == output_names.end()) { + throw std::runtime_error(log_tag + "Output names mismatch between OpenVINO and ONNX"); + } + int index = it->second; + return context.GetOutput(index, output_shape.get(), num_dims); +} + +Ort::UnownedValue +GetOutputTensor(Ort::KernelContext& context, + std::string output_name, + std::unordered_map output_names, + std::shared_ptr node) { + // Find position of '/' in the output_name + int pos = output_name.find("/"); + // Copy the substring from start to pos + output_name = output_name.substr(0, pos); + + auto it = output_names.find(output_name); + if (it == output_names.end()) { + throw std::runtime_error(log_tag + "Output names mismatch between OpenVINO and ONNX"); + } + int index = it->second; + auto shape = node->get_shape(); + + size_t num_dims = shape.size(); + std::unique_ptr output_shape(new int64_t[num_dims]); + for (size_t j = 0; j < num_dims; j++) { + output_shape[j] = static_cast(shape[j]); + } + return context.GetOutput(index, output_shape.get(), num_dims); +} + +int GetFirstAvailableDevice(GlobalContext& global_context) { + int i = 0; + // Get the first available VAD-M device and set the device to busy + while (i < 8) { + bool device = global_context.deviceAvailableList[i]; + if (device) { + global_context.deviceAvailableList[i] = false; + break; + } + i++; + } + // If all of the devices are busy, assign the first device and + // make all remaining devices free + if (i == 8) { + i = 0; + global_context.deviceAvailableList[i] = false; + for (int j = 1; j < 8; j++) { + global_context.deviceAvailableList[j] = true; + } + } + return i; +} + +void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor) { + switch (node->get_element_type()) { + case ov::element::Type_t::f32: { + FillOutputHelper(out_tensor, node); + break; + } + case ov::element::Type_t::boolean: { + FillOutputHelper(out_tensor, node); + break; + } + case ov::element::Type_t::i32: { + FillOutputHelper(out_tensor, node); + break; + } + case ov::element::Type_t::i64: { + FillOutputHelper(out_tensor, node); + break; + } + case ov::element::Type_t::f16: { + FillOutputHelper(out_tensor, node); + break; + } + default: + throw std::runtime_error(log_tag + "Unsupported output data type"); + } +} + +#if defined(_MSC_VER) +#pragma warning(disable : 4127) +#endif + +template +void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node) { + auto const_node = std::dynamic_pointer_cast(node); + auto res = const_node->cast_vector(); + T* tensor_data = out_tensor.GetTensorMutableData(); + std::copy(res.begin(), res.end(), tensor_data); +} + +#if defined(_MSC_VER) +#pragma warning(default : 4127) +#endif + +void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, + std::string input_name, Ort::KernelContext& context, + const SubGraphContext& subgraph_context) { + size_t input_data_size = inputBlob->get_byte_size(); + auto input_data = inputBlob->data(); + auto tensor = context.GetInput(subgraph_context.input_names.at(input_name)); + auto mem_info = tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + throw std::runtime_error(log_tag + "IO Buffering is not enabled, Please enable Input on CPU"); + } + // Copy input data into OpenVINO's input buffer + const char* tensor_data = tensor.GetTensorData(); + const char* batch_memory_offset = tensor_data + input_data_size * batch_slice_idx; + std::memcpy(input_data, batch_memory_offset, input_data_size); +} + +void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, + size_t batch_slice_idx) { + auto output_data = outputBlob->data(); + size_t output_data_size = outputBlob->get_byte_size(); + char* tensor_data = output_tensor.GetTensorMutableData(); + char* batch_memory_offset = tensor_data + output_data_size * batch_slice_idx; + std::memcpy(batch_memory_offset, output_data, output_data_size); +} + +void printPerformanceCounts(const std::vector& performanceMap, + std::ostream& stream, std::string deviceName) { + int64_t totalTime = 0; + // Print performance counts + stream << std::endl + << "performance counts:" << std::endl + << std::endl; + + for (const auto& it : performanceMap) { + std::string toPrint(it.node_name); + const int maxLayerName = 30; + + if (it.node_name.length() >= maxLayerName) { + toPrint = it.node_name.substr(0, maxLayerName - 4); + toPrint += "..."; + } + stream << std::setw(maxLayerName) << std::left << toPrint; + switch (it.status) { + case OVProfilingInfo::Status::EXECUTED: + stream << std::setw(15) << std::left << "EXECUTED"; + break; + case OVProfilingInfo::Status::NOT_RUN: + stream << std::setw(15) << std::left << "NOT_RUN"; + break; + case OVProfilingInfo::Status::OPTIMIZED_OUT: + stream << std::setw(15) << std::left << "OPTIMIZED_OUT"; + break; + } + stream << std::setw(30) << std::left << "layerType: " + std::string(it.node_type) + " "; + stream << std::setw(20) << std::left << "realTime: " + std::to_string(it.real_time.count()); + stream << std::setw(20) << std::left << "cpu: " + std::to_string(it.cpu_time.count()); + stream << " execType: " << it.exec_type << std::endl; + if (it.real_time.count() > 0) { + totalTime += it.real_time.count(); + } + } + stream << std::setw(20) << "Total time: " + std::to_string(totalTime) << " microseconds" << std::endl; + std::cout << std::endl; + std::cout << "Full device name: " << deviceName << std::endl; + std::cout << std::endl; +} + +void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName) { + auto performanceMap = request->GetNewObj().get_profiling_info(); + printPerformanceCounts(performanceMap, stream, std::move(deviceName)); +} + +} // namespace backend_utils +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/backend_utils.h b/samples/openvino/backend_utils.h new file mode 100644 index 0000000000000..c700f86f9c0f7 --- /dev/null +++ b/samples/openvino/backend_utils.h @@ -0,0 +1,75 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#define ORT_API_MANUAL_INIT +#include +#include +#include +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "contexts.h" +#include "ov_interface.h" +#ifdef _WIN32 +#include +#define GetCurrentDir _getcwd +#else +#include +#define GetCurrentDir getcwd +#endif + +#include + +namespace onnxruntime { +namespace openvino_ep { +namespace backend_utils { +const std::string log_tag = "[OpenVINO-EP] "; + +bool IsDebugEnabled(); + +// Internal diagnostic function. +bool IsCILogEnabled(); + +int GetFirstAvailableDevice(GlobalContext& global_context); + +void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor); + +template +void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node); + +Ort::UnownedValue +GetOutputTensor(Ort::KernelContext& context, + std::string output_name, + std::unordered_map output_names, + std::shared_ptr node); + +Ort::UnownedValue +GetOutputTensor(Ort::KernelContext& context, size_t batch_size, + OVInferRequestPtr infer_request, + std::string output_name, + std::unordered_map output_names); + +void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, + std::string input_name, Ort::KernelContext& context, + const SubGraphContext& subgraph_context); + +void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, + size_t batch_slice_idx); + +//std::shared_ptr +//CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, +// const GlobalContext& global_context, +// std::map>& const_outputs_map); + +void printPerformanceCounts(const std::vector& performanceMap, + std::ostream& stream, std::string deviceName); + +void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName); + +} // namespace backend_utils +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/contexts.h b/samples/openvino/contexts.h new file mode 100644 index 0000000000000..8f549d5ac1627 --- /dev/null +++ b/samples/openvino/contexts.h @@ -0,0 +1,55 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include "ov_interface.h" + +namespace onnxruntime { +namespace openvino_ep { + +// Holds context applicable to the entire EP instance. +struct GlobalContext { + OVCore ie_core; + bool is_wholly_supported_graph = false; + bool enable_npu_fast_compile = false; + bool enable_opencl_throttling = false; + bool disable_dynamic_shapes = false; + bool ep_context_embed_mode = true; + bool export_ep_ctx_blob = false; + bool enable_qdq_optimizer = false; + bool disable_cpu_fallback = false; + size_t num_of_threads; + std::string device_type; + std::string precision_str; + std::string model_precision; + std::string cache_dir; + std::string model_priority = "DEFAULT"; + int num_streams; + std::vector deviceAvailableList = {true, true, true, true, true, true, true, true}; + std::string onnx_model_name; + std::string onnx_model_path_name; +// int onnx_opset_version; + void* context = 0; + bool use_api_2; + std::vector OpenVINO_Version = {}; // Ov Major and OV minor version from OV headers +}; + +// Holds context specific to subgraph. +struct SubGraphContext { + bool has_dynamic_input_shape = false; + bool enable_batching = false; + bool set_npu_config = false; + bool is_constant = false; + void* context = 0; + std::string subgraph_name; + std::vector input_indexes; + std::unordered_map input_names; + std::unordered_map output_names; +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/ibackend.h b/samples/openvino/ibackend.h new file mode 100644 index 0000000000000..8e54f7d1cb5d4 --- /dev/null +++ b/samples/openvino/ibackend.h @@ -0,0 +1,29 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include "core/session/onnxruntime_cxx_api.h" +#include "onnx_ctx_model_helper.h" + +//namespace onnxruntime { +//namespace openvino_ep { +// +//class IBackend { +// public: +// virtual void Infer(OrtKernelContext* context) = 0; +// virtual ov::CompiledModel& GetOVCompiledModel() = 0; +//}; +// +//class BackendFactory { +// public: +// static std::shared_ptr +// MakeBackend(const ONNX_NAMESPACE::ModelProto& model_proto, +// GlobalContext& global_context, +// const SubGraphContext& subgraph_context, +// EPCtxHandler& ctx_handle); +//}; +// +//} // namespace openvino_ep +//} // namespace onnxruntime diff --git a/samples/openvino/onnx_ctx_model_helper.cc b/samples/openvino/onnx_ctx_model_helper.cc new file mode 100644 index 0000000000000..ec72897546582 --- /dev/null +++ b/samples/openvino/onnx_ctx_model_helper.cc @@ -0,0 +1,138 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include +#include +#include + +#include "onnx_ctx_model_helper.h" +#include "openvino_utils.h" + +namespace onnxruntime { +namespace openvino_ep { + +const OrtGraphApi* EPCtxHandler::graph_api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION)->GetGraphApi(ORT_API_VERSION); + +// Utilities to handle EPContext node export and parsing of an EPContext node +// to create the compiled_model object to infer on +static const char EPCONTEXT_OP[] = "EPContext"; +static const char EMBED_MODE[] = "embed_mode"; +static const char EP_CACHE_CONTEXT[] = "ep_cache_context"; +static const char EP_SDK_VER[] = "ep_sdk_version"; +static const char SOURCE[] = "source"; + +/* Export the serialized blob string embedded onto an EPContext Node + * along with other metadata necessary to validate the graph on import + */ + +//Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, +// const std::string& graph_name, +// const logging::Logger& logger, +// const bool& ep_context_embed_mode, +// const std::string& model_blob_str, +// const std::string& openvino_sdk_version, +// const std::string& device_type) const { +// auto model_build = graph_viewer.CreateModel(logger); +// auto& graph_build = model_build->MainGraph(); +// +// // Get graph inputs and outputs +// std::vector inputs, outputs; +// for (auto input : graph_viewer.GetInputs()) { +// auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); +// inputs.push_back(&n_input); +// } +// for (auto output : graph_viewer.GetOutputs()) { +// auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); +// outputs.push_back(&n_output); +// } +// +// // Create EP context node attributes +// auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); +// auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); +// auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); +// auto attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); +// +// // embed mode +// attr_0->set_name(EMBED_MODE); +// attr_0->set_type(onnx::AttributeProto_AttributeType_INT); +// attr_0->set_i(ep_context_embed_mode); +// // ep context +// attr_1->set_name(EP_CACHE_CONTEXT); +// attr_1->set_type(onnx::AttributeProto_AttributeType_STRING); +// attr_1->set_s(model_blob_str); +// // sdk version +// attr_2->set_name(EP_SDK_VER); +// attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); +// attr_2->set_s(openvino_sdk_version); +// // source +// attr_3->set_name(SOURCE); +// attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); +// attr_3->set_s(kOpenVINOExecutionProvider); +// +// auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); +// node_attributes->reserve(4); +// node_attributes->emplace(EMBED_MODE, *attr_0); +// node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1); +// node_attributes->emplace(EP_SDK_VER, *attr_2); +// node_attributes->emplace(SOURCE, *attr_3); +// +// // Create EP context node +// graph_build.AddNode(graph_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), kMSDomain); +// ORT_ENFORCE(graph_build.Resolve().IsOK()); +// +// // Serialize modelproto to string +// auto new_graph_viewer = graph_build.CreateGraphViewer(); +// auto model = new_graph_viewer->CreateModel(logger); +// auto model_proto = model->ToProto(); +// new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); +// model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); +// +// // Finally, dump the model +// std::ofstream dump(graph_name + "-ov_" + device_type + "_blob.onnx", +// std::ios::out | std::ios::trunc | std::ios::binary); +// model_proto->SerializeToOstream(dump); +// +// LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Export blob as EPContext Node"; +// +// return Status::OK(); +//} +// +//Status EPCtxHandler::ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer) { +// auto node = graph_viewer.GetNode(0); +// auto& attrs = node->GetAttributes(); +// ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) > 0); +// +// model_stream_ = std::make_shared(attrs.at(EP_CACHE_CONTEXT).s()); +// +// LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; +// +// is_valid_ep_ctx_graph_ = true; +// return Status::OK(); +//} + +bool EPCtxHandler::CheckForOVEPCtxNode(const OrtGraphViewer* graph_viewer, std::string openvino_sdk_version) const { + int max_node_index = 0; + graph_api_->OrtGraph_MaxNodeIndex(graph_viewer, &max_node_index); + for (int i = 0; i < max_node_index; ++i) { + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer, i, &node); + if (node != nullptr) { + const char* node_op_type = nullptr; + graph_api_->OrtNode_GetOpType(node, &node_op_type); + if (!strcmp(node_op_type, EPCONTEXT_OP)) { + const char* source_val = nullptr, *ep_sdk_ver_val = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, SOURCE, &source_val); + if (!strcmp(source_val, OpenVINOEp.c_str())) { + graph_api_->OrtNode_GetAttributeStr(node, EP_SDK_VER, &ep_sdk_ver_val); + if (!strcmp(ep_sdk_ver_val, openvino_sdk_version.c_str())) return true; + throw std::runtime_error("[Invalid Graph] Versions of OpenVINO used to export blob (" + std::string(ep_sdk_ver_val) + + ") and current runtime (" + openvino_sdk_version + ") don't match."); + } + } + } + } + return false; +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/onnx_ctx_model_helper.h b/samples/openvino/onnx_ctx_model_helper.h new file mode 100644 index 0000000000000..0c12d31869158 --- /dev/null +++ b/samples/openvino/onnx_ctx_model_helper.h @@ -0,0 +1,37 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include "core/session/onnxruntime_c_api_ep.h" + +namespace onnxruntime { +namespace openvino_ep { + +class EPCtxHandler { + public: + EPCtxHandler() = default; + EPCtxHandler(const EPCtxHandler&) = default; +// Status ExportEPCtxModel(const GraphViewer& graph_viewer, +// const std::string& graph_name, +// const logging::Logger& logger, +// const bool& ep_context_embed_mode, +// const std::string& model_blob_str, +// const std::string& openvino_sdk_version, +// const std::string& device_type) const; +// Status ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer); + bool CheckForOVEPCtxNode(const OrtGraphViewer* graph_viewer, std::string openvino_sdk_version) const; + bool IsValidOVEPCtxGraph() const { return is_valid_ep_ctx_graph_; } + [[nodiscard]] const std::shared_ptr GetModelBlobStream() const { return model_stream_; } + + private: + bool is_valid_ep_ctx_graph_{false}; + std::shared_ptr model_stream_; + static const OrtGraphApi* graph_api_; +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc new file mode 100644 index 0000000000000..d356362ddba8e --- /dev/null +++ b/samples/openvino/openvino_execution_provider.cc @@ -0,0 +1,87 @@ +#include +#include +#include "openvino_execution_provider.h" +#include "openvino_utils.h" +#include "ov_versions/capability.h" + +namespace onnxruntime { + +const OrtApi* OpenVINOExecutionProvider::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +const OrtGraphApi* OpenVINOExecutionProvider::graph_api_ = OpenVINOExecutionProvider::api_->GetGraphApi(ORT_API_VERSION); + +OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const ProviderOptions& provider_options) : OrtExecutionProvider() { + OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph_viewer, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { + const OpenVINOExecutionProvider* p = static_cast(this_); + std::string openvino_sdk_version = std::to_string(p->global_context_->OpenVINO_Version.at(0)) + "." + + std::to_string(p->global_context_->OpenVINO_Version.at(1)); + + // Check for valid ctx node and maintain state for validity + if (p->ep_ctx_handle_.CheckForOVEPCtxNode(graph_viewer, openvino_sdk_version)) { + int num_nodes = 0; + graph_api_->OrtGraph_NumberOfNodes(graph_viewer, &num_nodes); + assert((num_nodes==1) && "[Invalid Graph] EPContext Model with OpenVINO compiled blob should not have more than one node"); + } + + // Enable CI Logs + if (!(GetEnvironmentVar("ORT_OPENVINO_ENABLE_CI_LOG").empty())) { + std::cout << "In the OpenVINO EP" << std::endl; + } + const void* model_path = nullptr; + graph_api_->OrtGraph_GetModelPath(graph_viewer, &model_path); + p->global_context_->onnx_model_path_name = reinterpret_cast(model_path)->string(); + +// global_context_->onnx_opset_version = +// graph_viewer.DomainToVersionMap().at(kOnnxDomain); + + p->global_context_->model_precision = [&](const OrtGraphViewer* graph_viewer) { + // return empty if graph has no inputs or if types are not one of FP32/FP16 + // else assume the type of the first input + const char** required_inputs = nullptr; + size_t input_count = 0; + graph_api_->OrtGraph_GetRequiredInputs(graph_viewer, &required_inputs, &input_count); + if (input_count == 0) return ""; + if (p->global_context_->precision_str == "ACCURACY" && + p->global_context_->device_type.find("GPU") != std::string::npos) { + OrtValueInfoRef* valueinfo = nullptr; + graph_api_->OrtGraph_GetValueInfo(graph_viewer, required_inputs[0], &valueinfo); + ONNXTensorElementDataType data_type = valueinfo->data_type; + graph_api_->OrtGraph_ReleaseValueInfo(valueinfo); + graph_api_->ReleaseCharArray(required_inputs); + if (data_type == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) return "FP32"; + if (data_type == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) return "FP16"; + } + return ""; + }(graph_viewer); + + openvino_ep::GetCapability obj(graph_viewer, + p->global_context_->device_type, + p->global_context_->enable_qdq_optimizer); + *cnt = obj.Execute(indexed_sub_graph); + p->global_context_->is_wholly_supported_graph = obj.IsWhollySupportedGraph(); + }; + + OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info) -> OrtStatusPtr { + return nullptr; + }; +} + +OpenVINOExecutionProviderFactory::OpenVINOExecutionProviderFactory() { + OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { + ProviderOptions options; + for (size_t i = 0; i < option_size; i++) options[ep_option_keys[i]] = ep_option_values[i]; + std::unique_ptr ret = std::make_unique(OpenVINOEp.c_str(), std::move(options)); + return ret.release(); + }; +} +} // namespace onnxruntime + +#ifdef __cplusplus +extern "C" { +#endif +OrtExecutionProviderFactory* RegisterCustomEp() { + std::unique_ptr ret = std::make_unique(); + return ret.release(); +} +#ifdef __cplusplus +} +#endif diff --git a/samples/openvino/openvino_execution_provider.h b/samples/openvino/openvino_execution_provider.h new file mode 100644 index 0000000000000..eea702145d5a8 --- /dev/null +++ b/samples/openvino/openvino_execution_provider.h @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include +#include +#include "core/session/onnxruntime_c_api_ep.h" +#include "core/framework/provider_options.h" +#include "backend_manager.h" + +#ifdef _WIN32 +#define EXPORT_API __declspec(dllexport) +#else +#define EXPORT_API +#endif + +namespace onnxruntime { +static void print_build_options() { + std::cout << "[ERROR] INVALID DEVICE BUILD TYPE SPECIFIED" << std::endl; + std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority " + << "you want to build" + << std::endl; + std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build " + << "are ['CPU','GPU','NPU']" + << std::endl; + std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. " + << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" + << std::endl; +} + +static std::vector split(const std::string& s, char delim) { + std::vector result; + std::stringstream ss(s); + std::string item; + + while (getline(ss, item, delim)) { + result.push_back(item); + } + return result; +} + +static std::vector parseDevices(const std::string& device_string) { + std::string comma_separated_devices = device_string; + if (comma_separated_devices.find(":") != std::string::npos) { + comma_separated_devices = comma_separated_devices.substr(comma_separated_devices.find(":") + 1); + } + auto devices = split(comma_separated_devices, ','); + if (devices.size() < 2) { + print_build_options(); + throw std::runtime_error("Invalid device string: " + device_string); + } + std::vector dev_options = {"CPU", "GPU", "NPU"}; + for (std::string dev : devices) { + if (!std::count(dev_options.begin(), dev_options.end(), dev)) { + print_build_options(); + throw std::runtime_error("Invalid device string: " + device_string); + } + } + return devices; +} + +// Information needed to construct OpenVINO execution providers. +struct OpenVINOExecutionProviderInfo { + std::string device_type_{""}; + std::string precision_{""}; + bool enable_npu_fast_compile_{false}; + size_t num_of_threads_{0}; + std::string cache_dir_{""}; + std::string model_priority_{""}; + int num_streams_{1}; + void* context_{NULL}; + bool enable_opencl_throttling_{false}; + bool disable_dynamic_shapes_{false}; + bool export_ep_ctx_blob_{false}; + bool enable_qdq_optimizer_{false}; + bool disable_cpu_fallback_{false}; + + OpenVINOExecutionProviderInfo() = delete; + + explicit OpenVINOExecutionProviderInfo(std::string dev_type, std::string precision, bool enable_npu_fast_compile, + size_t num_of_threads, std::string cache_dir, std::string model_priority, + int num_streams, void* context, bool enable_opencl_throttling, + bool disable_dynamic_shapes, bool export_ep_ctx_blob, + bool enable_qdq_optimizer, bool disable_cpu_fallback) + : precision_(precision), + enable_npu_fast_compile_(enable_npu_fast_compile), + num_of_threads_(num_of_threads), + cache_dir_(std::move(cache_dir)), + model_priority_(model_priority), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + disable_dynamic_shapes_(disable_dynamic_shapes), + export_ep_ctx_blob_(export_ep_ctx_blob), + enable_qdq_optimizer_(enable_qdq_optimizer), + disable_cpu_fallback_(disable_cpu_fallback) { + std::set ov_supported_device_types = {"CPU", "GPU", + "GPU.0", "GPU.1", "NPU"}; + if (dev_type == "") { +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP]" +// << "No runtime device selection option provided."; +#if defined OPENVINO_CONFIG_CPU + device_type_ = "CPU"; + precision_ = "FP32"; +#elif defined OPENVINO_CONFIG_GPU + device_type_ = "GPU"; + precision_ = "FP16"; +#elif defined OPENVINO_CONFIG_NPU + device_type_ = "NPU"; + precision_ = "FP16"; +#elif defined OPENVINO_CONFIG_HETERO || defined OPENVINO_CONFIG_MULTI || defined OPENVINO_CONFIG_AUTO +#ifdef DEVICE_NAME +#define DEVICE DEVICE_NAME +#endif + dev_type = DEVICE; + + if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0 || dev_type.find("AUTO") == 0) { + std::vector devices = parseDevices(dev_type); + precision_ = "FP16"; + if (devices[0] == "CPU") { + precision_ = "FP32"; + } + device_type_ = std::move(dev_type); + } +#endif + } else if (ov_supported_device_types.find(dev_type) != ov_supported_device_types.end()) { + device_type_ = std::move(dev_type); + } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0 || dev_type.find("AUTO") == 0) { + std::vector devices = parseDevices(dev_type); + device_type_ = dev_type; + } else { + throw std::runtime_error("Invalid device string: " + dev_type); + } +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP]" +// << "Choosing Device: " << device_type_ << " , Precision: " << precision_; + } +}; + +//struct OpenVINOEPFunctionState { +// AllocateFunc allocate_func = nullptr; +// DestroyFunc destroy_func = nullptr; +// AllocatorHandle allocator_handle = nullptr; +// std::shared_ptr backend_manager; +//}; + +// Logical device representation. +class OpenVINOExecutionProvider : public OrtExecutionProvider { + public: + OpenVINOExecutionProvider(const char* ep_type, const ProviderOptions& provider_options); + ~OpenVINOExecutionProvider() = default; + + private: + std::unique_ptr global_context_; + openvino_ep::EPCtxHandler ep_ctx_handle_{}; + static const OrtApi* api_; + static const OrtGraphApi* graph_api_; +}; + +struct OpenVINOExecutionProviderFactory : public OrtExecutionProviderFactory { + OpenVINOExecutionProviderFactory(); +}; +} + +#ifdef __cplusplus +extern "C" { +#endif + +EXPORT_API OrtExecutionProviderFactory* RegisterCustomEp(); + +#ifdef __cplusplus +} +#endif diff --git a/samples/openvino/openvino_utils.cc b/samples/openvino/openvino_utils.cc new file mode 100644 index 0000000000000..cae3c60154b82 --- /dev/null +++ b/samples/openvino/openvino_utils.cc @@ -0,0 +1,25 @@ +#include +#include "openvino_utils.h" + +namespace onnxruntime { + std::string GetEnvironmentVar(const std::string& var_name) { +// TODO(leca): #ifdef _WIN32 +//#endif + constexpr DWORD kBufferSize = 32767; + + // Create buffer to hold the result + std::string buffer(kBufferSize, '\0'); + + // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters. + // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character. + // Therefore, If the function succeeds, kBufferSize should be larger than char_count. + auto char_count = GetEnvironmentVariableA(var_name.c_str(), buffer.data(), kBufferSize); + + if (kBufferSize > char_count) { + buffer.resize(char_count); + return buffer; + } + + return std::string(); + } +} diff --git a/samples/openvino/openvino_utils.h b/samples/openvino/openvino_utils.h new file mode 100644 index 0000000000000..3498657e53e35 --- /dev/null +++ b/samples/openvino/openvino_utils.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; +static const std::string OpenVINOEp = "OpenVINOEp"; + +namespace onnxruntime { + std::string GetEnvironmentVar(const std::string& var_name); +} diff --git a/samples/openvino/ov_interface.cc b/samples/openvino/ov_interface.cc new file mode 100644 index 0000000000000..af6c537c15a0b --- /dev/null +++ b/samples/openvino/ov_interface.cc @@ -0,0 +1,253 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "ov_interface.h" +#include "backend_utils.h" + +using Exception = ov::Exception; + +namespace onnxruntime { +namespace openvino_ep { + +const std::string log_tag = "[OpenVINO-EP] "; + +#ifndef NDEBUG +void printDebugInfo(const ov::CompiledModel& obj) { + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + // output of the actual settings that the device selected + auto supported_properties = obj.get_property(ov::supported_properties); + std::cout << "Model:" << std::endl; + for (const auto& cfg : supported_properties) { + if (cfg == ov::supported_properties) + continue; + auto prop = obj.get_property(cfg); + if (cfg == ov::device::properties) { + auto devices_properties = prop.as(); + for (auto& item : devices_properties) { + std::cout << " " << item.first << ": " << std::endl; + for (auto& item2 : item.second.as()) { + OPENVINO_SUPPRESS_DEPRECATED_START + if (item2.first == ov::supported_properties || item2.first == "SUPPORTED_CONFIG_KEYS)" || + item2.first == "SUPPORTED_METRICS") + continue; + OPENVINO_SUPPRESS_DEPRECATED_END + std::cout << " " << item2.first << ": " << item2.second.as() << std::endl; + } + } + } else { + std::cout << " " << cfg << ": " << prop.as() << std::endl; + } + } + } +} +#endif + +std::shared_ptr OVCore::ReadModel(const std::string& model, const std::string& model_path) const { + try { + std::istringstream modelStringStream(model); + std::istream& modelStream = modelStringStream; + // Try to load with FrontEndManager + ov::frontend::FrontEndManager manager; + ov::frontend::FrontEnd::Ptr FE; + ov::frontend::InputModel::Ptr inputModel; + + ov::AnyVector params{&modelStream, model_path}; + + FE = manager.load_by_model(params); + if (FE) { + inputModel = FE->load(params); + return FE->convert(inputModel); + } else { + throw std::runtime_error(log_tag + "[OpenVINO-EP] Unknown exception while Reading network"); + //return NULL; + } + } catch (const Exception& e) { + throw std::runtime_error(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what())); + } catch (...) { + throw std::runtime_error(log_tag + "[OpenVINO-EP] Unknown exception while Reading network"); + } +} + +OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_network, + std::string hw_target, + const ov::AnyMap& device_config, + std::string name) { + ov::CompiledModel obj; + try { + obj = oe.compile_model(ie_cnn_network, hw_target, device_config); +#ifndef NDEBUG + printDebugInfo(obj); +#endif + OVExeNetwork exe(obj); + return exe; + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph: " + name + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph " + name); + } +} + +OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, + std::string hw_target, + std::string precision, + std::string cache_dir, + const ov::AnyMap& device_config, + std::string name) { + ov::CompiledModel obj; + try { + if (hw_target == "AUTO:GPU,CPU") { + obj = oe.compile_model(onnx_model, ov::Tensor(), + "AUTO", + ov::device::priorities("GPU", "CPU"), + ov::device::properties("GPU", {ov::cache_dir(cache_dir), + ov::hint::inference_precision(precision)})); + } else { + obj = oe.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); + } +#ifndef NDEBUG + printDebugInfo(obj); +#endif + OVExeNetwork exe(obj); + return exe; + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph: " + name + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph " + name); + } +} + +OVExeNetwork OVCore::ImportModel(std::shared_ptr model_stream, + std::string hw_target, + const ov::AnyMap& device_config, + std::string name) { + try { + auto obj = oe.import_model(*model_stream, hw_target, device_config); +#ifndef NDEBUG + printDebugInfo(obj); +#endif + OVExeNetwork exe(obj); + return exe; + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph: " + name + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph " + name); + } +} + +void OVCore::SetCache(std::string cache_dir_path, std::string device_type) { + if (device_type != "AUTO:GPU,CPU") { + oe.set_property(ov::cache_dir(cache_dir_path)); + } +} + +#ifdef IO_BUFFER_ENABLED +OVExeNetwork OVCore::CompileModel(std::shared_ptr& model, + OVRemoteContextPtr context, std::string name) { + try { + auto obj = oe.compile_model(model, *context); +#ifndef NDEBUG + printDebugInfo(obj); +#endif + return OVExeNetwork(obj); + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph: " + name + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph " + name); + } +} +OVExeNetwork OVCore::ImportModel(std::shared_ptr model_stream, + OVRemoteContextPtr context, std::string name) { + try { + auto obj = oe.import_model(*model_stream, *context); +#ifndef NDEBUG + printDebugInfo(obj); +#endif + OVExeNetwork exe(obj); + return exe; + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph: " + name + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " Exception while Loading Network for graph " + name); + } +} +#endif + +std::vector OVCore::GetAvailableDevices() { + auto available_devices = oe.get_available_devices(); + return available_devices; +} + +void OVCore::SetStreams(const std::string& device_type, int num_streams) { + oe.set_property(device_type, {ov::num_streams(num_streams)}); +} + +OVInferRequest OVExeNetwork::CreateInferRequest() { + try { + auto infReq = obj.create_infer_request(); + OVInferRequest inf_obj(std::move(infReq)); + return inf_obj; + } catch (const Exception& e) { + throw std::runtime_error(log_tag + "Exception while creating InferRequest object: " + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + "Exception while creating InferRequest object."); + } +} + +OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) { + try { + auto tobj = ovInfReq.get_tensor(input_name); + OVTensorPtr blob = std::make_shared(tobj); + return blob; + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Cannot access IE Blob for input: " + input_name + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " Cannot access IE Blob for input: " + input_name); + } +} + +void OVInferRequest::SetTensor(std::string name, OVTensorPtr& blob) { + try { + ovInfReq.set_tensor(name, *(blob.get())); + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Cannot set Remote Blob for output: " + name + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " Cannot set Remote Blob for output: " + name); + } +} + +void OVInferRequest::StartAsync() { + try { + ovInfReq.start_async(); + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Couldn't start Inference: " + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " In Error Couldn't start Inference"); + } +} + +void OVInferRequest::Infer() { + try { + ovInfReq.infer(); + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Couldn't start Inference: " + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " In Error Couldn't start Inference"); + } +} + +void OVInferRequest::WaitRequest() { + try { + ovInfReq.wait(); + } catch (const Exception& e) { + throw std::runtime_error(log_tag + " Wait Model Failed: " + e.what()); + } catch (...) { + throw std::runtime_error(log_tag + " Wait Mode Failed"); + } +} + +void OVInferRequest::QueryStatus() { + std::cout << "ovInfReq.query_state()" + << " "; +} +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/ov_interface.h b/samples/openvino/ov_interface.h new file mode 100644 index 0000000000000..af6f252feb2ce --- /dev/null +++ b/samples/openvino/ov_interface.h @@ -0,0 +1,99 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include +#include + +#include "openvino/openvino.hpp" +#include "openvino/pass/convert_fp32_to_fp16.hpp" +#include "openvino/frontend/manager.hpp" + +#ifdef IO_BUFFER_ENABLED +#include +#endif + +#include + +namespace onnxruntime { +namespace openvino_ep { +class OVCore; +class OVInferRequest; +class OVExeNetwork; + +typedef ov::Tensor OVTensor; +typedef ov::ProfilingInfo OVProfilingInfo; +typedef ov::Model OVNetwork; +typedef std::shared_ptr OVInferRequestPtr; +typedef std::shared_ptr OVTensorPtr; + +#ifdef IO_BUFFER_ENABLED +typedef ov::intel_gpu::ocl::ClContext* OVRemoteContextPtr; +typedef ov::RemoteContext OVRemoteContext; +#endif + +class OVCore { + ov::Core oe; + + public: + std::shared_ptr ReadModel(const std::string& model_stream, const std::string& model_path) const; + OVExeNetwork CompileModel(std::shared_ptr& ie_cnn_network, + std::string hw_target, + const ov::AnyMap& device_config, + std::string name); + OVExeNetwork CompileModel(const std::string& onnx_model, + std::string hw_target, + std::string precision, + std::string cache_dir, + const ov::AnyMap& device_config, + std::string name); + OVExeNetwork ImportModel(std::shared_ptr model_stream, + std::string hw_target, + const ov::AnyMap& device_config, + std::string name); +#ifdef IO_BUFFER_ENABLED + OVExeNetwork CompileModel(std::shared_ptr& model, + OVRemoteContextPtr context, + std::string name); + OVExeNetwork ImportModel(std::shared_ptr model_stream, + OVRemoteContextPtr context, + std::string name); +#endif + std::vector GetAvailableDevices(); + void SetCache(std::string cache_dir_path, std::string device_type); + ov::Core& Get() { return oe; } + void SetStreams(const std::string& device_type, int num_streams); +}; + +class OVExeNetwork { + ov::CompiledModel obj; + + public: + explicit OVExeNetwork(ov::CompiledModel md) : obj(md) {} + OVExeNetwork() : obj(ov::CompiledModel()) {} + ov::CompiledModel& Get() { return obj; } + OVInferRequest CreateInferRequest(); +}; + +class OVInferRequest { + ov::InferRequest ovInfReq; + + public: + OVTensorPtr GetTensor(const std::string& name); + void SetTensor(std::string name, OVTensorPtr& blob); + void StartAsync(); + void Infer(); + void WaitRequest(); + void QueryStatus(); + explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {} + OVInferRequest() : ovInfReq(ov::InferRequest()) {} + ov::InferRequest& GetNewObj() { + return ovInfReq; + } +}; +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/ov_versions/capability.cc b/samples/openvino/ov_versions/capability.cc new file mode 100644 index 0000000000000..a9417261057e7 --- /dev/null +++ b/samples/openvino/ov_versions/capability.cc @@ -0,0 +1,226 @@ +// Copyright (C) 2019- Intel Corporation +// Licensed under the MIT License +#include +#include + +#include "../backend_utils.h" +#include "../backend_manager.h" +#include "capability.h" +//#include "core/providers/openvino/ov_versions/utils.h" +#include "openvino/core/version.hpp" + +#if defined(_MSC_VER) +#pragma warning(disable : 4244 4245 5208) +#elif __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif +#if defined(_MSC_VER) +#pragma warning(default : 4244 4245) +#elif __GNUC__ +#pragma GCC diagnostic pop +#endif + +namespace onnxruntime { +namespace openvino_ep { + +const OrtGraphApi* GetCapability::graph_api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION)->GetGraphApi(ORT_API_VERSION); + +// Constructor +GetCapability::GetCapability(const OrtGraphViewer* graph_viewer_param, + const std::string device_type_param, + const bool enable_qdq_optimizer) + : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { + bool npu_qdq_optimizer_enabled = false; + if (device_type_.find("NPU") != std::string::npos) { + device_type_ = "CPU"; + if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; + } +//#if OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 1 +// data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_, npu_qdq_optimizer_enabled); +//#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 2 +// data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_, npu_qdq_optimizer_enabled); +//#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 3 +// data_ops_ = new DataOps(graph_viewer_, V_2023_3, device_type_, npu_qdq_optimizer_enabled); +//#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0 +// data_ops_ = new DataOps(graph_viewer_, V_2024_0, device_type_, npu_qdq_optimizer_enabled); +//#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 1 +// data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); +//#else +// data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); +//#endif +} + +size_t GetCapability::Execute(OrtIndexedSubGraph*** indexed_sub_graph) { + // Check if it is a subgraph + bool is_subgraph = false; +// graph_api_->OrtGraph_IsSubgraph(graph_viewer_, &is_subgraph); + const char* graph_name = nullptr; + graph_api_->OrtGraph_GetName(graph_viewer_, &graph_name); + if (is_subgraph && !strcmp(graph_name, "tf2onnx")) return 0; + + // This is a list of initializers that nGraph considers as constants. Example weights, reshape shape etc. + std::unordered_set ng_required_initializers; + +// const auto unsupported_nodes = data_ops_->GetUnsupportedNodeIndices(ng_required_initializers); +//#ifndef NDEBUG +// if (openvino_ep::backend_utils::IsDebugEnabled()) { +// std::cout << "No of unsupported nodes " << unsupported_nodes.size() << std::endl; +// for (size_t i = 0; i < unsupported_nodes.size(); i++) { +// const Node* node = graph_viewer_.GetNode(unsupported_nodes[i]); +// std::cout << "Unsupported node op " << node->OpType() << std::endl; +// } +// } +//#endif +// +// // If all ops are supported, no partitioning is required. Short-circuit and avoid splitting. +// if (unsupported_nodes.empty()) { +// std::vector inputs; +// std::vector outputs; +// // Fill inputs with names +// std::for_each(graph_viewer_.GetInputs().begin(), graph_viewer_.GetInputs().end(), +// [&inputs](const NodeArg* node_arg) { inputs.push_back(node_arg->Name()); }); +// +// /* In scenarios, when there are no inputs or all inputs being initializers, +// ConstantFolding optimization in onnxruntime pre-computes the value.*/ +// if (inputs.empty()) { +// return result; +// } +// +// const std::vector& nodes = graph_viewer_.GetNodesInTopologicalOrder(); +// +// const Node* node = graph_viewer_.GetNode(nodes[0]); +// +// // Handle cases where lone, reoccuring Ops in smaller models cannot be supported in OpenVINO +// // If only a node of the same lone,unsupported type is present, then do not proceed with the subgraph +// if (nodes.size() <= 3) { +// if (data_ops_->IsOpSupportedOnlyInModel(node->OpType())) { +// return result; +// } +// } +// +// // Nodes that work well in models but not as a single node +// if (nodes.size() == 1) { +// // If reshape is not an intermediate node, shape needs to be an initializer +// if (data_ops_->SpecialConditionForClusterSizeOne(ng_required_initializers, node)) { +// return result; +// } +// } +// +// // Initializers need to be part of meta_def->inputs +// std::for_each(ng_required_initializers.begin(), ng_required_initializers.end(), +// [&inputs](const std::string& initializer) { inputs.push_back(initializer); }); +// +// // Fill outputs with names +// std::for_each(graph_viewer_.GetOutputs().begin(), graph_viewer_.GetOutputs().end(), +// [&outputs](const NodeArg* node_arg) { outputs.push_back(node_arg->Name()); }); +// +// // Create and add this graph to result. +// AppendClusterToSubGraph(graph_viewer_.GetNodesInTopologicalOrder(), inputs, outputs, result); +// +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is fully supported by OpenVINO"; +// // Enable CI Logs +// if (backend_utils::IsCILogEnabled()) { +// std::cout << "Model is fully supported on OpenVINO" << std::endl; +// } +// is_wholly_supported_graph_ = true; +// +// } else { // unsupported_nodes_idx.empty() +//#if defined(OPENVINO_DISABLE_GRAPH_PARTITION) // disables graph partition at build time +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] DISABLE_GRAPH_PARTITION option is set"; +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, " +// << "so making the full model fall back to default CPU Execution Provider"; +// return result; +//#endif +// +// std::vector modified_unsupported_nodes; +// for (const NodeIndex& node_idx : graph_viewer_.GetNodesInTopologicalOrder()) { +// if (find(unsupported_nodes.begin(), unsupported_nodes.end(), node_idx) != unsupported_nodes.end()) { +// modified_unsupported_nodes.push_back(node_idx); +// } else { +// const Node* node = graph_viewer_.GetNode(node_idx); +// const std::string& optype = node->OpType(); +// if (data_ops_->InsertNode(optype)) { +// modified_unsupported_nodes.push_back(node_idx); +// } +// } +// } +// +// auto ng_clusters = GetPartitionedClusters(graph_viewer_.GetNodesInTopologicalOrder(), modified_unsupported_nodes); +// +// auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); +// +// int no_of_clusters = 0; +// +// for (auto this_cluster : connected_clusters) { +// // If subgraph has less then three, graph is considered trivial +// if (this_cluster.size() < 3) { +// continue; +// } +// +// std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; +// +// GetInputsOutputsOfCluster(graph_viewer_, +// this_cluster, +// ng_required_initializers, +// cluster_graph_inputs, +// cluster_inputs, +// cluster_outputs); +// +// bool omit_subgraph = false; +// // Omitting zero dim subgraphs +// for (auto index : this_cluster) { +// const Node* node = graph_viewer_.GetNode(index); +// if (data_ops_->DoNotOmitSubGraph(node->OpType())) { +// for (const auto& input : node->InputDefs()) { +// const auto& input_name = input->Name(); +// auto it = find(cluster_graph_inputs.begin(), cluster_graph_inputs.end(), input_name); +// if (it != cluster_graph_inputs.end()) { +// omit_subgraph = true; +// break; +// } +// } +// } +// +// if (node->OpType() == "Conv" || node->OpType() == "Identity") { +// const auto& output_name = node->OutputDefs()[0]->Name(); +// auto it = find(cluster_outputs.begin(), cluster_outputs.end(), output_name); +// if (it != cluster_outputs.end() && node->GetOutputEdgesCount() != 0) { +// omit_subgraph = true; +// break; +// } +// } +// +// std::map slice_map; +// if (node->OpType() == "Slice") { +// auto input = node->InputDefs()[0]; +// const auto& input_name = input->Name(); +// auto it = find(cluster_graph_inputs.begin(), cluster_graph_inputs.end(), input_name); +// if (it != cluster_graph_inputs.end()) { +// if (slice_map.count(input_name) == 0) { +// slice_map[input_name] = 1; +// } else { +// omit_subgraph = true; +// break; +// } +// } +// } +// } +// if (omit_subgraph) +// continue; +// +// /* In scenarios, when there are no inputs or all inputs being initializers, +// ConstantFolding optimization in onnxruntime pre-computes the value.*/ +// if (!cluster_inputs.empty()) { +// AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); +// no_of_clusters++; +// } +// } +// LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Supported subgraphs on OpenVINO: " << no_of_clusters; +// } +// + return 0; +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/ov_versions/capability.h b/samples/openvino/ov_versions/capability.h new file mode 100644 index 0000000000000..210c3a3f92831 --- /dev/null +++ b/samples/openvino/ov_versions/capability.h @@ -0,0 +1,32 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once +#include +#include +#include +//#include "core/providers/openvino/ov_versions/data_ops.h" + +namespace onnxruntime { +namespace openvino_ep { + +class GetCapability { + private: + const OrtGraphViewer* graph_viewer_; + std::string device_type_; +// DataOps* data_ops_; + bool is_wholly_supported_graph_ = false; + static const OrtGraphApi* graph_api_; + + public: + GetCapability(const OrtGraphViewer* graph_viewer_param, + const std::string device_type_param, + const bool enable_qdq_optimizer); + size_t Execute(OrtIndexedSubGraph***); + bool IsWhollySupportedGraph() { + return is_wholly_supported_graph_; + } +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index ace9d73dd5e36..8e7f6f1fbd923 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -311,7 +311,7 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { // fingerprint current graph by hashing graph inputs // const std::vector& input_names = nullptr; - const char** input_names = nullptr; + const char** input_names = nullptr; // TODO(leca): release input_names size_t input_count = 0; graph_api->OrtGraph_GetAllInputs(graph_viewer, &input_names, &input_count); for (size_t i = 0; i < input_count; ++i) { From bc6561344698100e2817759e47bc8b42f1f974aa Mon Sep 17 00:00:00 2001 From: jslhcl Date: Wed, 30 Oct 2024 18:25:38 -0700 Subject: [PATCH 58/81] openvino, GetCapability almost ready --- .../core/session/onnxruntime_c_api_ep.h | 9 + .../core/session/onnxruntime_c_api_ep.cc | 7 + onnxruntime/core/session/ort_apis_ep.h | 2 + .../openvino/openvino_execution_provider.cc | 2 + samples/openvino/openvino_utils.cc | 37 + samples/openvino/openvino_utils.h | 6 + samples/openvino/ov_versions/capability.cc | 360 ++++---- samples/openvino/ov_versions/capability.h | 4 +- samples/openvino/ov_versions/data_ops.cc | 858 ++++++++++++++++++ samples/openvino/ov_versions/data_ops.h | 98 ++ samples/openvino/ov_versions/utils.cc | 313 +++++++ samples/openvino/ov_versions/utils.h | 54 ++ 12 files changed, 1585 insertions(+), 165 deletions(-) create mode 100644 samples/openvino/ov_versions/data_ops.cc create mode 100644 samples/openvino/ov_versions/data_ops.h create mode 100644 samples/openvino/ov_versions/utils.cc create mode 100644 samples/openvino/ov_versions/utils.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index e6525a2c512c6..df6ad6ecb85f9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -134,6 +134,15 @@ ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out); */ ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); +/** \brief Check if the graph is a subgraph + * TODO(leca): maybe deprecate OrtGraph_IsSubgraph? + * + * \param[in] graph The graph to query + * \param[out] out True if the graph is a subgraph + * + */ +ORT_API2_STATUS(OrtGraph_IsSubgraph2, const OrtGraphViewer* graph, _Out_ bool* out); + /** \brief Get the parent node of the graph * * \param[in] graph The graph to query diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index fd5781bc1df49..30740b9773b83 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -43,6 +43,12 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetParentGraph, const OrtGraph* graph return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSubgraph2, const OrtGraphViewer* graph, _Out_ bool* out) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *out = graph_viewer->IsSubgraph(); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); *parent_node = reinterpret_cast(graph_viewer->ParentNode()); @@ -789,6 +795,7 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetNodesIndexInTopologicalOrder, &OrtGraphApis::OrtGraph_IsSubgraph, &OrtGraphApis::OrtGraph_GetParentGraph, + &OrtGraphApis::OrtGraph_IsSubgraph2, &OrtGraphApis::OrtGraph_GetParenNode, &OrtGraphApis::OrtGraph_GetModelPath, &OrtGraphApis::OrtGraph_GetOrtGraph, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index 24337a6bf652a..7e010e8f8a2c4 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -13,6 +13,8 @@ ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out) ORT_API_STATUS_IMPL(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); +ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph2, const OrtGraphViewer* graph, _Out_ bool* out); + ORT_API_STATUS_IMPL(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** model_path); diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc index d356362ddba8e..0a6e5c16f4177 100644 --- a/samples/openvino/openvino_execution_provider.cc +++ b/samples/openvino/openvino_execution_provider.cc @@ -63,6 +63,8 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info) -> OrtStatusPtr { return nullptr; }; + + //OrtExecutionProvider::ReleaseIndexedSubGraphs } OpenVINOExecutionProviderFactory::OpenVINOExecutionProviderFactory() { diff --git a/samples/openvino/openvino_utils.cc b/samples/openvino/openvino_utils.cc index cae3c60154b82..9c7b16e684f0a 100644 --- a/samples/openvino/openvino_utils.cc +++ b/samples/openvino/openvino_utils.cc @@ -22,4 +22,41 @@ namespace onnxruntime { return std::string(); } + + OrtStatus* ForEachNodeDef(const OrtGraphApi* graph_api, const OrtGraphViewer* graph, const OrtNode* node, + std::function func) { + size_t input_count = 0; + graph_api->OrtNode_GetNumInputs(node, &input_count); + for (int i = 0; i < input_count; i++) { + const char* input_name = nullptr; + graph_api->OrtNode_GetIthInputName(node, i, &input_name); + OrtValueInfoRef* value_info = nullptr; + graph_api->OrtGraph_GetValueInfo(graph, input_name, &value_info); + func(input_name, value_info, true); + graph_api->OrtGraph_ReleaseValueInfo(value_info); + } + + size_t implicit_input_count = 0; + graph_api->OrtNode_GetImplicitInputSize(node, &implicit_input_count); + for (int i = 0; i < implicit_input_count; i++) { + const char* input_name = nullptr; + graph_api->OrtNode_GetIthImplicitInputName(node, i, &input_name); + OrtValueInfoRef* value_info = nullptr; + graph_api->OrtGraph_GetValueInfo(graph, input_name, &value_info); + func(input_name, value_info, true); + graph_api->OrtGraph_ReleaseValueInfo(value_info); + } + + size_t output_count = 0; + graph_api->OrtNode_GetNumOutputs(node, &output_count); + for (int i = 0; i < output_count; i++) { + const char* output_name = nullptr; + graph_api->OrtNode_GetIthOutputName(node, i, &output_name); + OrtValueInfoRef* value_info = nullptr; + graph_api->OrtGraph_GetValueInfo(graph, output_name, &value_info); + func(output_name, value_info, false); + graph_api->OrtGraph_ReleaseValueInfo(value_info); + } + return nullptr; + } } diff --git a/samples/openvino/openvino_utils.h b/samples/openvino/openvino_utils.h index 3498657e53e35..5fe3f5ff38cfa 100644 --- a/samples/openvino/openvino_utils.h +++ b/samples/openvino/openvino_utils.h @@ -3,10 +3,16 @@ #pragma once #include +#include +#include "core/session/onnxruntime_c_api_ep.h" +constexpr const char* kOnnxDomain = ""; constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; static const std::string OpenVINOEp = "OpenVINOEp"; namespace onnxruntime { + using NodeIndex = size_t; std::string GetEnvironmentVar(const std::string& var_name); + // TODO(leca): add name (const char*) into OrtValueInfoRef? + OrtStatus* ForEachNodeDef(const OrtGraphApi* graph_api, const OrtGraphViewer* graph, const OrtNode* node, std::function func); } diff --git a/samples/openvino/ov_versions/capability.cc b/samples/openvino/ov_versions/capability.cc index a9417261057e7..c1cc771653818 100644 --- a/samples/openvino/ov_versions/capability.cc +++ b/samples/openvino/ov_versions/capability.cc @@ -6,7 +6,7 @@ #include "../backend_utils.h" #include "../backend_manager.h" #include "capability.h" -//#include "core/providers/openvino/ov_versions/utils.h" +#include "utils.h" #include "openvino/core/version.hpp" #if defined(_MSC_VER) @@ -36,25 +36,25 @@ GetCapability::GetCapability(const OrtGraphViewer* graph_viewer_param, device_type_ = "CPU"; if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; } -//#if OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 1 -// data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_, npu_qdq_optimizer_enabled); -//#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 2 -// data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_, npu_qdq_optimizer_enabled); -//#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 3 -// data_ops_ = new DataOps(graph_viewer_, V_2023_3, device_type_, npu_qdq_optimizer_enabled); -//#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0 -// data_ops_ = new DataOps(graph_viewer_, V_2024_0, device_type_, npu_qdq_optimizer_enabled); -//#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 1 -// data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); -//#else -// data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); -//#endif +#if OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 1 + data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 2 + data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 3 + data_ops_ = new DataOps(graph_viewer_, V_2023_3, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0 + data_ops_ = new DataOps(graph_viewer_, V_2024_0, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 1 + data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); +#else + data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); +#endif } size_t GetCapability::Execute(OrtIndexedSubGraph*** indexed_sub_graph) { // Check if it is a subgraph bool is_subgraph = false; -// graph_api_->OrtGraph_IsSubgraph(graph_viewer_, &is_subgraph); + graph_api_->OrtGraph_IsSubgraph2(graph_viewer_, &is_subgraph); const char* graph_name = nullptr; graph_api_->OrtGraph_GetName(graph_viewer_, &graph_name); if (is_subgraph && !strcmp(graph_name, "tf2onnx")) return 0; @@ -62,164 +62,198 @@ size_t GetCapability::Execute(OrtIndexedSubGraph*** indexed_sub_graph) { // This is a list of initializers that nGraph considers as constants. Example weights, reshape shape etc. std::unordered_set ng_required_initializers; -// const auto unsupported_nodes = data_ops_->GetUnsupportedNodeIndices(ng_required_initializers); -//#ifndef NDEBUG -// if (openvino_ep::backend_utils::IsDebugEnabled()) { -// std::cout << "No of unsupported nodes " << unsupported_nodes.size() << std::endl; -// for (size_t i = 0; i < unsupported_nodes.size(); i++) { -// const Node* node = graph_viewer_.GetNode(unsupported_nodes[i]); -// std::cout << "Unsupported node op " << node->OpType() << std::endl; -// } -// } -//#endif -// -// // If all ops are supported, no partitioning is required. Short-circuit and avoid splitting. -// if (unsupported_nodes.empty()) { -// std::vector inputs; -// std::vector outputs; -// // Fill inputs with names -// std::for_each(graph_viewer_.GetInputs().begin(), graph_viewer_.GetInputs().end(), -// [&inputs](const NodeArg* node_arg) { inputs.push_back(node_arg->Name()); }); -// -// /* In scenarios, when there are no inputs or all inputs being initializers, -// ConstantFolding optimization in onnxruntime pre-computes the value.*/ -// if (inputs.empty()) { -// return result; -// } -// -// const std::vector& nodes = graph_viewer_.GetNodesInTopologicalOrder(); -// -// const Node* node = graph_viewer_.GetNode(nodes[0]); -// -// // Handle cases where lone, reoccuring Ops in smaller models cannot be supported in OpenVINO -// // If only a node of the same lone,unsupported type is present, then do not proceed with the subgraph -// if (nodes.size() <= 3) { -// if (data_ops_->IsOpSupportedOnlyInModel(node->OpType())) { -// return result; -// } -// } -// -// // Nodes that work well in models but not as a single node -// if (nodes.size() == 1) { -// // If reshape is not an intermediate node, shape needs to be an initializer -// if (data_ops_->SpecialConditionForClusterSizeOne(ng_required_initializers, node)) { -// return result; -// } -// } -// -// // Initializers need to be part of meta_def->inputs -// std::for_each(ng_required_initializers.begin(), ng_required_initializers.end(), -// [&inputs](const std::string& initializer) { inputs.push_back(initializer); }); -// -// // Fill outputs with names -// std::for_each(graph_viewer_.GetOutputs().begin(), graph_viewer_.GetOutputs().end(), -// [&outputs](const NodeArg* node_arg) { outputs.push_back(node_arg->Name()); }); -// -// // Create and add this graph to result. -// AppendClusterToSubGraph(graph_viewer_.GetNodesInTopologicalOrder(), inputs, outputs, result); -// + const auto unsupported_nodes = data_ops_->GetUnsupportedNodeIndices(ng_required_initializers); +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "No of unsupported nodes " << unsupported_nodes.size() << std::endl; + for (size_t i = 0; i < unsupported_nodes.size(); i++) { + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer_, unsupported_nodes[i], &node); + const char* optype = nullptr; + graph_api_->OrtNode_GetOpType(node, &optype); + std::cout << "Unsupported node op " << optype << std::endl; + } + } +#endif + + // If all ops are supported, no partitioning is required. Short-circuit and avoid splitting. + std::vector cache; + if (unsupported_nodes.empty()) { + std::vector inputs; + std::vector outputs; + // Fill inputs with names + const char** input_names = nullptr; + size_t input_count = 0; + graph_api_->OrtGraph_GetRequiredInputs(graph_viewer_, &input_names, &input_count); + for (int i = 0; i < input_count; i++) inputs.push_back(std::string(input_names[i])); + graph_api_->ReleaseCharArray(input_names); + + /* In scenarios, when there are no inputs or all inputs being initializers, + ConstantFolding optimization in onnxruntime pre-computes the value.*/ + if (inputs.empty()) { + return 0; + } + + const size_t* nodes = nullptr; + size_t num_nodes; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer_, 0, &nodes, &num_nodes); + + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer_, nodes[0], &node); + + // Handle cases where lone, reoccuring Ops in smaller models cannot be supported in OpenVINO + // If only a node of the same lone,unsupported type is present, then do not proceed with the subgraph + if (num_nodes <= 3) { + const char* optype = nullptr; + graph_api_->OrtNode_GetOpType(node, &optype); + if (data_ops_->IsOpSupportedOnlyInModel(optype)) { + return 0; + } + } + + // Nodes that work well in models but not as a single node + if (num_nodes == 1) { + // If reshape is not an intermediate node, shape needs to be an initializer + if (data_ops_->SpecialConditionForClusterSizeOne(ng_required_initializers, node)) { + return 0; + } + } + + // Initializers need to be part of meta_def->inputs + std::for_each(ng_required_initializers.begin(), ng_required_initializers.end(), + [&inputs](const std::string& initializer) { inputs.push_back(initializer); }); + + // Fill outputs with names + size_t output_count = 0; + graph_api_->OrtGraph_GetOutputSize(graph_viewer_, &output_count); + for (int i = 0; i < output_count; i++) { + const char* output_name = nullptr; + graph_api_->OrtGraph_GetIthOutputName(graph_viewer_, i, &output_name); + outputs.push_back(std::string(output_name)); + } + + // Create and add this graph to result. + AppendClusterToSubGraph(nodes, num_nodes, inputs, outputs, cache); + // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is fully supported by OpenVINO"; -// // Enable CI Logs -// if (backend_utils::IsCILogEnabled()) { -// std::cout << "Model is fully supported on OpenVINO" << std::endl; -// } -// is_wholly_supported_graph_ = true; -// -// } else { // unsupported_nodes_idx.empty() -//#if defined(OPENVINO_DISABLE_GRAPH_PARTITION) // disables graph partition at build time + // Enable CI Logs + if (backend_utils::IsCILogEnabled()) { + std::cout << "Model is fully supported on OpenVINO" << std::endl; + } + is_wholly_supported_graph_ = true; + + } else { // unsupported_nodes_idx.empty() +#if defined(OPENVINO_DISABLE_GRAPH_PARTITION) // disables graph partition at build time // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] DISABLE_GRAPH_PARTITION option is set"; // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, " // << "so making the full model fall back to default CPU Execution Provider"; -// return result; -//#endif -// -// std::vector modified_unsupported_nodes; -// for (const NodeIndex& node_idx : graph_viewer_.GetNodesInTopologicalOrder()) { -// if (find(unsupported_nodes.begin(), unsupported_nodes.end(), node_idx) != unsupported_nodes.end()) { -// modified_unsupported_nodes.push_back(node_idx); -// } else { -// const Node* node = graph_viewer_.GetNode(node_idx); -// const std::string& optype = node->OpType(); -// if (data_ops_->InsertNode(optype)) { -// modified_unsupported_nodes.push_back(node_idx); -// } -// } -// } -// -// auto ng_clusters = GetPartitionedClusters(graph_viewer_.GetNodesInTopologicalOrder(), modified_unsupported_nodes); -// -// auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); -// -// int no_of_clusters = 0; -// -// for (auto this_cluster : connected_clusters) { -// // If subgraph has less then three, graph is considered trivial -// if (this_cluster.size() < 3) { -// continue; -// } -// -// std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; -// -// GetInputsOutputsOfCluster(graph_viewer_, + return result; +#endif + + std::vector modified_unsupported_nodes; + const size_t* topo_order = nullptr; + size_t num_nodes = 0; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer_, 0, &topo_order, &num_nodes); + for (int i = 0; i < num_nodes; i++) { + const NodeIndex node_idx = topo_order[i]; + if (find(unsupported_nodes.begin(), unsupported_nodes.end(), node_idx) != unsupported_nodes.end()) { + modified_unsupported_nodes.push_back(node_idx); + } else { + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer_, node_idx, &node); + const char* optype = nullptr; + graph_api_->OrtNode_GetOpType(node, &optype); + if (data_ops_->InsertNode(optype)) { + modified_unsupported_nodes.push_back(node_idx); + } + } + } + + std::vector topo_vec(topo_order, topo_order + num_nodes); + auto ng_clusters = GetPartitionedClusters(topo_vec, modified_unsupported_nodes); + + auto connected_clusters = GetConnectedClusters(graph_api_, graph_viewer_, ng_clusters); + + int no_of_clusters = 0; + + for (auto this_cluster : connected_clusters) { + // If subgraph has less then three, graph is considered trivial + if (this_cluster.size() < 3) { + continue; + } + + std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; + +// GetInputsOutputsOfCluster(graph_api_, graph_viewer_, // this_cluster, // ng_required_initializers, // cluster_graph_inputs, // cluster_inputs, // cluster_outputs); -// -// bool omit_subgraph = false; -// // Omitting zero dim subgraphs -// for (auto index : this_cluster) { -// const Node* node = graph_viewer_.GetNode(index); -// if (data_ops_->DoNotOmitSubGraph(node->OpType())) { -// for (const auto& input : node->InputDefs()) { -// const auto& input_name = input->Name(); -// auto it = find(cluster_graph_inputs.begin(), cluster_graph_inputs.end(), input_name); -// if (it != cluster_graph_inputs.end()) { -// omit_subgraph = true; -// break; -// } -// } -// } -// -// if (node->OpType() == "Conv" || node->OpType() == "Identity") { -// const auto& output_name = node->OutputDefs()[0]->Name(); -// auto it = find(cluster_outputs.begin(), cluster_outputs.end(), output_name); -// if (it != cluster_outputs.end() && node->GetOutputEdgesCount() != 0) { -// omit_subgraph = true; -// break; -// } -// } -// -// std::map slice_map; -// if (node->OpType() == "Slice") { -// auto input = node->InputDefs()[0]; -// const auto& input_name = input->Name(); -// auto it = find(cluster_graph_inputs.begin(), cluster_graph_inputs.end(), input_name); -// if (it != cluster_graph_inputs.end()) { -// if (slice_map.count(input_name) == 0) { -// slice_map[input_name] = 1; -// } else { -// omit_subgraph = true; -// break; -// } -// } -// } -// } -// if (omit_subgraph) -// continue; -// -// /* In scenarios, when there are no inputs or all inputs being initializers, -// ConstantFolding optimization in onnxruntime pre-computes the value.*/ -// if (!cluster_inputs.empty()) { -// AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); -// no_of_clusters++; -// } -// } + + bool omit_subgraph = false; + // Omitting zero dim subgraphs + for (auto index : this_cluster) { + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer_, index, &node); + const char* optype = nullptr; + graph_api_->OrtNode_GetOpType(node, &optype); + if (data_ops_->DoNotOmitSubGraph(optype)) { + size_t num_inputs = 0; + graph_api_->OrtNode_GetNumInputs(node, &num_inputs); + for (int i = 0; i < num_inputs; i++) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, i, &input_name); + auto it = find(cluster_graph_inputs.begin(), cluster_graph_inputs.end(), std::string(input_name)); + if (it != cluster_graph_inputs.end()) { + omit_subgraph = true; + break; + } + } + } + + if (strcmp(optype, "Conv") == 0 || strcmp(optype, "Identity") == 0) { + const char* output_name = nullptr; + graph_api_->OrtNode_GetIthOutputName(node, 0, &output_name); + auto it = find(cluster_outputs.begin(), cluster_outputs.end(), std::string(output_name)); + size_t outputs_count = 0; + graph_api_->OrtNode_GetNumOutputs(node, &outputs_count); // TODO(leca): equivelant to node->GetOutputEdgesCount()? + if (it != cluster_outputs.end() && outputs_count != 0) { + omit_subgraph = true; + break; + } + } + + std::map slice_map; + if (!strcmp(optype, "Slice")) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, 0, &input_name); + auto it = find(cluster_graph_inputs.begin(), cluster_graph_inputs.end(), std::string(input_name)); + if (it != cluster_graph_inputs.end()) { + if (slice_map.count(input_name) == 0) { + slice_map[input_name] = 1; + } else { + omit_subgraph = true; + break; + } + } + } + } + if (omit_subgraph) + continue; + + /* In scenarios, when there are no inputs or all inputs being initializers, + ConstantFolding optimization in onnxruntime pre-computes the value.*/ + if (!cluster_inputs.empty()) { + AppendClusterToSubGraph(this_cluster.data(), this_cluster.size(), cluster_inputs, cluster_outputs, cache); + no_of_clusters++; + } + } // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Supported subgraphs on OpenVINO: " << no_of_clusters; -// } -// - return 0; + } + *indexed_sub_graph = new OrtIndexedSubGraph* [cache.size()]; + for (int i = 0; i < cache.size(); i++) (*indexed_sub_graph)[i] = cache[i]; + return cache.size(); } } // namespace openvino_ep diff --git a/samples/openvino/ov_versions/capability.h b/samples/openvino/ov_versions/capability.h index 210c3a3f92831..797dc272c3fcc 100644 --- a/samples/openvino/ov_versions/capability.h +++ b/samples/openvino/ov_versions/capability.h @@ -5,7 +5,7 @@ #include #include #include -//#include "core/providers/openvino/ov_versions/data_ops.h" +#include "data_ops.h" namespace onnxruntime { namespace openvino_ep { @@ -14,7 +14,7 @@ class GetCapability { private: const OrtGraphViewer* graph_viewer_; std::string device_type_; -// DataOps* data_ops_; + DataOps* data_ops_; bool is_wholly_supported_graph_ = false; static const OrtGraphApi* graph_api_; diff --git a/samples/openvino/ov_versions/data_ops.cc b/samples/openvino/ov_versions/data_ops.cc new file mode 100644 index 0000000000000..8bc821cbb4f72 --- /dev/null +++ b/samples/openvino/ov_versions/data_ops.cc @@ -0,0 +1,858 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include +#include +#include +#include +#include +#include + +#include "../backend_utils.h" +#include "../backend_manager.h" +#include "../ov_interface.h" +#include "data_ops.h" +#include "capability.h" +#include "utils.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4244 4245 5208) +#elif __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif +// #include +// #include +#if defined(_MSC_VER) +#pragma warning(default : 4244 4245) +#elif __GNUC__ +#pragma GCC diagnostic pop +#endif + +namespace onnxruntime { +namespace openvino_ep { + +// Ops which are supported only in models(as intermediate nodes) and not in unit tests +std::set ops_supported_only_in_model = { + "Add", + "Cast", + "Celu", + "Concat", + "ConstantOfShape", + "DequantizeLinear", + "Dropout", + "Einsum", + "Exp", + "Expand", + "EyeLike", + "GatherElements", + "GatherND", + "GridSample", + "Identity", + "LayerNormalization", + "Loop", + "LSTM", + "NonMaxSuppression", + "NonZero", + "Not", + "OneHot", + "Pad", + "QuantizeLinear", + "RandomNormalLike", + "Range", + "ReduceMin", + "Resize", + "Round", + "Shape", + "Slice", + "Split", + "Tile", + "TopK", + "Trilu"}; + +// Ops which are supported as functions (as composite ops) +std::set ops_supported_as_function = { + "LessOrEqual", + "GreaterOrEqual", + "LayerNormalization", + "Celu"}; + +std::vector supported_op_mode = { + {"Abs", V_2020_4, {"CPU", "GPU"}}, + {"Acos", V_2020_4, {"CPU"}}, + {"Acos", V_2022_1, {"GPU"}}, + {"Acosh", V_2020_4, {"CPU"}}, + {"Acosh", V_2022_1, {"GPU"}}, + {"Add", V_2020_4, {"CPU", "GPU"}}, + {"And", V_2020_4, {"CPU", "GPU"}}, + {"ArgMax", V_2020_4, {"CPU"}}, + {"ArgMax", V_2021_1, {"GPU"}}, + {"ArgMin", V_2020_4, {"CPU"}}, + {"ArgMin", V_2022_1, {"GPU"}}, + {"Asin", V_2020_4, {"CPU", "GPU"}}, + {"Asinh", V_2020_4, {"CPU", "GPU"}}, + {"Atan", V_2020_4, {"CPU", "GPU"}}, + {"Atanh", V_2020_4, {"CPU"}}, + {"Atanh", V_2022_1, {"GPU"}}, + {"AveragePool", V_2020_4, {"CPU", "GPU"}}, + {"BatchNormalization", V_2020_4, {"CPU", "GPU"}}, + {"BitShift", V_2022_1, {"CPU"}}, + {"Cast", V_2020_4, {"CPU", "GPU"}}, + {"CastLike", V_2023_1, {"CPU", "GPU"}}, + {"Ceil", V_2020_4, {"GPU"}}, + {"Ceil", V_2021_4, {"CPU"}}, + {"Celu", V_2022_1, {"CPU", "GPU"}}, + {"Clip", V_2020_4, {"CPU", "GPU"}}, + {"Compress", V_2023_1, {"CPU", "GPU"}}, + {"Concat", V_2020_4, {"CPU", "GPU"}}, + {"Constant", V_2020_4, {"CPU", "GPU"}}, + {"ConstantOfShape", V_2020_4, {"CPU", "GPU"}}, + {"Conv", V_2020_4, {"CPU", "GPU"}}, + {"ConvInteger", V_2022_1, {"CPU", "GPU"}}, + {"ConvTranspose", V_2020_4, {"CPU", "GPU"}}, + {"Cos", V_2020_4, {"CPU"}}, + {"Cos", V_2022_1, {"GPU"}}, + {"Cosh", V_2020_4, {"CPU"}}, + {"Cosh", V_2022_1, {"GPU"}}, + {"CumSum", V_2022_1, {"CPU", "GPU"}}, + {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, + {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, + {"Div", V_2020_4, {"CPU", "GPU"}}, + {"Dropout", V_2020_4, {"CPU", "GPU"}}, + {"Elu", V_2020_4, {"CPU", "GPU"}}, + {"Einsum", V_2023_1, {"CPU", "GPU"}}, + {"EPContext", V_2024_0, {"CPU", "GPU", "NPU"}}, + {"Equal", V_2020_4, {"CPU", "GPU"}}, + {"Erf", V_2020_4, {"CPU", "GPU"}}, + {"Exp", V_2020_4, {"CPU", "GPU"}}, + {"Expand", V_2022_1, {"CPU", "GPU"}}, + {"EyeLike", V_2022_1, {"CPU"}}, + {"Flatten", V_2020_4, {"CPU", "GPU"}}, + {"Floor", V_2020_4, {"CPU", "GPU"}}, + {"Gather", V_2020_4, {"CPU", "GPU"}}, + {"GatherElements", V_2022_2, {"CPU", "GPU"}}, + {"GatherND", V_2021_4, {"CPU", "GPU"}}, + {"Gelu", V_2023_1, {"CPU", "GPU"}}, + {"Gemm", V_2020_4, {"CPU", "GPU"}}, + {"GlobalAveragePool", V_2020_4, {"CPU", "GPU"}}, + {"GlobalLpPool", V_2020_4, {"CPU", "GPU"}}, + {"GlobalMaxPool", V_2022_1, {"CPU", "GPU"}}, + {"Greater", V_2020_4, {"CPU", "GPU"}}, + {"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}}, + {"GridSample", V_2022_3, {"CPU"}}, + {"GridSample", V_2023_0, {"GPU"}}, + {"HardMax", V_2023_1, {"CPU", "GPU"}}, + {"Identity", V_2020_4, {"CPU", "GPU"}}, + {"If", V_2022_3, {"CPU", "GPU"}}, + {"ImageScaler", V_2022_1, {"CPU", "GPU"}}, + {"InstanceNormalization", V_2020_4, {"CPU", "GPU"}}, + {"HardSigmoid", V_2020_4, {"CPU", "GPU"}}, + {"HardMax", V_2022_1, {"CPU", "GPU"}}, + {"LayerNormalization", V_2023_0, {"CPU", "GPU"}}, + {"LeakyRelu", V_2020_4, {"CPU", "GPU"}}, + {"Less", V_2020_4, {"CPU", "GPU"}}, + {"LessOrEqual", V_2022_1, {"CPU", "GPU"}}, + {"Log", V_2020_4, {"CPU", "GPU"}}, + {"LogSoftMax", V_2022_1, {"CPU", "GPU"}}, + {"Loop", V_2021_4, {"CPU", "GPU"}}, + {"LpNormalization", V_2023_1, {"CPU", "GPU"}}, + {"LRN", V_2020_4, {"CPU", "GPU"}}, + {"LSTM", V_2020_4, {"CPU", "GPU"}}, + {"MatMul", V_2020_4, {"CPU", "GPU"}}, + {"MatMulInteger", V_2022_1, {"CPU"}}, + {"Max", V_2020_4, {"CPU", "GPU"}}, + {"MaxPool", V_2020_4, {"CPU", "GPU"}}, + {"Mean", V_2020_4, {"CPU", "GPU"}}, + {"MeanVarianceNormalization", V_2022_1, {"CPU", "GPU"}}, + {"Min", V_2020_4, {"CPU", "GPU"}}, + {"Mod", V_2022_1, {"CPU", "GPU"}}, + {"Mul", V_2020_4, {"CPU", "GPU"}}, + {"Neg", V_2020_4, {"CPU", "GPU"}}, + {"NonMaxSuppression", V_2021_1, {"CPU", "GPU"}}, + {"NonZero", V_2021_1, {"CPU"}}, + {"NonZero", V_2023_0, {"GPU"}}, + {"Not", V_2021_1, {"CPU", "GPU"}}, + {"Not", V_2020_4, {"CPU", "GPU"}}, + {"OneHot", V_2020_4, {"CPU", "GPU"}}, + {"Or", V_2022_1, {"CPU", "GPU"}}, + {"Pad", V_2020_4, {"CPU", "GPU"}}, + {"Pow", V_2020_4, {"CPU", "GPU"}}, + {"PRelu", V_2020_4, {"CPU", "GPU"}}, + {"QLinearMatMul", V_2022_3, {"CPU"}}, + {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, + {"RNN", V_2023_1, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormal", V_2023_0, {"CPU", "GPU"}}, + {"Range", V_2022_1, {"CPU", "GPU"}}, + {"Reciprocal", V_2020_4, {"CPU", "GPU"}}, + {"ReduceL1", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL2", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSum", V_2020_4, {"CPU"}}, + {"ReduceLogSum", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSumExp", V_2022_1, {"CPU", "GPU"}}, + {"ReduceMax", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMean", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMin", V_2020_4, {"CPU", "GPU"}}, + {"ReduceProd", V_2020_4, {"CPU"}}, + {"ReduceProd", V_2022_1, {"GPU"}}, + {"ReduceSum", V_2020_4, {"CPU", "GPU"}}, + {"ReduceSumSquare", V_2020_4, {"CPU"}}, + {"ReduceSumSquare", V_2022_1, {"CPU", "GPU"}}, + {"Relu", V_2020_4, {"CPU", "GPU"}}, + {"Resize", V_2020_4, {"CPU"}}, + {"Resize", V_2022_1, {"GPU"}}, + {"Reshape", V_2020_4, {"CPU", "GPU"}}, + {"ReverseSequence", V_2022_1, {"CPU", "GPU"}}, + {"RoiAlign", V_2021_1, {"CPU", "GPU"}}, + {"Round", V_2021_4, {"CPU", "GPU"}}, + {"Scatter", V_2022_1, {"CPU", "GPU"}}, + {"ScatterElements", V_2022_1, {"CPU", "GPU"}}, + {"ScatterND", V_2022_1, {"CPU", "GPU"}}, + {"Selu", V_2020_4, {"CPU", "GPU"}}, + {"Shape", V_2020_4, {"CPU", "GPU"}}, + {"Shrink", V_2022_1, {"CPU", "GPU"}}, + {"Sigmoid", V_2020_4, {"CPU", "GPU"}}, + {"Sign", V_2020_4, {"CPU"}}, + {"Sign", V_2022_1, {"GPU"}}, + {"Sin", V_2022_1, {"CPU", "GPU"}}, + {"Sinh", V_2020_4, {"CPU"}}, + {"Size", V_2022_1, {"CPU", "GPU"}}, + {"Slice", V_2020_4, {"CPU", "GPU"}}, + {"Softmax", V_2020_4, {"CPU", "GPU"}}, + {"Softplus", V_2022_1, {"CPU", "GPU"}}, + {"Softsign", V_2022_1, {"CPU", "GPU"}}, + {"SpaceToDepth", V_2020_4, {"CPU", "GPU"}}, + {"Split", V_2020_4, {"CPU", "GPU"}}, + {"Sqrt", V_2020_4, {"CPU", "GPU"}}, + {"Squeeze", V_2020_4, {"CPU", "GPU"}}, + {"Softsign", V_2020_4, {"CPU"}}, + {"Sub", V_2020_4, {"CPU", "GPU"}}, + {"Sum", V_2020_4, {"CPU", "GPU"}}, + {"Tan", V_2020_4, {"CPU", "GPU"}}, + {"Tanh", V_2020_4, {"CPU", "GPU"}}, + {"ThresholdedRelu", V_2022_1, {"CPU", "GPU"}}, + {"Tile", V_2021_3, {"CPU", "GPU"}}, + {"Transpose", V_2020_4, {"CPU", "GPU"}}, + {"Trilu", V_2023_0, {"CPU", "GPU"}}, + {"TopK", V_2020_4, {"CPU", "GPU"}}, + {"Upsample", V_2020_4, {"CPU", "GPU"}}, + {"Unsqueeze", V_2020_4, {"CPU", "GPU"}}, + {"Where", V_2022_1, {"CPU", "GPU"}}, + {"Xor", V_2022_1, {"CPU", "GPU"}}, +}; + +const OrtGraphApi* DataOps::graph_api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION)->GetGraphApi(ORT_API_VERSION); + +void DataOps::populate_types_supported() { + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)); + supported_types_initializer_.insert( + std::make_pair(V_2021_1, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)); + + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)); + supported_types_npu_.insert( + std::make_pair(V_2021_1, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)); + + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)); + supported_types_cpu_.insert( + std::make_pair(V_2022_2, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)); + + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)); + supported_types_gpu_.insert( + std::make_pair(V_2021_1, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)); + supported_types_gpu_.insert( + std::make_pair(V_2022_1, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)); +} + +void DataOps::populate_op_mode_supported() { + no_dimension_supported_.push_back({"Add", V_2022_1, {"All"}}); + no_dimension_supported_.push_back({"And", V_2022_1, {"All"}}); + no_dimension_supported_.push_back({"Cast", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"Ceil", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"Clip", V_2022_1, {"All"}}); + no_dimension_supported_.push_back({"Div", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"DequantizeLinear", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"Equal", V_2022_1, {"CPU"}}); + no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); + no_dimension_supported_.push_back({"Expand", V_2023_3, {"CPU"}}); + no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"Identity", V_2023_0, {"All"}}); + no_dimension_supported_.push_back({"Less", V_2022_1, {"CPU"}}); + no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"Neg", V_2023_0, {"CPU", "GPU"}}); + no_dimension_supported_.push_back({"Pow", V_2023_0, {"CPU", "GPU"}}); + no_dimension_supported_.push_back({"QuantizeLinear", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"Range", V_2021_2, {"All"}}); + no_dimension_supported_.push_back({"ReduceMax", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"ReduceMin", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"ReduceProd", V_2022_1, {"CPU", "GPU"}}); + no_dimension_supported_.push_back({"Reshape", V_2022_1, {"All"}}); + no_dimension_supported_.push_back({"Shape", V_2022_1, {"GPU"}}); + no_dimension_supported_.push_back({"Shape", V_2023_0, {"CPU"}}); + no_dimension_supported_.push_back({"Sqrt", V_2023_0, {"All"}}); + no_dimension_supported_.push_back({"Squeeze", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"Sub", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"Unsqueeze", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"Where", V_2021_2, {"All"}}); + + subgraph_supported_.push_back({"Cast", V_2020_4, {"All"}}); + subgraph_supported_.push_back({"Concat", V_2020_4, {"All"}}); + subgraph_supported_.push_back({"Div", V_2021_1, {"CPU"}}); + subgraph_supported_.push_back({"Gather", V_2020_4, {"All"}}); + subgraph_supported_.push_back({"Identity", V_2021_1, {"CPU"}}); + subgraph_supported_.push_back({"Mul", V_2020_4, {"All"}}); + subgraph_supported_.push_back({"Sub", V_2021_1, {"CPU"}}); + subgraph_supported_.push_back({"Transpose", V_2020_4, {"All"}}); + subgraph_supported_.push_back({"Unsqueeze", V_2020_4, {"All"}}); + + // populate unsupportedmode_t + { + UnsupportedOpMode obj = {{V_2024_1}, + [this](const OrtNode* node) { + size_t num_input = 0; + graph_api_->OrtNode_GetNumInputs(node, &num_input); + // If the Input of ReduceMax op is UINT8, it is rejected (Due to output mismatch) + for (size_t i = 0; i < num_input; i++) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, i, &input_name); + OrtValueInfoRef* value_info = nullptr; + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, input_name, &value_info); + ONNXTensorElementDataType dtype = value_info->data_type; + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + if (dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || + dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) + return true; + } + return false; + }}; + op_list_.insert({"ReduceMax", obj}); + } + { + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1}, + [this](const OrtNode* node) { + const char* input1_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, 1, &input1_name); + OrtValueInfoRef* value_info = nullptr; + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, input1_name, &value_info); + if (value_info->shape != nullptr) { + for (int i = 0; i < value_info->shape_len; i++) { + if (value_info->shape[i] == 0) { + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + return true; + } + } + } + + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + return false; + }}; + op_list_.insert({"Reshape", obj}); + } + { + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1}, + [this](const OrtNode* node) { + // If the operator is unsqueeze + // If axes is an input, then we cannot produce a static graph. + // Conversion fails in convert_function_to_cnn_network. + size_t num_input = 0; + graph_api_->OrtNode_GetNumInputs(node, &num_input); + for (size_t i = 0; i < num_input; i++) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, i, &input_name); + if (!strcmp(input_name, "axes")) { + return true; + } + } + return (!this->dimension_unsupported(node)); + }}; + op_list_.insert({"Unsqueeze", obj}); + } + { + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1}, + [this](const OrtNode* node) { + // check for attributes + size_t key_count = 0; + graph_api_->OrtNode_GetAttributeKeyCount(node, "scales", &key_count); + if (key_count > 0) { + int float_size = 0; + graph_api_->OrtNode_GetAttributeFloatSize(node, "scales", &float_size); + if (float_size > 2) { + float f0, f1; + graph_api_->OrtNode_GetAttributeIthFloat(node, "scales", 0, &f0); + graph_api_->OrtNode_GetAttributeIthFloat(node, "scales", 1, &f1); + if (f0 != 1.f || f1 != 1.f) return true; + } + } + + // check for input dimensions + const char* input0_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, 0, &input0_name); + OrtValueInfoRef* value_info = nullptr; + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, input0_name, &value_info); + if (value_info->shape != nullptr) { + if (value_info->shape_len == 1 || value_info->shape_len == 4) { + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + return true; + } + } + // x_arg supports only float, int8 and float16 type + ONNXTensorElementDataType dtype = value_info->data_type; + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + if (dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || + dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || + dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) + return false; + return true; + }}; + op_list_.insert({"Upsample", obj}); + } +} + +bool DataOps::op_is_supported(std::string name, std::vector& op_list) { + bool auto_support = false; + bool multi_support = false; + for (size_t i = 0; i < op_list.size(); i++) { + if (op_list[i].optype == name) { + if (op_list[i].version <= version_id_) { + auto it = op_list[i].device_type.begin(); + while (it != op_list[i].device_type.end()) { + // status variable is set to True if it's Hetero/Multi/Auto device type + bool status = false; + + // The operator to be marked true, it should be supported by either of the devices specified with HETERO + if (device_id_.find("HETERO") == 0) { + status = true; + if (device_id_.find(*it) != std::string::npos || (*it == "All")) { + return true; + } + } + + // The operator to be marked true, it should be supported by all the devices specified with MULTI/AUTO + if (device_id_.find("MULTI") == 0) { + status = true; + if ((*it == "All") || device_id_.find(*it) != std::string::npos) { + multi_support = true; + } + } + // The operator to be marked true, it should be supported by atleast CPU device specified with AUTO + if (device_id_.find("AUTO") == 0) { + if (std::string(*it).find("CPU") == std::string::npos) { + auto_support = false; + } else if ((*it == "All") || (device_id_.find(*it) != std::string::npos)) { + auto_support = true; + } + } + // if device supported is all then we support it + if (*it == "All") { + return true; + } + // check for device supported + if (status == false) { + if (device_id_.find(*it) != std::string::npos) { + return true; + } + } + it++; + } + } + } + } + if (device_id_.find("AUTO") == 0 && auto_support == true) { + return true; + } + if (device_id_.find("MULTI") == 0 && multi_support == true) { + return true; + } + return false; +} + +bool DataOps::type_is_supported(ONNXTensorElementDataType dtype, bool is_initializer) { + if (is_initializer) { + for (auto const& var : supported_types_initializer_) { + if ((var.first <= version_id_) && + (var.second == dtype)) { + return true; + } + } + +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Initializer Data Type is not supported" << std::endl; + } +#endif + return false; + } + if (device_id_.find("HETERO") != std::string::npos || + device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { + for (auto const& var : supported_types_npu_) { + if ((var.first <= version_id_) && + (var.second == dtype)) { + return true; + } + } + +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "I/O data type is not supported" << std::endl; + } +#endif + return false; + + } else if (device_id_ == "CPU") { + for (auto const& var : supported_types_cpu_) { + if ((var.first <= version_id_) && + (var.second == dtype)) { + return true; + } + } +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "I/O data type is not supported" << std::endl; + } +#endif + return false; + + } else if (device_id_ == "GPU") { + for (auto const& var : supported_types_gpu_) { + if ((var.first <= version_id_) && + (var.second == dtype)) { + return true; + } + } +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "I/O data type is not supported" << std::endl; + } +#endif + return false; + } + return true; +} + +bool DataOps::unsupported_op_mode(const OrtNode* node) { + bool result = false; + const char* optype = nullptr; + graph_api_->OrtNode_GetOpType(node, &optype); +// const char** initializers = nullptr; +// size_t initializers_count = 0; +// graph_api_->OrtGraph_GetAllInitializers(graph_viewer_, &initializers, &initializers_count); + + auto iter = op_list_.equal_range(std::string(optype)); + for (auto it = iter.first; it != iter.second; ++it) { + auto ob = it->second; + if (std::find(ob.ver.begin(), ob.ver.end(), version_id_) != ob.ver.end()) { + return ob.func(node); + } + } + return result; +} + +bool DataOps::dimension_unsupported(const OrtNode* node) { + const char* input0_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, 0, &input0_name); + OrtValueInfoRef* value_info = nullptr; + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, input0_name, &value_info); + if (value_info->shape == nullptr) { + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + return true; + } + size_t input_dims = value_info->shape_len; + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + const char* optype = nullptr; + graph_api_->OrtNode_GetOpType(node, &optype); + if (!strstr(optype, "Pool") && input_dims != 4 && input_dims != 5) return false; + + if (!strcmp(optype, "ReduceSum")) { + size_t key_count = 0; + int axes_size = 0; + graph_api_->OrtNode_GetAttributeKeyCount(node, "axes", &key_count); + if (key_count > 0) graph_api_->OrtNode_GetAttributeIntSize(node, "axes", &axes_size); + if (device_id_.find("GPU") != std::string::npos && axes_size == 0) return true; + if (axes_size == 0) return false; + } + return true; +} + +bool DataOps::node_is_supported(const NodeIndex node_idx) { + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer_, node_idx, &node); + const char* optype = nullptr; + graph_api_->OrtNode_GetOpType(node, &optype); + +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Node " << optype << std::endl; + } +#endif + + /* + 0. Check if node is in the unsupported list + 1. Check input and output data types are supported. + 2. Check if there is unsupported dimension in input and output shapes + 3. Check Op is supported + 3a. Check if Op is of known unsupported modes (edge cases). If yes return false right away. + 3b. If above is not true, check if the op is available in nGraph. + */ + + // Check 0 + if (!op_is_supported(optype, supported_op_mode)) { +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Node is not in the supported ops list" << std::endl; + } +#endif + return false; + } + + // Check 1 + bool are_types_supported = true; + ForEachNodeDef(graph_api_, graph_viewer_, node, [&are_types_supported, this](const char* arg_name, const OrtValueInfoRef* node_arg, bool is_input) { + bool is_initializer = false; + if (is_input) { + graph_api_->OrtGraph_IsConstantInitializer(graph_viewer_, arg_name, true, &is_initializer); + } + + bool is_supported = type_is_supported(node_arg->data_type, is_initializer); + are_types_supported &= is_supported; + }); + + if (!are_types_supported) { +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "DType is not supported" << std::endl; + } +#endif + return false; + } + + // Check 2 + + bool has_unsupported_dimension = false; + ForEachNodeDef(graph_api_, graph_viewer_, node, [&has_unsupported_dimension, this, &optype, &node](const char* arg_name, const OrtValueInfoRef* node_arg, bool is_input) { + if (is_input) { + bool is_constant_initializer = false; + graph_api_->OrtGraph_IsConstantInitializer(graph_viewer_, arg_name, true, &is_constant_initializer); + if (is_constant_initializer) + return; + } + + if (node_arg->shape_len == 0) { + if (op_is_supported(optype, no_dimension_supported_)) { + return; + } + if (npu_qdq_optimizer_enabled_) { + // Pad Op with DQ inputs will be optimized out in the qdq optimization pass, so mark those no dim Pad ops + // supported here + if (optype == "Pad") { + size_t num_inputs = 0; + graph_api_->OrtNode_GetNumInputs(node, &num_inputs); + for (int i = 0; i < num_inputs; i++) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, i, &input_name); + const OrtNode* DQ = nullptr; + graph_api_->OrtGraph_GetNodeProducingOutput(graph_viewer_, input_name, &DQ); + const char* dq_optype = nullptr; + graph_api_->OrtNode_GetOpType(DQ, &dq_optype); + if (!strcmp(dq_optype, "DequantizeLinear")) return; + } + } + } + has_unsupported_dimension = true; + return; + } + // Zero dimension check + for (int i = 0; i < node_arg->shape_len; i++) { + if (node_arg->shape[i] == 0) { + if (((device_id_.find("CPU") != std::string::npos) || (device_id_.find("GPU") != std::string::npos)) && + (strcmp(optype, "Expand") == 0 || strcmp(optype, "Equal") == 0 || + strcmp(optype, "Slice") == 0 || strcmp(optype, "Concat") == 0 || + strcmp(optype, "Shape") == 0)) { + return; + } + has_unsupported_dimension = true; + return; + } + } + }); + if (has_unsupported_dimension) { +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Dimension check failed" << std::endl; + } +#endif + + return false; + } + + // Check 3a + const char* domain = nullptr; + graph_api_->OrtNode_GetDomain(node, &domain); + if (strcmp(domain, kOnnxDomain) == 0 && unsupported_op_mode(node)) { + if (optype == "GatherElements") { + return true; + } +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Failed in unsupported op mode" << std::endl; + } +#endif + return false; + } + + return true; +} + +std::vector DataOps::GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers) { + std::vector unsupported_nodes_idx; + + size_t num_nodes = 0; + const size_t* nodes_topo_order = nullptr; + graph_api_->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer_, 0, &nodes_topo_order, &num_nodes); + for (int i = 0; i < num_nodes; i++) { + size_t node_idx = nodes_topo_order[i]; + if (node_is_supported(node_idx)) { + // Collect inputs that are initializers + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer_, node_idx, &node); + ForEachNodeDef(graph_api_, graph_viewer_, node, [&ng_required_initializers, this](const char* name, const OrtValueInfoRef* value_info, bool is_input) { + if (is_input) { + const char** initializer_names = nullptr; + size_t initializers_count = 0; + graph_api_->OrtGraph_GetAllInitializers(graph_viewer_, &initializer_names, &initializers_count); + for (int j = 0; j < initializers_count; j++) { + if (!strcmp(initializer_names[j], name)) ng_required_initializers.insert(std::string(name)); + } + graph_api_->ReleaseCharArray(initializer_names); + } + }); + } else { + unsupported_nodes_idx.push_back(node_idx); + } + } + return unsupported_nodes_idx; +} + +bool DataOps::IsOpSupportedOnlyInModel(std::string name) { + return ops_supported_only_in_model.find(name) != ops_supported_only_in_model.end(); +} + +bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, + const OrtNode* node) { + const char* optype = nullptr; + graph_api_->OrtNode_GetOpType(node, &optype); + if (!strcmp(optype, "Reshape")) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, 1, &input_name); + if (ng_required_initializers.find(std::string(input_name)) == ng_required_initializers.end()) { + return true; + } + } else if (!strcmp(optype, "Expand")) { + // nGraph only supports constant shape input values + const char* output_name = nullptr; + graph_api_->OrtNode_GetIthOutputName(node, 0, &output_name); + OrtValueInfoRef* value_info = nullptr; + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, output_name, &value_info); + if (value_info->data_type != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + return true; + } + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + } else if (!strcmp(optype, "RoiAlign")) { + const char* input0_name = nullptr, *input1_name = nullptr, *input2_name = nullptr, *output_name = nullptr; + graph_api_->OrtNode_GetIthInputName(node, 0, &input0_name); + graph_api_->OrtNode_GetIthInputName(node, 1, &input1_name); + graph_api_->OrtNode_GetIthInputName(node, 2, &input2_name); + graph_api_->OrtNode_GetIthOutputName(node, 0, &output_name); + OrtValueInfoRef* input0_info = nullptr, *input1_info = nullptr, *input2_info = nullptr, *output_info = nullptr; + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, input0_name, &input0_info); + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, input1_name, &input1_info); + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, input2_name, &input2_info); + graph_api_->OrtGraph_GetValueInfo(graph_viewer_, output_name, &output_info); + + if (input0_info->data_type != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 || + input1_info->data_type != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 || + input2_info->data_type != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || + output_info->data_type != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 + ) { + graph_api_->OrtGraph_ReleaseValueInfo(input0_info); + graph_api_->OrtGraph_ReleaseValueInfo(input1_info); + graph_api_->OrtGraph_ReleaseValueInfo(input2_info); + graph_api_->OrtGraph_ReleaseValueInfo(output_info); + return true; + } + graph_api_->OrtGraph_ReleaseValueInfo(input0_info); + graph_api_->OrtGraph_ReleaseValueInfo(input1_info); + graph_api_->OrtGraph_ReleaseValueInfo(input2_info); + graph_api_->OrtGraph_ReleaseValueInfo(output_info); + } + return false; +} + +bool DataOps::DoNotOmitSubGraph(const std::string& name) { + return op_is_supported(name, subgraph_supported_); +} + +bool DataOps::InsertNode(const std::string& optype) { + if (optype == "TopK" || optype == "NonZero") { + return true; + } + return false; +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/ov_versions/data_ops.h b/samples/openvino/ov_versions/data_ops.h new file mode 100644 index 0000000000000..853a5aa76229a --- /dev/null +++ b/samples/openvino/ov_versions/data_ops.h @@ -0,0 +1,98 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once +#include +#include +#include +#include +#include +#include +#include "../openvino_utils.h" + +namespace onnxruntime { +namespace openvino_ep { + +using VarianceFunc = std::function; + +enum versionNum { + V_2020_4, + V_2021_1, + V_2021_2, + V_2021_3, + V_2021_4, + V_2022_1, + V_2022_2, + V_2022_3, + V_2023_0, + V_2023_1, + V_2023_2, + V_2023_3, + V_2024_0, + V_2024_1 +}; + +using VersionNum = enum versionNum; + +struct supportedOp { + std::string optype; + VersionNum version; + std::vector device_type; +}; + +struct unsupportedOpMode { + std::vector ver; + VarianceFunc func; +}; + +using SupportedOp = struct supportedOp; +using UnsupportedOpMode = struct unsupportedOpMode; +using Pairs = std::pair; + +class DataOps { + private: + const OrtGraphViewer* graph_viewer_; + VersionNum version_id_; + std::string device_id_; + std::string device_precision_; + std::multimap op_list_; + std::vector subgraph_supported_; + std::vector no_dimension_supported_; + std::set supported_types_npu_; + std::set supported_types_cpu_; + std::set supported_types_gpu_; + std::set supported_types_initializer_; + bool npu_qdq_optimizer_enabled_; + static const OrtGraphApi* graph_api_; + +// protected: + void populate_op_mode_supported(); + void populate_types_supported(); + bool op_is_supported(std::string name, std::vector& list); + bool dimension_unsupported(const OrtNode* node); + bool unsupported_op_mode(const OrtNode* node); + bool type_is_supported(ONNXTensorElementDataType dtype, bool is_initializer); + bool node_is_supported(const NodeIndex node_idx); + + public: + DataOps(const OrtGraphViewer* graph_viewer_param, VersionNum ver, + const std::string dev_id, const bool npu_qdq_optimizer_enabled) + : graph_viewer_(graph_viewer_param), + version_id_(ver), + device_id_(dev_id), + npu_qdq_optimizer_enabled_(npu_qdq_optimizer_enabled) { + populate_op_mode_supported(); + populate_types_supported(); + } + + virtual std::vector GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers); + virtual bool IsOpSupportedOnlyInModel(std::string name); + virtual bool SpecialConditionForClusterSizeOne( + std::unordered_set& ng_required_initializers, const OrtNode* node); + virtual bool DoNotOmitSubGraph(const std::string& name); + virtual bool InsertNode(const std::string& name); + VersionNum GetVersion() const { return version_id_; } +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/ov_versions/utils.cc b/samples/openvino/ov_versions/utils.cc new file mode 100644 index 0000000000000..cd397979cc301 --- /dev/null +++ b/samples/openvino/ov_versions/utils.cc @@ -0,0 +1,313 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "utils.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4244 4245 5208) +#elif __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +#if defined(_MSC_VER) +#pragma warning(default : 4244 4245) +#elif __GNUC__ +#pragma GCC diagnostic pop +#endif + +namespace onnxruntime { +namespace openvino_ep { + +// Gets the input count of given node +//int GetInputCount(const Node* node, const InitializedTensorSet& initializer_set) { +// int count = 0; +// for (const auto& input : node->InputDefs()) { +// const auto& name = input->Name(); +// auto it = initializer_set.find(name); +// if (it == initializer_set.end()) { +// count++; +// } +// } +// return count; +//} + +// Ops which are supported only in models(as intermediate nodes) and not in unit tests +//bool IsOpSupportedOnlyInModel(std::string name) { +// std::set ops_supported_only_in_model = { +// "Cast", +// "Concat", +// "ConstantOfShape", +// "Dropout", +// "Einsum", +// "Expand", +// "EyeLike", +// "Exp", +// "GatherND", +// "Identity", +// "LayerNormalization", +// "NonMaxSuppression", +// "NonZero", +// "Not", +// "OneHot", +// "Pad", +// "Range", +// "ReduceMin", +// "Resize", +// "Round", +// "Shape", +// "Split", +// "TopK"}; +// return ops_supported_only_in_model.find(name) != ops_supported_only_in_model.end(); +//} + +void AppendClusterToSubGraph(const size_t* node_index, size_t node_count, + const std::vector& inputs, + const std::vector& outputs, + std::vector& cache) { + static size_t op_counter = 0; + + OrtMetaDef* meta_def = new OrtMetaDef(); + std::string name = "OpenVINO-EP-subgraph_" + std::to_string(++op_counter); + meta_def->name = new char [name.length() + 1]; + strcpy(meta_def->name, name.c_str()); + meta_def->domain = "com.intel.ai"; + meta_def->since_version = 1; + // TODO(leca): meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; + meta_def->input_len = inputs.size(); + meta_def->inputs = new char* [inputs.size()]; + for (int i = 0; i < inputs.size(); i++) { + strcpy(meta_def->inputs[i], inputs[i].c_str()); + } + meta_def->output_len = outputs.size(); + meta_def->outputs = new char* [outputs.size()]; + for (int i = 0; i < outputs.size(); i++) { + strcpy(meta_def->outputs[i], outputs[i].c_str()); + } + + OrtIndexedSubGraph* indexed_sub_graph = new OrtIndexedSubGraph(); + indexed_sub_graph->meta_def = meta_def; + indexed_sub_graph->node_index_len = node_count; + indexed_sub_graph->node_index = new size_t [node_count]; + for (int i = 0; i < node_count; i++) { + indexed_sub_graph->node_index[i] = node_index[i]; + } + + cache.push_back(indexed_sub_graph); +} + +//int GetOnnxOpSet(const GraphViewer& graph_viewer) { +// const auto& dm_to_ver = graph_viewer.DomainToVersionMap(); +// return dm_to_ver.at(kOnnxDomain); +//} + +/** + * Returns a vector clusters(or node_idx). For each unsupported node, the graph is split into 3 parts. + * supported_cluster + (UNsupported_node + rest_of_the_graph). This functions returns vector of all supported_clusters by nGraph + */ +std::vector> +GetPartitionedClusters(const std::vector& topological_order, + const std::vector& unsupported_nodes) { + std::vector> ng_clusters; + + auto prev = topological_order.begin(); + + for (const auto& unsup_node : unsupported_nodes) { + auto it = std::find(prev, topological_order.end(), unsup_node); + // Create a cluster vector[supported_node_idx, unsupported_node_idx) and append it to return list. + std::vector this_cluster{prev, it}; + if (!this_cluster.empty()) { + ng_clusters.push_back(std::move(this_cluster)); + } + if (it != topological_order.end()) { + // Point prev to node idx past this unsuported node. + prev = ++it; + } + } + + // Tail + std::vector this_cluster{prev, topological_order.end()}; + if (!this_cluster.empty()) { + ng_clusters.push_back(std::move(this_cluster)); + } + + return ng_clusters; +} + +void IdentifyConnectedNodes(const OrtGraphApi* graph_api, + const OrtGraphViewer* graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster) { + if (std::find(cluster.begin(), cluster.end(), curr_node_index) == cluster.end()) + return; + + sub_cluster.emplace_back(curr_node_index); + cluster.erase(std::remove(cluster.begin(), cluster.end(), curr_node_index), cluster.end()); + const OrtNode* curr_node = nullptr; + graph_api->OrtGraph_GetOrtNode(graph_viewer, curr_node_index, &curr_node); + + // TODO(leca): equivalent to for (auto node = curr_node->InputNodesBegin(); node != curr_node->InputNodesEnd(); ++node)? + // TODO(leca): consider implicit input? + size_t num_inputs = 0; + graph_api->OrtNode_GetNumInputs(curr_node, &num_inputs); + for (int i = 0; i < num_inputs; i++) { + const char* input_name = nullptr; + graph_api->OrtNode_GetIthInputName(curr_node, i, &input_name); + const OrtNode* producer_node = nullptr; + graph_api->OrtGraph_GetNodeProducingOutput(graph_viewer, input_name, &producer_node); + size_t producer_index = 0; + graph_api->OrtNode_GetIndex(producer_node, &producer_index); + IdentifyConnectedNodes(graph_api, graph_viewer, producer_index, cluster, sub_cluster); + } + + // TODO(leca): equal to for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) ? + size_t num_outputs = 0; + graph_api->OrtNode_GetNumOutputs(curr_node, &num_outputs); + for (int i = 0; i < num_outputs; i++) { + const char* output_name = nullptr; + graph_api->OrtNode_GetIthOutputName(curr_node, i, &output_name); + const OrtNode** consumer_nodes = nullptr; + size_t num_consumers = 0; + // TODO(leca): if there is one consumer consuming more than 1 output of curr_node, would it be visited twice? + graph_api->OrtGraph_GetNodesConsumingInput(graph_viewer, output_name, &consumer_nodes, &num_consumers); + for (int j = 0; j < num_consumers; j++) { + size_t consumer_index = 0; + graph_api->OrtNode_GetIndex(consumer_nodes[j], &consumer_index); + IdentifyConnectedNodes(graph_api, graph_viewer, consumer_index, cluster, sub_cluster); + } + // TODO(leca): release consumer_nodes + } +} + +std::vector> +GetConnectedClusters(const OrtGraphApi* graph_api, const OrtGraphViewer* graph_viewer, const std::vector>& clusters) { + std::vector> connected_clusters; + + for (auto this_cluster : clusters) { + while (this_cluster.size() > 0) { + std::vector sub_cluster; + IdentifyConnectedNodes(graph_api, graph_viewer, this_cluster[0], this_cluster, sub_cluster); + connected_clusters.emplace_back(sub_cluster); + } + } + return connected_clusters; +} + +//void GetInputsOutputsOfCluster(const OrtGraphApi* graph_api, +// const GraphViewer& graph_viewer, +// const std::vector& cluster, +// const std::unordered_set& ng_required_initializers, +// /*out*/ std::vector& cluster_graph_inputs, +// /*out*/ std::vector& cluster_inputs, +// /*out*/ std::vector& cluster_outputs) { +// std::unordered_set input_args; +// std::vector ordered_input_args; +// std::unordered_set output_args; +// std::unordered_set external_output_args; +// std::vector constant_inputs; +// +// for (const auto& node_idx : cluster) { +// const OrtNode* node = nullptr; +// graph_api->OrtGraph_GetOrtNode(graph_viewer, node_idx, &node); +// // Collect all inputs and outputs +// ForEachNodeDef(graph_api, graph_viewer, node, +// [&input_args, &ordered_input_args, &output_args](const char* arg_name, const OrtValueInfoRef*, bool is_input) { +// if (strcmp(arg_name, "") != 0) { +// if (is_input) { +// if (!input_args.count(std::string(arg_name))) { +// ordered_input_args.push_back(std::string(arg_name)); +// } +// input_args.insert(std::string(arg_name)); +// } else { +// output_args.insert(std::string(arg_name)); +// } +// } +// }); +// +// // Check if output of this node is used by nodes outside this_cluster. If yes add this to cluster outputs +// // TODO(leca): equal to for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) ? +// size_t num_outputs = 0; +// graph_api->OrtNode_GetNumOutputs(node, &num_outputs); +// for (int i = 0; i < num_outputs; i++) { +// const char* output_name = nullptr; +// graph_api->OrtNode_GetIthOutputName(node, i, &output_name); +// const OrtNode** consumer_nodes = nullptr; +// size_t num_consumers = 0; +// // TODO(leca): if there is one consumer consuming more than 1 output of curr_node, would it be visited twice? +// graph_api->OrtGraph_GetNodesConsumingInput(graph_viewer, output_name, &consumer_nodes, &num_consumers); +// for (int j = 0; j < num_consumers; j++) { +// size_t consumer_index = 0; +// graph_api->OrtNode_GetIndex(consumer_nodes[j], &consumer_index); +// +// if (std::find(cluster.begin(), cluster.end(), consumer_index) == cluster.end()) { +// // Node is external to this_cluster. Search through its inputs to +// // find the output that is generated by this_cluster. +// std::set ext_node_inputs; +// ForEachNodeDef(graph_api, graph_viewer, consumer_nodes[j], +// [&ext_node_inputs](const char* arg_name, const OrtValueInfoRef*, bool is_input) { +// if (is_input) { +// ext_node_inputs.insert(std::string(arg_name)); +// } +// }); +// +// for (int j = 0; j < num_outputs; j++) { +// const char* out_def = nullptr; +// graph_api->OrtNode_GetIthOutputName(node, j, &out_def); +// if (ext_node_inputs.find(std::string(out_def)) != ext_node_inputs.end()) { +// external_output_args.insert(std::string(out_def)); +// } +// } +// } +// } +// // TODO(leca): release consumer_nodes +// } +// } +// +// // Extract initializers used by this_cluster. +// std::unordered_set original_graph_inputs; +// const char** input_names = nullptr; +// size_t input_len = 0; +// graph_api->OrtGraph_GetAllInputs(graph_viewer, &input_names, &input_len); +// for (int i = 0; i < input_len; i++) { +// original_graph_inputs.insert(std::string(input_names[i])); +// } +// graph_api->ReleaseCharArray(input_names); + +// const char** initializer_names = nullptr; +// size_t initializer_len = 0; +// graph_api->OrtGraph_GetAllInitializers(graph_viewer, &initializer_names, &initializer_len); +// const auto& initializers = graph_viewer.GetAllInitializedTensors(); +// for (const auto& in_arg : ordered_input_args) { +// if ((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) || +// ng_required_initializers.count(in_arg)) { +// constant_inputs.push_back(in_arg); +// } +// } +// +// for (const auto& in_arg : ordered_input_args) { +// if (!output_args.count(in_arg) && +// !((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) || +// ng_required_initializers.count(in_arg))) { +// cluster_inputs.push_back(in_arg); +// } +// } +// for (const auto& input : cluster_inputs) { +// cluster_graph_inputs.push_back(input); +// } +// +// for (const auto& in_arg : constant_inputs) { +// cluster_inputs.push_back(in_arg); +// } +// +// std::copy(external_output_args.begin(), external_output_args.end(), std::back_inserter(cluster_outputs)); +// for (const auto& node_arg : graph_viewer.GetOutputs()) { +// const auto& name = node_arg->Name(); +// if (output_args.count(name) && !external_output_args.count(name)) { +// cluster_outputs.push_back(name); +// } +// } +//} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/ov_versions/utils.h b/samples/openvino/ov_versions/utils.h new file mode 100644 index 0000000000000..2cd27bbee23a9 --- /dev/null +++ b/samples/openvino/ov_versions/utils.h @@ -0,0 +1,54 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../openvino_utils.h" + +namespace onnxruntime { +namespace openvino_ep { + +//int GetInputCount(const Node* node, const InitializedTensorSet& initializer_set); + +//bool IsOpSupportedOnlyInModel(std::string name); + +void AppendClusterToSubGraph(const size_t* node_index, size_t node_count, + const std::vector& inputs, + const std::vector& outputs, + std::vector& cache); + +//int GetOnnxOpSet(const GraphViewer& graph_viewer); + +//std::map> GetNgSupportedOps(const int onnx_opset); + +std::vector> +GetPartitionedClusters( + const std::vector& topological_order, const std::vector& unsupported_nodes); + +void IdentifyConnectedNodes( + const OrtGraphApi* graph_api, + const OrtGraphViewer* graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster); + +std::vector> +GetConnectedClusters(const OrtGraphApi* graph_api, const OrtGraphViewer* graph_viewer, const std::vector>& clusters); + +//void GetInputsOutputsOfCluster(const OrtGraphApi* graph_api, +// const GraphViewer& graph_viewer, +// const std::vector& cluster, +// const std::unordered_set& ng_required_initializers, +// /*out*/ std::vector& cluster_graph_inputs, +// /*out*/ std::vector& cluster_inputs, +// /*out*/ std::vector& cluster_outputs); + +} // namespace openvino_ep +} // namespace onnxruntime From a1a3eead86fb40b07e497878649993b993b8af91 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Thu, 31 Oct 2024 17:24:32 -0700 Subject: [PATCH 59/81] openvino GetCapacity() is done. UnregisterPluginExecutionProviderLibrary --- .../onnxruntime/core/session/environment.h | 6 + .../core/session/onnxruntime_c_api.h | 11 + onnxruntime/__init__.py | 3 +- .../core/framework/provider_factory_adapter.h | 3 + onnxruntime/core/session/environment.cc | 14 ++ onnxruntime/core/session/onnxruntime_c_api.cc | 5 + onnxruntime/core/session/ort_apis.h | 2 + onnxruntime/core/session/ort_env.cc | 4 + onnxruntime/core/session/ort_env.h | 2 + .../python/onnxruntime_pybind_state.cc | 14 +- .../test/python/onnxruntime_test_plugin_ep.py | 15 +- samples/c_test/test.cpp | 8 + samples/openvino/backend_manager.cc | 137 +++++----- samples/openvino/backend_manager.h | 49 ++-- samples/openvino/onnx_ctx_model_helper.cc | 29 ++- samples/openvino/onnx_ctx_model_helper.h | 2 +- .../openvino/openvino_execution_provider.cc | 50 ++++ samples/openvino/ov_versions/capability.cc | 12 +- samples/openvino/ov_versions/utils.cc | 233 +++++++++--------- samples/openvino/ov_versions/utils.h | 14 +- 20 files changed, 369 insertions(+), 244 deletions(-) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 2b05bc08ac376..199c6f7cdeb67 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -5,6 +5,7 @@ #include #include +#include #include "core/common/common.h" #include "core/common/status.h" #include "core/platform/threadpool.h" @@ -89,10 +90,15 @@ class Environment { */ Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg = nullptr); + // TODO(leca): return Status to handle corner cases (plugin EP Factory already exists, etc.) void InsertPluginEpFactory(const char* ep_name, OrtExecutionProviderFactory* ep_factory); OrtExecutionProviderFactory* GetPluginExecutionProviderFactory(const std::string& ep_name); + Status DeletePluginEpFactory(const char* ep_name); + + std::unordered_set GetPluginEpFactoryNames(); + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); Status Initialize(std::unique_ptr logging_manager, diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d317ca02ac407..ee0b52347975c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4728,6 +4728,17 @@ struct OrtApi { */ ORT_API2_STATUS(RegisterPluginExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); + /** \brief Unregister the plugin ExecutionProvider library + * + * The plugin ExecutionProvider factory will be removed from OrtEnv object + * + * \param[in] env OrtEnv object + * \param[in] ep_name the plugin ExecutionProvider name + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(UnregisterPluginExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* ep_name); + /** \brief Append the plugin ExecutionProvider factory into the session option with provider options * * \param[in] options OrtSessionOptions object diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index c96bf331706e0..528ad427fbf94 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -43,10 +43,11 @@ from onnxruntime.capi._pybind_state import get_device # noqa: F401 from onnxruntime.capi._pybind_state import get_version_string # noqa: F401 from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401 + from onnxruntime.capi._pybind_state import register_plugin_execution_provider_library # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401 from onnxruntime.capi._pybind_state import set_seed # noqa: F401 - from onnxruntime.capi._pybind_state import register_plugin_execution_provider_library # noqa: F401 + from onnxruntime.capi._pybind_state import unregister_plugin_execution_provider # noqa: F401 import_capi_exception = None except Exception as e: diff --git a/onnxruntime/core/framework/provider_factory_adapter.h b/onnxruntime/core/framework/provider_factory_adapter.h index 89695e88220f8..9b87c11a743ee 100644 --- a/onnxruntime/core/framework/provider_factory_adapter.h +++ b/onnxruntime/core/framework/provider_factory_adapter.h @@ -25,6 +25,9 @@ std::unique_ptr CreateProvider() override { return std::make_unique(ep_factory_->CreateExecutionProvider(ep_factory_, keys_.data(), values_.data(), provider_option_length_)); } OrtExecutionProviderFactory* ep_factory_; +// Have to keep both provider_option_keys_ and keys_ to make provider options local, otherwise when CreateProvider() +// is invoked, keys_ will point to the memory which is already released +// Or TODO(leca): use std::vector> keys_ and copy over the provider options in the constructor std::vector provider_option_keys_, provider_option_values_; std::vector keys_, values_; size_t provider_option_length_; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 9e6e0d3ba003f..6271542af9135 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -359,4 +359,18 @@ OrtExecutionProviderFactory* Environment::GetPluginExecutionProviderFactory(cons return it->second.get(); } +Status Environment::DeletePluginEpFactory(const char* ep_name) { + size_t ret = plugin_ep_factories_.erase(ep_name); + if (ret) return Status::OK(); + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "cannot delete the plugin EpFactory: " + std::string(ep_name)); +} + +std::unordered_set Environment::GetPluginEpFactoryNames() { + std::unordered_set ret; + for (const auto& [k, v] : plugin_ep_factories_) { + ret.insert(k); + } + return ret; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index c1a0b726c5d7e..e78abd82b0ad9 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2402,6 +2402,10 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterPluginExecutionProviderLibrary, _In_ const API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::UnregisterPluginExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* ep_name) { + return ToOrtStatus(env->DeletePluginEpFactory(ep_name)); +} + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendPluginExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { OrtExecutionProviderFactory* ep_factory = env->GetPluginExecutionProviderFactory(ep_name); @@ -2822,6 +2826,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::DeviceGetId, &OrtApis::ReleaseDevice, &OrtApis::RegisterPluginExecutionProviderLibrary, + &OrtApis::UnregisterPluginExecutionProviderLibrary, &OrtApis::SessionOptionsAppendPluginExecutionProvider, &OrtApis::CreateOrtTypeConstraints, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 6530bfb6205c1..33709936c9454 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -536,6 +536,8 @@ ORT_API(void, ReleaseDevice, _Frees_ptr_opt_ OrtDevice*); ORT_API_STATUS_IMPL(RegisterPluginExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); +ORT_API_STATUS_IMPL(UnregisterPluginExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* ep_name); + ORT_API_STATUS_IMPL(SessionOptionsAppendPluginExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index fecd7a85ce9b1..7fe3064b36b61 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -119,3 +119,7 @@ void OrtEnv::InsertPluginEpFactory(const char* ep_name, OrtExecutionProviderFact OrtExecutionProviderFactory* OrtEnv::GetPluginExecutionProviderFactory(const char* ep_name) { return value_->GetPluginExecutionProviderFactory(ep_name); } + +onnxruntime::common::Status OrtEnv::DeletePluginEpFactory(const char* ep_name) { + return value_->DeletePluginEpFactory(ep_name); +} diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 621f6d8096953..6ec22ed5c6583 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -70,6 +70,8 @@ struct OrtEnv { OrtExecutionProviderFactory* GetPluginExecutionProviderFactory(const char* ep_name); + onnxruntime::common::Status DeletePluginEpFactory(const char* ep_name); + private: static std::unique_ptr p_instance_; static onnxruntime::OrtMutex m_; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7e0cd174a42a7..43c5203c6f0fc 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1330,11 +1330,6 @@ static void LogDeprecationWarning( } #endif -// TODO(leca): when will this variable be unset? It is saved in Environment thus should be cross-session, which means -// once the session ends, the plugin ep should still be left in the Environment -// Should implement Environment::RemovePluginEp() which will be invoked in ~EnvInitializer(), and also clear plugin_execution_providers there -static std::unordered_set plugin_execution_providers; - void addGlobalMethods(py::module& m) { m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance."); m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer."); @@ -1491,10 +1486,15 @@ void addGlobalMethods(py::module& m) { OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, "RegisterCustomEp", (void**)&symbol)); auto env = GetEnv(); env->InsertPluginEpFactory(provider_type, symbol()); - plugin_execution_providers.insert(std::string(provider_type)); } }); - m.def("get_available_plugin_providers", []() -> std::unordered_set { return plugin_execution_providers; }); + m.def("get_available_plugin_providers", []() -> std::unordered_set { return GetEnv()->GetPluginEpFactoryNames(); }); + m.def("unregister_plugin_execution_provider", [](const char* provider_type) { + Status status = GetEnv()->DeletePluginEpFactory(provider_type); + if (!status.IsOK()) { + throw std::runtime_error(status.ErrorMessage()); + } + }); } void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { diff --git a/onnxruntime/test/python/onnxruntime_test_plugin_ep.py b/onnxruntime/test/python/onnxruntime_test_plugin_ep.py index fcbccbe2931c4..db94f634b0ca5 100644 --- a/onnxruntime/test/python/onnxruntime_test_plugin_ep.py +++ b/onnxruntime/test/python/onnxruntime_test_plugin_ep.py @@ -1,15 +1,20 @@ -import onnxruntime as ort import numpy -#ort.register_plugin_execution_provider_library("outTreeEp", "/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so") -ort.register_plugin_execution_provider_library("tensorrtEp", "/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so") +import onnxruntime as ort + +ort.register_plugin_execution_provider_library("outTreeEp", "/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so") +#ort.register_plugin_execution_provider_library("tensorrtEp", "/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so") sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL #session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=[("CPUExecutionProvider")]) #session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) -#session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["outTreeEp", "CPUExecutionProvider"], provider_options=[{"int_property":"3", "str_property":"strvalue"}, {}]) -session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["tensorrtEp", "CPUExecutionProvider"], provider_options=[{"device_id":"0", "str_property":"strvalue"}, {}]) +session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["outTreeEp", "CPUExecutionProvider"], provider_options=[{"int_property":"3", "str_property":"strvalue"}, {}]) + +# runtime error for tensorrtEp for using two different map in libonnx.a, as onnxruntime_pybind_state.so is linked to static ORT libs +#session = ort.InferenceSession("/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", sess_options, providers=["tensorrtEp", "CPUExecutionProvider"], provider_options=[{"device_id":"0", "str_property":"strvalue"}, {}]) y = session.run(None, {'x': numpy.array([-3.0, 5.0, -2.0, 4.0]).astype(numpy.float32)}) print(y) + +ort.unregister_plugin_execution_provider("outTreeEp") diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 35ee9348b92d1..cbbd568a418d0 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -347,6 +347,14 @@ int main(int argc, char *argv[]) { RunControlFlow(p_env, so); } + if (!strcmp(argv[1], "c")) { + g_ort->UnregisterPluginExecutionProviderLibrary(p_env, "outTreeEp"); + } else if (!strcmp(argv[1], "k")) { + g_ort->UnregisterPluginExecutionProviderLibrary(p_env, "kernelEp"); + } else if (!strcmp(argv[1], "t") || !strcmp(argv[1], "tc")) { + g_ort->UnregisterPluginExecutionProviderLibrary(p_env, "tensorrtEp"); + } + g_ort->ReleaseEnv(p_env); return 0; } diff --git a/samples/openvino/backend_manager.cc b/samples/openvino/backend_manager.cc index 896aea8830b54..67af0bd589952 100644 --- a/samples/openvino/backend_manager.cc +++ b/samples/openvino/backend_manager.cc @@ -18,52 +18,58 @@ namespace onnxruntime { namespace openvino_ep { +const OrtGraphApi* BackendManager::graph_api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION)->GetGraphApi(ORT_API_VERSION); -//GlobalContext& BackendManager::GetGlobalContext() { -// return global_context_; -//} -// -//BackendManager::BackendManager(const GlobalContext& global_context, -// const onnxruntime::Node& fused_node, -// const onnxruntime::GraphViewer& subgraph, -// const logging::Logger& logger, -// EPCtxHandler& ctx_handle) { -// global_context_ = global_context; -// ep_ctx_handle_ = ctx_handle; -// -// openvino_sdk_version_ = std::to_string(global_context_.OpenVINO_Version.at(0)) + "." + -// std::to_string(global_context_.OpenVINO_Version.at(1)); -// if (ep_ctx_handle_.CheckForOVEPCtxNode(subgraph, openvino_sdk_version_)) { -// if (ep_ctx_handle_.ImportBlobFromEPCtxModel(subgraph) != Status::OK()) -// ORT_THROW("Import blob from model failed"); -// } -// -// // Save the indexes of graph inputs among fused_node's inputDefs -// // (which also contains initializers). -// auto node_input_defs = fused_node.InputDefs(); -// int i = 0; -// for (auto idef : node_input_defs) { -// subgraph_context_.input_names.insert({idef->Name(), i}); -// i++; -// } -// -// const std::vector& graph_inputs = subgraph.GetInputs(); -// for (auto input : graph_inputs) { -// auto it = subgraph_context_.input_names.find(input->Name()); -// if (it == subgraph_context_.input_names.end()) { -// ORT_THROW("Input not found in the input defs list"); -// } -// int index = it->second; -// subgraph_context_.input_indexes.push_back(index); -// } -// -// auto graph_outputs_defs = fused_node.OutputDefs(); -// i = 0; -// for (auto output_def : graph_outputs_defs) { -// subgraph_context_.output_names.insert({output_def->Name(), i}); -// i++; -// } -// subgraph_context_.subgraph_name = fused_node.Name(); +GlobalContext& BackendManager::GetGlobalContext() { + return global_context_; +} + +BackendManager::BackendManager(const GlobalContext& global_context, + const OrtNode* fused_node, + const OrtGraphViewer* subgraph, + EPCtxHandler& ctx_handle) { + global_context_ = global_context; + ep_ctx_handle_ = ctx_handle; + + openvino_sdk_version_ = std::to_string(global_context_.OpenVINO_Version.at(0)) + "." + + std::to_string(global_context_.OpenVINO_Version.at(1)); + if (ep_ctx_handle_.CheckForOVEPCtxNode(subgraph, openvino_sdk_version_)) { + if (ep_ctx_handle_.ImportBlobFromEPCtxModel(subgraph) != nullptr) + throw std::runtime_error("Import blob from model failed"); + } + + // Save the indexes of graph inputs among fused_node's inputDefs + // (which also contains initializers). + size_t input_count = 0; + graph_api_->OrtNode_GetNumInputs(fused_node, &input_count); + for (int i = 0; i < input_count; i++) { + const char* input_name = nullptr; + graph_api_->OrtNode_GetIthInputName(fused_node, i, &input_name); + subgraph_context_.input_names.insert({input_name, i}); + } + + const char** graph_inputs = nullptr; + graph_api_->OrtGraph_GetRequiredInputs(subgraph, &graph_inputs, &input_count); + for (int i = 0; i < input_count; i++) { + auto it = subgraph_context_.input_names.find(std::string(graph_inputs[i])); + if (it == subgraph_context_.input_names.end()) { + throw std::runtime_error("Input not found in the input defs list"); + } + int index = it->second; + subgraph_context_.input_indexes.push_back(index); + } + graph_api_->ReleaseCharArray(graph_inputs); + + size_t output_count = 0; + graph_api_->OrtNode_GetNumOutputs(fused_node, &output_count); + for (int i = 0; i < output_count; i++) { + const char* output_name = nullptr; + graph_api_->OrtNode_GetIthOutputName(fused_node, i, &output_name); + subgraph_context_.output_names.insert({output_name, i}); + } + const char* subgraph_name = nullptr; + graph_api_->OrtNode_GetName(fused_node, &subgraph_name); + subgraph_context_.subgraph_name = std::string(subgraph_name); // model_proto_ = GetModelProtoFromFusedNode(fused_node, subgraph, logger); // std::string device_type = openvino_ep::BackendManager::GetGlobalContext().device_type; // @@ -128,22 +134,21 @@ namespace openvino_ep { //#endif // } // } -//} -// -//// Call EPContext model exporter here if the provider option for exporting -//// precompiled blob is set. If that's the case: -//// By default, create model in embed mode where the blob stream is exported as data within -//// the EPContext node. -//Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& graph_body_viewer, -// const logging::Logger& logger) { -// if (GetGlobalContext().disable_dynamic_shapes && subgraph_context_.has_dynamic_input_shape) { -// std::string exception_str = -// "Exporting dynamically compiled models at runtime is not supported. " -// "Cannot export blobs of dynamic models that request static shape inference. " -// "To export this model, set disable_dynamic_shapes to False"; -// ORT_THROW(exception_str); -// } -// +} + +// Call EPContext model exporter here if the provider option for exporting +// precompiled blob is set. If that's the case: +// By default, create model in embed mode where the blob stream is exported as data within +// the EPContext node. +OrtStatus* BackendManager::ExportCompiledBlobAsEPCtxNode(const OrtGraphViewer* graph_body_viewer) { + if (GetGlobalContext().disable_dynamic_shapes && subgraph_context_.has_dynamic_input_shape) { + std::string exception_str = + "Exporting dynamically compiled models at runtime is not supported. " + "Cannot export blobs of dynamic models that request static shape inference. " + "To export this model, set disable_dynamic_shapes to False"; + throw std::runtime_error(exception_str); + } + // std::string model_blob_str; // auto compiled_model = concrete_backend_->GetOVCompiledModel(); // auto graph_name = global_context_.onnx_model_path_name; @@ -174,8 +179,8 @@ namespace openvino_ep { // openvino_sdk_version_, // GetGlobalContext().device_type)); // -// return Status::OK(); -//} + return nullptr; +} // //bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { // bool has_batched_inputs = true; @@ -468,9 +473,9 @@ std::string MakeMapKeyString(const std::vector>& shapes, // } //#endif //} -// -//void BackendManager::ShutdownBackendManager() { -//} + +void BackendManager::ShutdownBackendManager() { +} } // namespace openvino_ep } // namespace onnxruntime diff --git a/samples/openvino/backend_manager.h b/samples/openvino/backend_manager.h index a057c10377a01..aa705db94c8cb 100644 --- a/samples/openvino/backend_manager.h +++ b/samples/openvino/backend_manager.h @@ -13,25 +13,23 @@ #include "onnx_ctx_model_helper.h" #include "ibackend.h" -//namespace onnxruntime { -//namespace openvino_ep { -// -//// Singleton class that manages all the backends -//class BackendManager { -// public: -// BackendManager(const GlobalContext& global_context, -// const onnxruntime::Node& fused_node, -// const onnxruntime::GraphViewer& subgraph, -// const logging::Logger& logger, -// EPCtxHandler& ctx_handle); +namespace onnxruntime { +namespace openvino_ep { + +// Singleton class that manages all the backends +class BackendManager { + public: + BackendManager(const GlobalContext& global_context, + const OrtNode* fused_node, + const OrtGraphViewer* subgraph, + EPCtxHandler& ctx_handle); // void Compute(OrtKernelContext* context); -// void ShutdownBackendManager(); + void ShutdownBackendManager(); // void SetGlobalCotext(const GlobalContext& global_context); -// GlobalContext& GetGlobalContext(); -// Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, -// const logging::Logger& logger); -// -// private: + GlobalContext& GetGlobalContext(); + OrtStatus* ExportCompiledBlobAsEPCtxNode(const OrtGraphViewer* subgraph); + + private: // std::unique_ptr GetModelProtoFromFusedNode( // const onnxruntime::Node& fused_node, // const onnxruntime::GraphViewer& subgraph, @@ -50,12 +48,13 @@ // std::unique_ptr model_proto_; // std::shared_ptr concrete_backend_; // std::map> backend_map_; -// SubGraphContext subgraph_context_; -// GlobalContext global_context_; -// EPCtxHandler ep_ctx_handle_{}; -// std::string openvino_sdk_version_{}; -//}; -// -//} // namespace openvino_ep -//} // namespace onnxruntime + SubGraphContext subgraph_context_; + GlobalContext global_context_; + EPCtxHandler ep_ctx_handle_{}; + std::string openvino_sdk_version_{}; + static const OrtGraphApi* graph_api_; +}; + +} // namespace openvino_ep +} // namespace onnxruntime // diff --git a/samples/openvino/onnx_ctx_model_helper.cc b/samples/openvino/onnx_ctx_model_helper.cc index ec72897546582..cf8491082037c 100644 --- a/samples/openvino/onnx_ctx_model_helper.cc +++ b/samples/openvino/onnx_ctx_model_helper.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include "onnx_ctx_model_helper.h" #include "openvino_utils.h" @@ -96,19 +97,23 @@ static const char SOURCE[] = "source"; // // return Status::OK(); //} -// -//Status EPCtxHandler::ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer) { -// auto node = graph_viewer.GetNode(0); -// auto& attrs = node->GetAttributes(); -// ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) > 0); -// -// model_stream_ = std::make_shared(attrs.at(EP_CACHE_CONTEXT).s()); -// + +OrtStatus* EPCtxHandler::ImportBlobFromEPCtxModel(const OrtGraphViewer* graph_viewer) { + const OrtNode* node = nullptr; + graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); + size_t attr_count = 0; + graph_api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT, &attr_count); + assert(attr_count > 0); + + const char* attr_str = nullptr; + graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT, &attr_str); + model_stream_ = std::make_shared(attr_str); + // LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; -// -// is_valid_ep_ctx_graph_ = true; -// return Status::OK(); -//} + + is_valid_ep_ctx_graph_ = true; + return nullptr; +} bool EPCtxHandler::CheckForOVEPCtxNode(const OrtGraphViewer* graph_viewer, std::string openvino_sdk_version) const { int max_node_index = 0; diff --git a/samples/openvino/onnx_ctx_model_helper.h b/samples/openvino/onnx_ctx_model_helper.h index 0c12d31869158..26d8b7d43e7ba 100644 --- a/samples/openvino/onnx_ctx_model_helper.h +++ b/samples/openvino/onnx_ctx_model_helper.h @@ -22,7 +22,7 @@ class EPCtxHandler { // const std::string& model_blob_str, // const std::string& openvino_sdk_version, // const std::string& device_type) const; -// Status ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer); + OrtStatus* ImportBlobFromEPCtxModel(const OrtGraphViewer* graph_viewer); bool CheckForOVEPCtxNode(const OrtGraphViewer* graph_viewer, std::string openvino_sdk_version) const; bool IsValidOVEPCtxGraph() const { return is_valid_ep_ctx_graph_; } [[nodiscard]] const std::shared_ptr GetModelBlobStream() const { return model_stream_; } diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc index 0a6e5c16f4177..cc016e0eb1c05 100644 --- a/samples/openvino/openvino_execution_provider.cc +++ b/samples/openvino/openvino_execution_provider.cc @@ -61,6 +61,56 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const }; OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info) -> OrtStatusPtr { + OpenVINOExecutionProvider* p = static_cast(this_); + for (int i = 0; i < cnt; i++) { + p->global_context_->use_api_2 = true; + + // During backend creation, we check if user wants to use precompiled blob onnx model or the original model + // For precompiled blob, directly load the model instead of compiling the model + // For original model, check if the user wants to export a model with pre-compiled blob + + std::shared_ptr backend_manager = + std::make_shared(*p->global_context_, + node[i], + graph[i], + p->ep_ctx_handle_); + + if (p->global_context_->export_ep_ctx_blob && !p->ep_ctx_handle_.IsValidOVEPCtxGraph()) { + backend_manager->ExportCompiledBlobAsEPCtxNode(graph[i]); + } + + node_compute_info[i].CreateFunctionStateFunc = nullptr; + node_compute_info[i].ComputeFunc = nullptr; + node_compute_info[i].DestroyFunctionStateFunc = nullptr; +// compute_info.create_state_func = +// [backend_manager](ComputeContext* context, FunctionState* state) { +// OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState(); +// p->allocate_func = context->allocate_func; +// p->destroy_func = context->release_func; +// p->allocator_handle = context->allocator_handle; +// p->backend_manager = backend_manager; +// *state = static_cast(p); +// return 0; +// }; +// compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { +// auto function_state = static_cast(state); +// try { +// function_state->backend_manager->Compute(context); +// } catch (const std::exception& ex) { +// return common::Status(common::ONNXRUNTIME, common::FAIL, ex.what()); +// } +// return Status::OK(); +// }; +// +// compute_info.release_state_func = +// [](FunctionState state) { +// if (state) { +// OpenVINOEPFunctionState* function_state = static_cast(state); +// delete function_state; +// } +// }; +// node_compute_funcs.push_back(compute_info); + } return nullptr; }; diff --git a/samples/openvino/ov_versions/capability.cc b/samples/openvino/ov_versions/capability.cc index c1cc771653818..b102e932b9f04 100644 --- a/samples/openvino/ov_versions/capability.cc +++ b/samples/openvino/ov_versions/capability.cc @@ -184,12 +184,12 @@ size_t GetCapability::Execute(OrtIndexedSubGraph*** indexed_sub_graph) { std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; -// GetInputsOutputsOfCluster(graph_api_, graph_viewer_, -// this_cluster, -// ng_required_initializers, -// cluster_graph_inputs, -// cluster_inputs, -// cluster_outputs); + GetInputsOutputsOfCluster(graph_api_, graph_viewer_, + this_cluster, + ng_required_initializers, + cluster_graph_inputs, + cluster_inputs, + cluster_outputs); bool omit_subgraph = false; // Omitting zero dim subgraphs diff --git a/samples/openvino/ov_versions/utils.cc b/samples/openvino/ov_versions/utils.cc index cd397979cc301..5db4458edc0eb 100644 --- a/samples/openvino/ov_versions/utils.cc +++ b/samples/openvino/ov_versions/utils.cc @@ -1,6 +1,7 @@ // Copyright (C) Intel Corporation // Licensed under the MIT License +#include #include "utils.h" #if defined(_MSC_VER) @@ -194,120 +195,124 @@ GetConnectedClusters(const OrtGraphApi* graph_api, const OrtGraphViewer* graph_v return connected_clusters; } -//void GetInputsOutputsOfCluster(const OrtGraphApi* graph_api, -// const GraphViewer& graph_viewer, -// const std::vector& cluster, -// const std::unordered_set& ng_required_initializers, -// /*out*/ std::vector& cluster_graph_inputs, -// /*out*/ std::vector& cluster_inputs, -// /*out*/ std::vector& cluster_outputs) { -// std::unordered_set input_args; -// std::vector ordered_input_args; -// std::unordered_set output_args; -// std::unordered_set external_output_args; -// std::vector constant_inputs; -// -// for (const auto& node_idx : cluster) { -// const OrtNode* node = nullptr; -// graph_api->OrtGraph_GetOrtNode(graph_viewer, node_idx, &node); -// // Collect all inputs and outputs -// ForEachNodeDef(graph_api, graph_viewer, node, -// [&input_args, &ordered_input_args, &output_args](const char* arg_name, const OrtValueInfoRef*, bool is_input) { -// if (strcmp(arg_name, "") != 0) { -// if (is_input) { -// if (!input_args.count(std::string(arg_name))) { -// ordered_input_args.push_back(std::string(arg_name)); -// } -// input_args.insert(std::string(arg_name)); -// } else { -// output_args.insert(std::string(arg_name)); -// } -// } -// }); -// -// // Check if output of this node is used by nodes outside this_cluster. If yes add this to cluster outputs -// // TODO(leca): equal to for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) ? -// size_t num_outputs = 0; -// graph_api->OrtNode_GetNumOutputs(node, &num_outputs); -// for (int i = 0; i < num_outputs; i++) { -// const char* output_name = nullptr; -// graph_api->OrtNode_GetIthOutputName(node, i, &output_name); -// const OrtNode** consumer_nodes = nullptr; -// size_t num_consumers = 0; -// // TODO(leca): if there is one consumer consuming more than 1 output of curr_node, would it be visited twice? -// graph_api->OrtGraph_GetNodesConsumingInput(graph_viewer, output_name, &consumer_nodes, &num_consumers); -// for (int j = 0; j < num_consumers; j++) { -// size_t consumer_index = 0; -// graph_api->OrtNode_GetIndex(consumer_nodes[j], &consumer_index); -// -// if (std::find(cluster.begin(), cluster.end(), consumer_index) == cluster.end()) { -// // Node is external to this_cluster. Search through its inputs to -// // find the output that is generated by this_cluster. -// std::set ext_node_inputs; -// ForEachNodeDef(graph_api, graph_viewer, consumer_nodes[j], -// [&ext_node_inputs](const char* arg_name, const OrtValueInfoRef*, bool is_input) { -// if (is_input) { -// ext_node_inputs.insert(std::string(arg_name)); -// } -// }); -// -// for (int j = 0; j < num_outputs; j++) { -// const char* out_def = nullptr; -// graph_api->OrtNode_GetIthOutputName(node, j, &out_def); -// if (ext_node_inputs.find(std::string(out_def)) != ext_node_inputs.end()) { -// external_output_args.insert(std::string(out_def)); -// } -// } -// } -// } -// // TODO(leca): release consumer_nodes -// } -// } -// -// // Extract initializers used by this_cluster. -// std::unordered_set original_graph_inputs; -// const char** input_names = nullptr; -// size_t input_len = 0; -// graph_api->OrtGraph_GetAllInputs(graph_viewer, &input_names, &input_len); -// for (int i = 0; i < input_len; i++) { -// original_graph_inputs.insert(std::string(input_names[i])); -// } -// graph_api->ReleaseCharArray(input_names); - -// const char** initializer_names = nullptr; -// size_t initializer_len = 0; -// graph_api->OrtGraph_GetAllInitializers(graph_viewer, &initializer_names, &initializer_len); -// const auto& initializers = graph_viewer.GetAllInitializedTensors(); -// for (const auto& in_arg : ordered_input_args) { -// if ((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) || -// ng_required_initializers.count(in_arg)) { -// constant_inputs.push_back(in_arg); -// } -// } -// -// for (const auto& in_arg : ordered_input_args) { -// if (!output_args.count(in_arg) && -// !((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) || -// ng_required_initializers.count(in_arg))) { -// cluster_inputs.push_back(in_arg); -// } -// } -// for (const auto& input : cluster_inputs) { -// cluster_graph_inputs.push_back(input); -// } -// -// for (const auto& in_arg : constant_inputs) { -// cluster_inputs.push_back(in_arg); -// } -// -// std::copy(external_output_args.begin(), external_output_args.end(), std::back_inserter(cluster_outputs)); -// for (const auto& node_arg : graph_viewer.GetOutputs()) { -// const auto& name = node_arg->Name(); -// if (output_args.count(name) && !external_output_args.count(name)) { -// cluster_outputs.push_back(name); -// } -// } -//} +void GetInputsOutputsOfCluster(const OrtGraphApi* graph_api, + const OrtGraphViewer* graph_viewer, + const std::vector& cluster, + const std::unordered_set& ng_required_initializers, + /*out*/ std::vector& cluster_graph_inputs, + /*out*/ std::vector& cluster_inputs, + /*out*/ std::vector& cluster_outputs) { + std::unordered_set input_args; + std::vector ordered_input_args; + std::unordered_set output_args; + std::unordered_set external_output_args; + std::vector constant_inputs; + + for (const auto& node_idx : cluster) { + const OrtNode* node = nullptr; + graph_api->OrtGraph_GetOrtNode(graph_viewer, node_idx, &node); + // Collect all inputs and outputs + ForEachNodeDef(graph_api, graph_viewer, node, + [&input_args, &ordered_input_args, &output_args](const char* arg_name, const OrtValueInfoRef*, bool is_input) { + if (strcmp(arg_name, "") != 0) { + if (is_input) { + if (!input_args.count(std::string(arg_name))) { + ordered_input_args.push_back(std::string(arg_name)); + } + input_args.insert(std::string(arg_name)); + } else { + output_args.insert(std::string(arg_name)); + } + } + }); + + // Check if output of this node is used by nodes outside this_cluster. If yes add this to cluster outputs + // TODO(leca): equal to for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) ? + size_t num_outputs = 0; + graph_api->OrtNode_GetNumOutputs(node, &num_outputs); + for (int i = 0; i < num_outputs; i++) { + const char* output_name = nullptr; + graph_api->OrtNode_GetIthOutputName(node, i, &output_name); + const OrtNode** consumer_nodes = nullptr; + size_t num_consumers = 0; + // TODO(leca): if there is one consumer consuming more than 1 output of curr_node, would it be visited twice? + graph_api->OrtGraph_GetNodesConsumingInput(graph_viewer, output_name, &consumer_nodes, &num_consumers); + for (int j = 0; j < num_consumers; j++) { + size_t consumer_index = 0; + graph_api->OrtNode_GetIndex(consumer_nodes[j], &consumer_index); + + if (std::find(cluster.begin(), cluster.end(), consumer_index) == cluster.end()) { + // Node is external to this_cluster. Search through its inputs to + // find the output that is generated by this_cluster. + std::set ext_node_inputs; + ForEachNodeDef(graph_api, graph_viewer, consumer_nodes[j], + [&ext_node_inputs](const char* arg_name, const OrtValueInfoRef*, bool is_input) { + if (is_input) { + ext_node_inputs.insert(std::string(arg_name)); + } + }); + + for (int j = 0; j < num_outputs; j++) { + const char* out_def = nullptr; + graph_api->OrtNode_GetIthOutputName(node, j, &out_def); + if (ext_node_inputs.find(std::string(out_def)) != ext_node_inputs.end()) { + external_output_args.insert(std::string(out_def)); + } + } + } + } + // TODO(leca): release consumer_nodes + } + } + + // Extract initializers used by this_cluster. + std::unordered_set original_graph_inputs; + const char** input_names = nullptr; + size_t input_len = 0; + graph_api->OrtGraph_GetAllInputs(graph_viewer, &input_names, &input_len); + for (int i = 0; i < input_len; i++) { + original_graph_inputs.insert(std::string(input_names[i])); + } + graph_api->ReleaseCharArray(input_names); + + const char** initializers = nullptr; + size_t initializer_len = 0; + graph_api->OrtGraph_GetAllInitializers(graph_viewer, &initializers, &initializer_len); + for (const auto& in_arg : ordered_input_args) { + bool initializers_contain_in_arg = false; + for (int i = 0; i < initializer_len; i++) { + if (!strcmp(initializers[i], in_arg.c_str())) { + initializers_contain_in_arg = true; + break; + } + } + + if ((initializers_contain_in_arg && !original_graph_inputs.count(in_arg)) || + ng_required_initializers.count(in_arg)) constant_inputs.push_back(in_arg); + if (!output_args.count(in_arg) && + !((initializers_contain_in_arg && !original_graph_inputs.count(in_arg)) || + ng_required_initializers.count(in_arg))) cluster_inputs.push_back(in_arg); + } + + for (const auto& input : cluster_inputs) { + cluster_graph_inputs.push_back(input); + } + + for (const auto& in_arg : constant_inputs) { + cluster_inputs.push_back(in_arg); + } + + std::copy(external_output_args.begin(), external_output_args.end(), std::back_inserter(cluster_outputs)); + size_t output_count = 0; + graph_api->OrtGraph_GetOutputSize(graph_viewer, &output_count); + for (int i = 0; i < output_count; i++) { + const char* name = nullptr; + graph_api->OrtGraph_GetIthOutputName(graph_viewer, i, &name); + if (output_args.count(name) && !external_output_args.count(name)) { + cluster_outputs.push_back(name); + } + } +} } // namespace openvino_ep } // namespace onnxruntime diff --git a/samples/openvino/ov_versions/utils.h b/samples/openvino/ov_versions/utils.h index 2cd27bbee23a9..44cf61faa8cba 100644 --- a/samples/openvino/ov_versions/utils.h +++ b/samples/openvino/ov_versions/utils.h @@ -42,13 +42,13 @@ void IdentifyConnectedNodes( std::vector> GetConnectedClusters(const OrtGraphApi* graph_api, const OrtGraphViewer* graph_viewer, const std::vector>& clusters); -//void GetInputsOutputsOfCluster(const OrtGraphApi* graph_api, -// const GraphViewer& graph_viewer, -// const std::vector& cluster, -// const std::unordered_set& ng_required_initializers, -// /*out*/ std::vector& cluster_graph_inputs, -// /*out*/ std::vector& cluster_inputs, -// /*out*/ std::vector& cluster_outputs); +void GetInputsOutputsOfCluster(const OrtGraphApi* graph_api, + const OrtGraphViewer* graph_viewer, + const std::vector& cluster, + const std::unordered_set& ng_required_initializers, + /*out*/ std::vector& cluster_graph_inputs, + /*out*/ std::vector& cluster_inputs, + /*out*/ std::vector& cluster_outputs); } // namespace openvino_ep } // namespace onnxruntime From 0fe5f0129673df0db507a89a6bc799cf4a8f7c36 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Fri, 1 Nov 2024 18:41:37 +0800 Subject: [PATCH 60/81] refine compile of openvino ep (#22689) --- samples/openvino/backend_manager.cc | 84 +++++++++++-------- samples/openvino/backend_manager.h | 4 +- .../openvino/openvino_execution_provider.cc | 56 ++++++------- .../openvino/openvino_execution_provider.h | 17 ++-- samples/openvino/openvino_utils.cc | 4 +- 5 files changed, 89 insertions(+), 76 deletions(-) diff --git a/samples/openvino/backend_manager.cc b/samples/openvino/backend_manager.cc index 67af0bd589952..944cb33b2edfd 100644 --- a/samples/openvino/backend_manager.cc +++ b/samples/openvino/backend_manager.cc @@ -213,41 +213,53 @@ OrtStatus* BackendManager::ExportCompiledBlobAsEPCtxNode(const OrtGraphViewer* g // return has_batched_inputs; //} // -//bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const { -// bool has_sym_dims = false; -// auto graph_inputs = subgraph.GetInputs(); -// for (auto input : graph_inputs) { -// if (input->Shape() == nullptr) { -// has_sym_dims = true; -// break; -// } -// for (auto& dim : input->Shape()->dim()) { -// if (dim.value_case() != dim.kDimValue) { -// has_sym_dims = true; -// break; -// } -// } -// if (has_sym_dims) { -// break; -// } -// } -// return has_sym_dims; -//} -// -//// Check to see if the graph is QDQ -//static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { -// std::unordered_set qdq_ops = {"QuantizeLinear", "DequantizeLinear"}; -// const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); -// -// for (size_t i = 0; i < node_indices.size(); i++) { -// gsl::not_null node(graph_viewer.GetNode(node_indices[i])); -// if (qdq_ops.find(node->OpType()) != qdq_ops.end()) { -// return true; -// } -// } -// return false; -//} -// +bool BackendManager::ModelHasSymbolicInputDims(const OrtGraphViewer* subgraph) const { + bool has_sym_dims = false; + const char** required_inputs = nullptr; + size_t input_count = 0; + graph_api_->OrtGraph_GetRequiredInputs(subgraph, &required_inputs, &input_count); + for (int i = 0; i < input_count; i++) { + OrtValueInfoRef* value_info = nullptr; + graph_api_->OrtGraph_GetValueInfo(subgraph, required_inputs[i], &value_info); + if (value_info->shape == nullptr) { + has_sym_dims = true; + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + break; + } + for (size_t j = 0; j < value_info->shape_len; j++) { + // if (dim.value_case() != dim.kDimValue) { TODO:yang + // has_sym_dims = true; + // graph_api_->OrtGraph_ReleaseValueInfo(value_info); + // break; + // } + } + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + if (has_sym_dims) { + break; + } + } + return has_sym_dims; +} + +// Check to see if the graph is QDQ +static bool IsQDQGraph(const OrtGraphApi* graph_api, const OrtGraphViewer* graph_viewer) { + std::unordered_set qdq_ops = {"QuantizeLinear", "DequantizeLinear"}; + const size_t* nodes = nullptr; + size_t num_nodes; + graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_viewer, 0, &nodes, &num_nodes); + + for(size_t i = 0; i < num_nodes; i++) { + const OrtNode* node = nullptr; + graph_api->OrtGraph_GetOrtNode(graph_viewer, nodes[i], &node); + const char* optype = nullptr; + graph_api->OrtNode_GetOpType(node, &optype); + if (qdq_ops.find(optype) != qdq_ops.end()) { + return true; + } + } + return false; +} + //static void DumpOpenVINOEPModel(std::string onnx_model_path_name, // ONNX_NAMESPACE::ModelProto* model_proto, // const onnxruntime::Node& fused_node) { @@ -296,7 +308,7 @@ OrtStatus* BackendManager::ExportCompiledBlobAsEPCtxNode(const OrtGraphViewer* g // // QDQ stripping enabled only for the NPU // if (global_context_.device_type.find("NPU") != std::string::npos && // global_context_.enable_qdq_optimizer && -// IsQDQGraph(subgraph)) { +// IsQDQGraph(_graph_api, subgraph)) { // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] QDQ optimization pass status: 1"; // std::unique_ptr model; // Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, model); diff --git a/samples/openvino/backend_manager.h b/samples/openvino/backend_manager.h index aa705db94c8cb..58c8db4d312c2 100644 --- a/samples/openvino/backend_manager.h +++ b/samples/openvino/backend_manager.h @@ -23,7 +23,7 @@ class BackendManager { const OrtNode* fused_node, const OrtGraphViewer* subgraph, EPCtxHandler& ctx_handle); -// void Compute(OrtKernelContext* context); + void Compute(OrtKernelContext* context); void ShutdownBackendManager(); // void SetGlobalCotext(const GlobalContext& global_context); GlobalContext& GetGlobalContext(); @@ -35,7 +35,7 @@ class BackendManager { // const onnxruntime::GraphViewer& subgraph, // const logging::Logger& logger) const; // -// bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; + bool ModelHasSymbolicInputDims(const OrtGraphViewer* subgraph) const; // bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; // // std::shared_ptr diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc index cc016e0eb1c05..7c9211bca3b44 100644 --- a/samples/openvino/openvino_execution_provider.cc +++ b/samples/openvino/openvino_execution_provider.cc @@ -79,37 +79,31 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const backend_manager->ExportCompiledBlobAsEPCtxNode(graph[i]); } - node_compute_info[i].CreateFunctionStateFunc = nullptr; - node_compute_info[i].ComputeFunc = nullptr; - node_compute_info[i].DestroyFunctionStateFunc = nullptr; -// compute_info.create_state_func = -// [backend_manager](ComputeContext* context, FunctionState* state) { -// OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState(); -// p->allocate_func = context->allocate_func; -// p->destroy_func = context->release_func; -// p->allocator_handle = context->allocator_handle; -// p->backend_manager = backend_manager; -// *state = static_cast(p); -// return 0; -// }; -// compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { -// auto function_state = static_cast(state); -// try { -// function_state->backend_manager->Compute(context); -// } catch (const std::exception& ex) { -// return common::Status(common::ONNXRUNTIME, common::FAIL, ex.what()); -// } -// return Status::OK(); -// }; -// -// compute_info.release_state_func = -// [](FunctionState state) { -// if (state) { -// OpenVINOEPFunctionState* function_state = static_cast(state); -// delete function_state; -// } -// }; -// node_compute_funcs.push_back(compute_info); + node_compute_info[i].CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + std::unique_ptr p = std::make_unique(); + p->allocate_func = context->AllocateFunc; + p->destroy_func = context->DestroyFunc; + p->allocator_handle = context->allocator_handle; + // p->backend_manager = static_cast(extra_param); + // p->backend_manager = backend_manager; TODO:yang + *state = p.release(); + return 0; + }; + node_compute_info[i].ComputeFunc = [](void* state, void* extra_param, const OrtApi* api, OrtKernelContext* context) -> OrtStatusPtr { + auto function_state = static_cast(state); + try { + function_state->backend_manager->Compute(context); + } catch (const std::exception& ex) { + return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, ex.what()); + } + return nullptr; + }; + node_compute_info[i].DestroyFunctionStateFunc = [](void* state) { + if (state) { + OpenVINOEPFunctionState* function_state = static_cast(state); + delete function_state; + } + }; } return nullptr; }; diff --git a/samples/openvino/openvino_execution_provider.h b/samples/openvino/openvino_execution_provider.h index eea702145d5a8..baae12a65a843 100644 --- a/samples/openvino/openvino_execution_provider.h +++ b/samples/openvino/openvino_execution_provider.h @@ -18,6 +18,11 @@ #endif namespace onnxruntime { + +using AllocateFunc = void* (*)(void*, size_t, size_t); +using DestroyFunc = void (*)(void*, void*); +using AllocatorHandle = void*; + static void print_build_options() { std::cout << "[ERROR] INVALID DEVICE BUILD TYPE SPECIFIED" << std::endl; std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority " @@ -139,12 +144,12 @@ struct OpenVINOExecutionProviderInfo { } }; -//struct OpenVINOEPFunctionState { -// AllocateFunc allocate_func = nullptr; -// DestroyFunc destroy_func = nullptr; -// AllocatorHandle allocator_handle = nullptr; -// std::shared_ptr backend_manager; -//}; +struct OpenVINOEPFunctionState { + AllocateFunc allocate_func = nullptr; + DestroyFunc destroy_func = nullptr; + AllocatorHandle allocator_handle = nullptr; + std::shared_ptr backend_manager; +}; // Logical device representation. class OpenVINOExecutionProvider : public OrtExecutionProvider { diff --git a/samples/openvino/openvino_utils.cc b/samples/openvino/openvino_utils.cc index 9c7b16e684f0a..f0b641b102f53 100644 --- a/samples/openvino/openvino_utils.cc +++ b/samples/openvino/openvino_utils.cc @@ -1,10 +1,11 @@ -#include +// #include #include "openvino_utils.h" namespace onnxruntime { std::string GetEnvironmentVar(const std::string& var_name) { // TODO(leca): #ifdef _WIN32 //#endif +#if defined(_WIN32) constexpr DWORD kBufferSize = 32767; // Create buffer to hold the result @@ -19,6 +20,7 @@ namespace onnxruntime { buffer.resize(char_count); return buffer; } +#endif return std::string(); } From 6bae1b984f841f9a3068b6dc780ef28be634ae12 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:17:19 -0700 Subject: [PATCH 61/81] Add utility files (#22650) Add some utility files for plugin ep to include and compile. - provider option map -> `provider_option`.h - provider option parser -> `provider_option_utils`.h - some macro define, classes and functions from include/onnxruntime/core/common --- samples/tensorRTEp/CMakeLists.txt | 3 +- samples/utils/code_location.h | 58 ++++++++ samples/utils/common.h | 169 ++++++++++++++++++++++ samples/utils/cuda/cuda_call.h | 69 +++++++++ samples/utils/cuda/cuda_common.h | 14 ++ samples/utils/exceptions.h | 91 ++++++++++++ samples/utils/make_string.h | 126 ++++++++++++++++ samples/utils/parse_string.h | 85 +++++++++++ samples/utils/provider_options.h | 18 +++ samples/utils/provider_options_utils.h | 164 +++++++++++++++++++++ samples/utils/status.cc | 91 ++++++++++++ samples/utils/status.h | 192 +++++++++++++++++++++++++ 12 files changed, 1079 insertions(+), 1 deletion(-) create mode 100644 samples/utils/code_location.h create mode 100644 samples/utils/common.h create mode 100644 samples/utils/cuda/cuda_call.h create mode 100644 samples/utils/cuda/cuda_common.h create mode 100644 samples/utils/exceptions.h create mode 100644 samples/utils/make_string.h create mode 100644 samples/utils/parse_string.h create mode 100644 samples/utils/provider_options.h create mode 100644 samples/utils/provider_options_utils.h create mode 100644 samples/utils/status.cc create mode 100644 samples/utils/status.h diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index 641d05d1b1ad2..ebf5448b80a93 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -12,9 +12,10 @@ find_package(CUDAToolkit REQUIRED) add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) add_definitions(-DNV_TENSORRT_MAJOR=10) -file(GLOB tensorrt_src "./*.cc") +file(GLOB tensorrt_src "./*.cc" "../utils/status.cc") add_library(TensorRTEp SHARED ${tensorrt_src}) target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" + "../utils" "/usr/local/cuda/include" ${TENSORRT_HOME}/include "../../build/tensorrt/Debug/_deps/flatbuffers-src/include" diff --git a/samples/utils/code_location.h b/samples/utils/code_location.h new file mode 100644 index 0000000000000..dbff69099ba78 --- /dev/null +++ b/samples/utils/code_location.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +/** + CodeLocation captures information on where in the source code a message came from. +*/ +struct CodeLocation { + /** + @param file_path Usually the value of __FILE__ + @param line Usually the value of __LINE__ + @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ + */ + CodeLocation(const char* file_path, const int line, const char* func) + : file_and_path{file_path}, line_num{line}, function{func} { + } + + /** + @param file_path Usually the value of __FILE__ + @param line Usually the value of __LINE__ + @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ + @param stacktrace Stacktrace from source of message. + */ + CodeLocation(const char* file_path, const int line, const char* func, const std::vector& stacktrace) + : file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) { + } + + std::string FileNoPath() const { + // assuming we always have work to do, so not trying to avoid creating a new string if + // no path was removed. + return file_and_path.substr(file_and_path.find_last_of("/\\") + 1); + } + + enum Format { + kFilename, + kFilenameAndPath + }; + + std::string ToString(Format format = Format::kFilename) const { + std::ostringstream out; + out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function; + return out.str(); + } + // utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8. + const std::string file_and_path; + const int line_num; + // utf-8 + const std::string function; + const std::vector stacktrace; +}; + +} // namespace onnxruntime diff --git a/samples/utils/common.h b/samples/utils/common.h new file mode 100644 index 0000000000000..eaf000a56ce24 --- /dev/null +++ b/samples/utils/common.h @@ -0,0 +1,169 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Portions Copyright (c) Microsoft Corporation + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "code_location.h" +#include "exceptions.h" +#include "make_string.h" +#include "status.h" + +namespace onnxruntime { + +// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER +// so we only define it as one for MSVC +#if (_MSC_VER && !defined(__PRETTY_FUNCTION__)) +#define __PRETTY_FUNCTION__ __FUNCTION__ +#endif + +// Capture where a message is coming from. Use __FUNCTION__ rather than the much longer __PRETTY_FUNCTION__ +#define ORT_WHERE ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast(__FUNCTION__)) + +#define ORT_WHERE_WITH_STACK \ + ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast(__PRETTY_FUNCTION__), ::onnxruntime::GetStackTrace()) + +// Throw an exception with optional message. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +/* +#define ORT_THROW(...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) +*/ +#define ORT_THROW(...) \ + throw ::onnxruntime::OnnxRuntimeException(::onnxruntime::MakeString(__VA_ARGS__)) + +// Just in order to mark things as not implemented. Do not use in final code. +#define ORT_NOT_IMPLEMENTED(...) \ + throw ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) + +// Check condition. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +#define ORT_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + throw ::onnxruntime::OnnxRuntimeException(#condition, \ + ::onnxruntime::MakeString(__VA_ARGS__)); \ + } \ + } while (false) + +#define ORT_THROW_EX(ex, ...) \ + throw ex(__VA_ARGS__) + +#define ORT_MAKE_STATUS(category, code, ...) \ + ::onnxruntime::common::Status(::onnxruntime::common::category, \ + ::onnxruntime::common::code, \ + ::onnxruntime::MakeString(__VA_ARGS__)) + +// Check condition. if met, return status. +#define ORT_RETURN_IF(condition, ...) \ + do { \ + if (condition) { \ + return ::onnxruntime::common::Status(::onnxruntime::common::ONNXRUNTIME, \ + ::onnxruntime::common::FAIL, \ + ::onnxruntime::MakeString(ORT_WHERE.ToString(), " ", __VA_ARGS__)); \ + } \ + } while (false) + +// Check condition. if not met, return status. +#define ORT_RETURN_IF_NOT(condition, ...) \ + ORT_RETURN_IF(!(condition), __VA_ARGS__) + +// Macros to disable the copy and/or move ctor and assignment methods +// These are usually placed in the private: declarations for a class. + +#define ORT_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete + +#define ORT_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete + +#define ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \ + ORT_DISALLOW_COPY(TypeName); \ + ORT_DISALLOW_ASSIGNMENT(TypeName) + +#define ORT_DISALLOW_MOVE(TypeName) \ + TypeName(TypeName&&) = delete; \ + TypeName& operator=(TypeName&&) = delete + +#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \ + ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \ + ORT_DISALLOW_MOVE(TypeName) + +#define ORT_RETURN_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + return _status; \ + } \ + } while (0) + +#define ORT_THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + ORT_THROW(_status); \ + } \ + } while (0) + +// use this macro when cannot early return +#define ORT_CHECK_AND_SET_RETVAL(expr) \ + do { \ + if (retval.IsOK()) { \ + retval = (expr); \ + } \ + } while (0) + +struct null_type {}; +inline std::string ToUTF8String(const std::string& s) { return s; } +#ifdef _WIN32 +/** + * Convert a wide character string to a UTF-8 string + */ +std::string ToUTF8String(const std::wstring& s); + +std::wstring ToWideString(const std::string& s); +inline std::wstring ToWideString(const std::wstring& s) { return s; } +#else +inline std::string ToWideString(const std::string& s) { return s; } +#endif + +constexpr size_t kMaxStrLen = 2048; + +// Returns whether `key` is in `container`. +// Like C++20's map/set contains() member function. +template typename AssociativeContainer, + typename LookupKey> +inline bool Contains(const AssociativeContainer& container, LookupKey&& key) { + return container.find(std::forward(key)) != container.end(); +} + +} // namespace onnxruntime diff --git a/samples/utils/cuda/cuda_call.h b/samples/utils/cuda/cuda_call.h new file mode 100644 index 0000000000000..81d5975c408b9 --- /dev/null +++ b/samples/utils/cuda/cuda_call.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "../common.h" + +namespace onnxruntime { + +// ----------------------------------------------------------------------- +// Error handling +// ----------------------------------------------------------------------- +// +template +const char* CudaErrString(ERRTYPE) { + ORT_NOT_IMPLEMENTED(); +} + +template +std::conditional_t CudaCall( + ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + if (retCode != successCode) { + try { +//#ifdef _WIN32 + //std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); + //if (hostname_str.empty()) { + //hostname_str = "?"; + //} + //const char* hostname = hostname_str.c_str(); +//#else + //char hostname[HOST_NAME_MAX]; + //if (gethostname(hostname, HOST_NAME_MAX) != 0) + //strcpy(hostname, "?"); +//#endif + int currentCudaDevice = -1; + cudaGetDevice(¤tCudaDevice); + cudaGetLastError(); // clear last CUDA error + static char str[1024]; + snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=? ; file=%s ; line=%d ; expr=%s; %s", + libName, (int)retCode, CudaErrString(retCode), currentCudaDevice, + //hostname, + file, line, exprString, msg); + if constexpr (THRW) { + // throw an exception with the error info + ORT_THROW(str); + } else { + //LOGS_DEFAULT(ERROR) << str; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + } + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, so we'd never get to see the error + if constexpr (THRW) { + ORT_THROW(e.what()); + } else { + //LOGS_DEFAULT(ERROR) << e.what(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); + } + } + } + if constexpr (!THRW) { + return Status::OK(); + } +} + +//template +//std::conditional_t CudaCall( + //ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); + +#define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) + +} // namespace onnxruntime diff --git a/samples/utils/cuda/cuda_common.h b/samples/utils/cuda/cuda_common.h new file mode 100644 index 0000000000000..b00ef3f92674e --- /dev/null +++ b/samples/utils/cuda/cuda_common.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_call.h" + +namespace onnxruntime { +namespace cuda { + +#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) + +} // namespace cuda +} // namespace onnxruntime diff --git a/samples/utils/exceptions.h b/samples/utils/exceptions.h new file mode 100644 index 0000000000000..19c1586aeca07 --- /dev/null +++ b/samples/utils/exceptions.h @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common.h" +//#include "code_location.h" + +namespace onnxruntime { + +class NotImplementedException : public std::logic_error { + public: + explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; + explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; +}; + +class TypeMismatchException : public std::logic_error { + public: + TypeMismatchException() noexcept : logic_error("Type mismatch"){}; +}; + +class OnnxRuntimeException : public std::exception { + public: + // code location is not provided for now + /* + OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept + : OnnxRuntimeException(location, nullptr, msg) { + } + */ + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + */ + /* + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) + : location_{location} { + std::ostringstream ss; + + ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous + if (failed_condition != nullptr) { + ss << " " << failed_condition << " was false."; + } + + ss << " " << msg << "\n"; + if (!location.stacktrace.empty()) { + ss << "Stacktrace:\n"; + // skip the first entry in the stacktrace as we have that information from location.ToString() + std::copy(std::next(location.stacktrace.begin()), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); + } + + what_ = ss.str(); + } + */ + + OnnxRuntimeException(const std::string& msg) noexcept + : OnnxRuntimeException(nullptr, msg) { + } + + OnnxRuntimeException(const char* failed_condition, const std::string& msg) { + std::ostringstream ss; + + if (failed_condition != nullptr) { + ss << failed_condition << " was false."; + } + + ss << " " << msg << "\n"; + what_ = ss.str(); + } + + const char* what() const noexcept override { + return what_.c_str(); + } + + private: + //const CodeLocation location_; + const std::vector stacktrace_; + std::string what_; +}; + +} // namespace onnxruntime diff --git a/samples/utils/make_string.h b/samples/utils/make_string.h new file mode 100644 index 0000000000000..826898de852a8 --- /dev/null +++ b/samples/utils/make_string.h @@ -0,0 +1,126 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Portions Copyright (c) Microsoft Corporation + +#pragma once + +#include +#include +#include + +namespace onnxruntime { + +namespace detail { + +inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept { +} + +template +inline void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept { + MakeStringImpl(ss, t); + MakeStringImpl(ss, args...); +} + +// see MakeString comments for explanation of why this is necessary +template +inline std::string MakeStringImpl(const Args&... args) noexcept { + std::ostringstream ss; + MakeStringImpl(ss, args...); + return ss.str(); +} + +// +// Infrastructure to convert char[n] to char* to reduce binary size +// + +// default is to leave the type as is +template +struct if_char_array_make_ptr { + using type = T; +}; + +// specialization that matches an array reference, which is what the char array from a string literal +// used in a call to MakeString will be. +// if the type is a char[n] array we 'decay' it to a char* so that the usages can be folded. +template +struct if_char_array_make_ptr { + // remove a single extent (T[x] -> T, but T[x][y] -> T[y]) so we only match char[x], + // and get the type name without the 'const' so both 'const char (&)[n]' and 'char (&)[n]' are matched. + using element_type = typename std::remove_const::type>::type; + using type = typename std::conditional::value, T*, T (&)[N]>::type; +}; + +// helper to make usage simpler in MakeString +template +using if_char_array_make_ptr_t = typename if_char_array_make_ptr::type; +} // namespace detail + +/** + * Makes a string by concatenating string representations of the arguments. + * This version uses the current locale. + */ +template +std::string MakeString(const Args&... args) { + // We need to update the types from the MakeString template instantiation to decay any char[n] to char*. + // e.g. MakeString("in", "out") goes from MakeString to MakeStringImpl + // so that MakeString("out", "in") will also match MakeStringImpl instead of requiring + // MakeStringImpl. + // + // We have to do the type processing before any actual work, so this function purely implements the type processing. + // If we do not do it this way we do not get the full binary size reduction. + // + // See https://stackoverflow.com/a/29418212/684911 for overall details of the approach, but note it does not cover + // the need to do the type processing as a separate step. + + return detail::MakeStringImpl(detail::if_char_array_make_ptr_t(args)...); +} + +/** + * Makes a string by concatenating string representations of the arguments. + * This version uses std::locale::classic(). + */ +template +std::string MakeStringWithClassicLocale(const Args&... args) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + detail::MakeStringImpl(ss, args...); + return ss.str(); +} + +// MakeString versions for already-a-string types. + +inline std::string MakeString(const std::string& str) { + return str; +} + +inline std::string MakeString(const char* cstr) { + return cstr; +} + +inline std::string MakeStringWithClassicLocale(const std::string& str) { + return str; +} + +inline std::string MakeStringWithClassicLocale(const char* cstr) { + return cstr; +} + +} // namespace onnxruntime diff --git a/samples/utils/parse_string.h b/samples/utils/parse_string.h new file mode 100644 index 0000000000000..ce404607120f4 --- /dev/null +++ b/samples/utils/parse_string.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "common.h" + +namespace onnxruntime { + +/** + * Tries to parse a value from an entire string. + */ +template +bool TryParseStringWithClassicLocale(std::string_view str, T& value) { + if constexpr (std::is_integral::value && std::is_unsigned::value) { + // if T is unsigned integral type, reject negative values which will wrap + if (!str.empty() && str[0] == '-') { + return false; + } + } + + // don't allow leading whitespace + if (!str.empty() && std::isspace(str[0], std::locale::classic())) { + return false; + } + + std::istringstream is{std::string{str}}; + is.imbue(std::locale::classic()); + T parsed_value{}; + + const bool parse_successful = + is >> parsed_value && + is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters + if (!parse_successful) { + return false; + } + + value = std::move(parsed_value); + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) { + value = str; + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) { + if (str == "0" || str == "False" || str == "false") { + value = false; + return true; + } + + if (str == "1" || str == "True" || str == "true") { + value = true; + return true; + } + + return false; +} + +/** + * Parses a value from an entire string. + */ +template +Status ParseStringWithClassicLocale(std::string_view s, T& value) { + ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\""); + return Status::OK(); +} + +/** + * Parses a value from an entire string. + */ +template +T ParseStringWithClassicLocale(std::string_view s) { + T value{}; + ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value)); + return value; +} + +} // namespace onnxruntime diff --git a/samples/utils/provider_options.h b/samples/utils/provider_options.h new file mode 100644 index 0000000000000..aab13e808e3b6 --- /dev/null +++ b/samples/utils/provider_options.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { + +// data types for execution provider options + +using ProviderOptions = std::unordered_map; +using ProviderOptionsVector = std::vector; +using ProviderOptionsMap = std::unordered_map; + +} // namespace onnxruntime diff --git a/samples/utils/provider_options_utils.h b/samples/utils/provider_options_utils.h new file mode 100644 index 0000000000000..c7380b3629709 --- /dev/null +++ b/samples/utils/provider_options_utils.h @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "common.h" +#include "parse_string.h" +#include "provider_options.h" + +namespace onnxruntime { + +template +using EnumNameMapping = std::vector>; + +/** + * Given a mapping and an enumeration value, gets the corresponding name. + */ +template +Status EnumToName(const EnumNameMapping& mapping, TEnum value, std::string& name) { + const auto it = std::find_if( + mapping.begin(), mapping.end(), + [&value](const std::pair& entry) { + return entry.first == value; + }); + ORT_RETURN_IF( + it == mapping.end(), + "Failed to map enum value to name: ", static_cast::type>(value)); + name = it->second; + return Status::OK(); +} + +template +std::string EnumToName(const EnumNameMapping& mapping, TEnum value) { + std::string name; + ORT_THROW_IF_ERROR(EnumToName(mapping, value, name)); + return name; +} + +/** + * Given a mapping and a name, gets the corresponding enumeration value. + */ +template +Status NameToEnum( + const EnumNameMapping& mapping, const std::string& name, TEnum& value) { + const auto it = std::find_if( + mapping.begin(), mapping.end(), + [&name](const std::pair& entry) { + return entry.second == name; + }); + ORT_RETURN_IF( + it == mapping.end(), + "Failed to map enum name to value: ", name); + value = it->first; + return Status::OK(); +} + +template +TEnum NameToEnum(const EnumNameMapping& mapping, const std::string& name) { + TEnum value; + ORT_THROW_IF_ERROR(NameToEnum(mapping, name, value)); + return value; +} + +class ProviderOptionsParser { + public: + /** + * Adds a parser for a particular provider option value. + * + * @param name The provider option name. + * @param value_parser An object that parses the option value. + * It should be callable with the following signature and return + * whether the parsing was successful: + * Status value_parser(const std::string&) + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddValueParser( + const std::string& name, ValueParserType value_parser) { + ORT_ENFORCE( + value_parsers_.emplace(name, ValueParser{value_parser}).second, + "Provider option \"", name, "\" already has a value parser."); + return *this; + } + + /** + * Adds a parser for a particular provider option value which converts a + * value to the right type and assigns it to the given reference. + * + * IMPORTANT: This function stores a reference to the destination variable. + * The caller must ensure that the reference is valid when Parse() is called! + * + * @param name The provider option name. + * @param dest The destination variable reference. + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddAssignmentToReference( + const std::string& name, ValueType& dest) { + return AddValueParser( + name, + [&dest](const std::string& value_str) -> Status { + return ParseStringWithClassicLocale(value_str, dest); + }); + } + + /** + * Adds a parser for a particular provider option value which maps an + * enumeration name to a value and assigns it to the given reference. + * + * IMPORTANT: This function stores references to the mapping and destination + * variables. The caller must ensure that the references are valid when + * Parse() is called! + * + * @param name The provider option name. + * @param mapping The enumeration value to name mapping. + * @param dest The destination variable reference. + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddAssignmentToEnumReference( + const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { + return AddValueParser( + name, + [&mapping, &dest](const std::string& value_str) -> Status { + return NameToEnum(mapping, value_str, dest); + }); + } + + /** + * Parses the given provider options. + */ + Status Parse(const ProviderOptions& options) const { + for (const auto& option : options) { + const auto& name = option.first; + const auto& value_str = option.second; + const auto value_parser_it = value_parsers_.find(name); + ORT_RETURN_IF( + value_parser_it == value_parsers_.end(), + "Unknown provider option: \"", name, "\"."); + + const auto parse_status = value_parser_it->second(value_str); + ORT_RETURN_IF_NOT( + parse_status.IsOK(), + "Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage()); + } + + return Status::OK(); + } + + private: + using ValueParser = std::function; + std::unordered_map value_parsers_; +}; + +} // namespace onnxruntime diff --git a/samples/utils/status.cc b/samples/utils/status.cc new file mode 100644 index 0000000000000..b3a89c8c13f43 --- /dev/null +++ b/samples/utils/status.cc @@ -0,0 +1,91 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Modifications Copyright (c) Microsoft. + +#include "status.h" +#include "common.h" + +namespace onnxruntime { +namespace common { +Status::Status(StatusCategory category, int code, const std::string& msg) { + // state_ will be allocated here causing the status to be treated as a failure + ORT_ENFORCE(code != static_cast(common::OK)); + + state_ = std::make_unique(category, code, msg); +} + +Status::Status(StatusCategory category, int code, const char* msg) { + // state_ will be allocated here causing the status to be treated as a failure + ORT_ENFORCE(code != static_cast(common::OK)); + + state_ = std::make_unique(category, code, msg); +} + +Status::Status(StatusCategory category, int code) + : Status(category, code, "") { +} + +StatusCategory Status::Category() const noexcept { + return IsOK() ? common::NONE : state_->category; +} + +int Status::Code() const noexcept { + return IsOK() ? static_cast(common::OK) : state_->code; +} + +const std::string& Status::ErrorMessage() const noexcept { + return IsOK() ? EmptyString() : state_->msg; +} + +std::string Status::ToString() const { + if (state_ == nullptr) { + return std::string("OK"); + } + + std::string result; + + if (common::SYSTEM == state_->category) { + result += "SystemError"; + result += " : "; + result += std::to_string(errno); + } else if (common::ONNXRUNTIME == state_->category) { + result += "[ONNXRuntimeEPError]"; + result += " : "; + result += std::to_string(Code()); + result += " : "; + result += StatusCodeToString(static_cast(Code())); + result += " : "; + result += state_->msg; + } + + return result; +} + +// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial +// and should not have any destruction order issues via pragmas instead. +// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 26426) +#endif + +const std::string& Status::EmptyString() noexcept { + static std::string s_empty; + return s_empty; +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +} // namespace common +} // namespace onnxruntime diff --git a/samples/utils/status.h b/samples/utils/status.h new file mode 100644 index 0000000000000..80bf7caf87867 --- /dev/null +++ b/samples/utils/status.h @@ -0,0 +1,192 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Modifications Copyright (c) Microsoft. + +#pragma once + +#include +#include +#include +#ifdef _WIN32 +#include +#endif + +namespace onnxruntime { +namespace common { + +enum StatusCategory { + NONE = 0, + SYSTEM = 1, + ONNXRUNTIME = 2, +}; + +/** + Error code for ONNXRuntime. +*/ +enum StatusCode { + OK = 0, + FAIL = 1, + INVALID_ARGUMENT = 2, + NO_SUCHFILE = 3, + NO_MODEL = 4, + ENGINE_ERROR = 5, + RUNTIME_EXCEPTION = 6, + INVALID_PROTOBUF = 7, + MODEL_LOADED = 8, + NOT_IMPLEMENTED = 9, + INVALID_GRAPH = 10, + EP_FAIL = 11 +}; + +constexpr const char* StatusCodeToString(StatusCode status) noexcept { + switch (status) { + case StatusCode::OK: + return "SUCCESS"; + case StatusCode::FAIL: + return "FAIL"; + case StatusCode::INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case StatusCode::NO_SUCHFILE: + return "NO_SUCHFILE"; + case StatusCode::NO_MODEL: + return "NO_MODEL"; + case StatusCode::ENGINE_ERROR: + return "ENGINE_ERROR"; + case StatusCode::RUNTIME_EXCEPTION: + return "RUNTIME_EXCEPTION"; + case StatusCode::INVALID_PROTOBUF: + return "INVALID_PROTOBUF"; + case StatusCode::MODEL_LOADED: + return "MODEL_LOADED"; + case StatusCode::NOT_IMPLEMENTED: + return "NOT_IMPLEMENTED"; + case StatusCode::INVALID_GRAPH: + return "INVALID_GRAPH"; + case StatusCode::EP_FAIL: + return "EP_FAIL"; + default: + return "GENERAL ERROR"; + } +} + +#ifdef _WIN32 +constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { + switch (status) { + case StatusCode::OK: + return S_OK; + case StatusCode::FAIL: + return E_FAIL; + case StatusCode::INVALID_ARGUMENT: + return E_INVALIDARG; + case StatusCode::NO_SUCHFILE: + return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case StatusCode::NO_MODEL: + return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case StatusCode::ENGINE_ERROR: + return E_FAIL; + case StatusCode::RUNTIME_EXCEPTION: + return E_FAIL; + case StatusCode::INVALID_PROTOBUF: + return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case StatusCode::MODEL_LOADED: + return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::NOT_IMPLEMENTED: + return E_NOTIMPL; + case StatusCode::INVALID_GRAPH: + return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case StatusCode::EP_FAIL: + return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + default: + return E_FAIL; + } +} +#endif + +class [[nodiscard]] Status { + public: + Status() noexcept = default; + + Status(StatusCategory category, int code, const std::string& msg); + + Status(StatusCategory category, int code, const char* msg); + + Status(StatusCategory category, int code); + + Status(const Status& other) + : state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {} + Status& operator=(const Status& other) { + if (state_ != other.state_) { + if (other.state_ == nullptr) { + state_.reset(); + } else { + state_.reset(new State(*other.state_)); + } + } + return *this; + } + + Status(Status&&) = default; + Status& operator=(Status&&) = default; + ~Status() = default; + + bool IsOK() const { + return (state_ == nullptr); + } + + int Code() const noexcept; + + StatusCategory Category() const noexcept; + + const std::string& ErrorMessage() const noexcept; + + std::string ToString() const; + + bool operator==(const Status& other) const { + return (this->state_ == other.state_) || (ToString() == other.ToString()); + } + + bool operator!=(const Status& other) const { + return !(*this == other); + } + + static Status OK() { + return Status(); + } + + private: + static const std::string& EmptyString() noexcept; + + struct State { + State(StatusCategory cat0, int code0, const std::string& msg0) + : category(cat0), code(code0), msg(msg0) {} + + State(StatusCategory cat0, int code0, const char* msg0) + : category(cat0), code(code0), msg(msg0) {} + + const StatusCategory category; + const int code; + const std::string msg; + }; + + // As long as Code() is OK, state_ == nullptr. + std::unique_ptr state_; +}; + +inline std::ostream& operator<<(std::ostream& out, const Status& status) { + return out << status.ToString(); +} +} // namespace common + +// make Status directly available in the onnxruntime namespace as it is widely used +using common::Status; + +} // namespace onnxruntime From ab75d985e5395e8afec1a7ca5ad32b29b3a8aa57 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Sat, 2 Nov 2024 00:22:57 -0700 Subject: [PATCH 62/81] OpenVino, compile() is done --- samples/openvino/CMakeLists.txt | 11 +- samples/openvino/backend_manager.cc | 456 +++++++------- samples/openvino/backend_manager.h | 25 +- samples/openvino/backend_utils.cc | 76 +-- samples/openvino/backend_utils.h | 9 +- samples/openvino/backends/backend_factory.cc | 35 ++ samples/openvino/backends/basic_backend.cc | 569 ++++++++++++++++++ samples/openvino/backends/basic_backend.h | 111 ++++ samples/openvino/ibackend.h | 41 +- .../openvino/openvino_execution_provider.cc | 28 +- .../openvino/openvino_execution_provider.h | 10 +- samples/openvino/openvino_utils.cc | 7 +- 12 files changed, 1056 insertions(+), 322 deletions(-) create mode 100644 samples/openvino/backends/backend_factory.cc create mode 100644 samples/openvino/backends/basic_backend.cc create mode 100644 samples/openvino/backends/basic_backend.h diff --git a/samples/openvino/CMakeLists.txt b/samples/openvino/CMakeLists.txt index 1a1deba629b35..8724219705181 100644 --- a/samples/openvino/CMakeLists.txt +++ b/samples/openvino/CMakeLists.txt @@ -9,11 +9,20 @@ set(CMAKE_CXX_STANDARD 17) find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime) -file(GLOB openvino_src "./*.cc" "./ov_versions/*.cc") +add_definitions(-DONNX_NAMESPACE=onnx) +add_definitions(-DONNX_ML) +file(GLOB openvino_src "./*.cc" "./ov_versions/*.cc" "./backends/*.cc") add_library(OpenVINOEp SHARED ${openvino_src}) target_include_directories(OpenVINOEp PUBLIC "../../include/onnxruntime" ${OPENVINO_HOME}/include + "../../build/Windows/Debug/_deps/onnx-src" + "../../build/Windows/Debug/_deps/onnx-build" + "../../build/Windows/Debug/_deps/protobuf-src/src" ) target_link_libraries(OpenVINOEp PUBLIC "C:/Users/leca/source/onnxruntime/build/Windows/Debug/Debug/onnxruntime.lib" ${OPENVINO_LIB_LIST} + "C:/Users/leca/source/onnxruntime/build/Windows/Debug/_deps/onnx-build/Debug/onnx.lib" + "C:/Users/leca/source/onnxruntime/build/Windows/Debug/_deps/onnx-build/Debug/onnx_proto.lib" + "C:/Users/leca/source/onnxruntime/build/Windows/Debug/_deps/protobuf-build/Debug/libprotobufd.lib" + "C:/Users/leca/source/onnxruntime/build/Windows/Debug/_deps/protobuf-build/Debug/libprotocd.lib" ) diff --git a/samples/openvino/backend_manager.cc b/samples/openvino/backend_manager.cc index 944cb33b2edfd..0f45b0e511362 100644 --- a/samples/openvino/backend_manager.cc +++ b/samples/openvino/backend_manager.cc @@ -70,70 +70,72 @@ BackendManager::BackendManager(const GlobalContext& global_context, const char* subgraph_name = nullptr; graph_api_->OrtNode_GetName(fused_node, &subgraph_name); subgraph_context_.subgraph_name = std::string(subgraph_name); -// model_proto_ = GetModelProtoFromFusedNode(fused_node, subgraph, logger); -// std::string device_type = openvino_ep::BackendManager::GetGlobalContext().device_type; -// -// if (ModelHasSymbolicInputDims(subgraph)) { -// subgraph_context_.has_dynamic_input_shape = true; + model_proto_ = GetModelProtoFromFusedNode(fused_node, subgraph, &model_proto_len_); + std::string device_type = openvino_ep::BackendManager::GetGlobalContext().device_type; + + if (ModelHasSymbolicInputDims(subgraph)) { + subgraph_context_.has_dynamic_input_shape = true; // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; -// ORT_ENFORCE(!global_context_.enable_qdq_optimizer, -// "QDQ stripping should not be enabled for models with dynamic input shapes. " -// "Set enable_qdq_optimizer to False"); -// if (GetGlobalContext().device_type.find("CPU") != std::string::npos || -// GetGlobalContext().device_type.find("GPU") != std::string::npos) { -// if (!GetGlobalContext().disable_dynamic_shapes) { + assert((!global_context_.enable_qdq_optimizer) && + "QDQ stripping should not be enabled for models with dynamic input shapes. Set enable_qdq_optimizer to False"); + if (GetGlobalContext().device_type.find("CPU") != std::string::npos || + GetGlobalContext().device_type.find("GPU") != std::string::npos) { + if (!GetGlobalContext().disable_dynamic_shapes) { // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " // << "Creating backend Dynamic Shapes"; -// try { -// concrete_backend_ = BackendFactory::MakeBackend(*model_proto_, -// GetGlobalContext(), -// subgraph_context_, -// ep_ctx_handle_); -// } catch (std::string const& msg) { -// ORT_THROW(msg); -// } + try { + concrete_backend_ = BackendFactory::MakeBackend(model_proto_, + model_proto_len_, + GetGlobalContext(), + subgraph_context_, + ep_ctx_handle_); + } catch (std::string const& msg) { + throw std::runtime_error(msg); + } // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " // << "Backend created for graph " << subgraph_context_.subgraph_name; -// } -// } -// } else { + } + } + } else { // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. " // << "Initializing backend for graph " // << subgraph_context_.subgraph_name; -// -// subgraph_context_.has_dynamic_input_shape = false; -// -// // OV NPU plugin is supported with fallback to OV CPU upon compilation failures. -// try { -// concrete_backend_ = BackendFactory::MakeBackend(*model_proto_, -// GetGlobalContext(), -// subgraph_context_, -// ep_ctx_handle_); -// } catch (const OnnxRuntimeException& ex) { -//#if defined(OPENVINO_DISABLE_NPU_FALLBACK) -// ORT_THROW(ex.what()); -//#else -// if (device_type.find("NPU") != std::string::npos && -// !GetGlobalContext().disable_cpu_fallback) { + + subgraph_context_.has_dynamic_input_shape = false; + + // OV NPU plugin is supported with fallback to OV CPU upon compilation failures. + try { + concrete_backend_ = BackendFactory::MakeBackend(model_proto_, + model_proto_len_, + GetGlobalContext(), + subgraph_context_, + ep_ctx_handle_); + } catch (const std::exception& ex) { +#if defined(OPENVINO_DISABLE_NPU_FALLBACK) + throw std::runtime_error(ex.what()); +#else + if (device_type.find("NPU") != std::string::npos && + !GetGlobalContext().disable_cpu_fallback) { // LOGS_DEFAULT(WARNING) << ex.what(); // LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." // << "Falling back to OV CPU for execution"; -// GetGlobalContext().device_type = "CPU"; -// GetGlobalContext().precision_str = "FP32"; -// try { -// concrete_backend_ = BackendFactory::MakeBackend(*model_proto_, -// GetGlobalContext(), -// subgraph_context_, -// ep_ctx_handle_); -// } catch (std::string const& msg) { -// ORT_THROW(msg); -// } -// } else { -// ORT_THROW(ex.what()); -// } -//#endif -// } -// } + GetGlobalContext().device_type = "CPU"; + GetGlobalContext().precision_str = "FP32"; + try { + concrete_backend_ = BackendFactory::MakeBackend(model_proto_, + model_proto_len_, + GetGlobalContext(), + subgraph_context_, + ep_ctx_handle_); + } catch (std::string const& msg) { + throw std::runtime_error(msg); + } + } else { + throw std::runtime_error(ex.what()); + } +#endif + } + } } // Call EPContext model exporter here if the provider option for exporting @@ -149,28 +151,28 @@ OrtStatus* BackendManager::ExportCompiledBlobAsEPCtxNode(const OrtGraphViewer* g throw std::runtime_error(exception_str); } -// std::string model_blob_str; -// auto compiled_model = concrete_backend_->GetOVCompiledModel(); -// auto graph_name = global_context_.onnx_model_path_name; -// // Remove extension so we can append suffix to form the complete name of output graph -// graph_name = [&]() { -// size_t dot = graph_name.find_last_of("."); -// if (dot == std::string::npos) return graph_name; -// return graph_name.substr(0, dot); -// }(); -// // If embed_mode, then pass on the serialized blob -// // If not embed_mode, dump the blob here and only pass on the path to the blob -// if (global_context_.ep_context_embed_mode) { -// std::ostringstream model_blob_stream; -// compiled_model.export_model(model_blob_stream); -// model_blob_str = model_blob_stream.str(); -// ORT_ENFORCE(model_blob_str.size() != 0); -// } else { -// std::ofstream f(graph_name + ".blob", std::ios::out | std::ios::trunc | std::ios::binary); -// compiled_model.export_model(f); -// model_blob_str = graph_name + ".blob"; -// } -// + std::string model_blob_str; + auto compiled_model = concrete_backend_->GetOVCompiledModel(); + auto graph_name = global_context_.onnx_model_path_name; + // Remove extension so we can append suffix to form the complete name of output graph + graph_name = [&]() { + size_t dot = graph_name.find_last_of("."); + if (dot == std::string::npos) return graph_name; + return graph_name.substr(0, dot); + }(); + // If embed_mode, then pass on the serialized blob + // If not embed_mode, dump the blob here and only pass on the path to the blob + if (global_context_.ep_context_embed_mode) { + std::ostringstream model_blob_stream; + compiled_model.export_model(model_blob_stream); + model_blob_str = model_blob_stream.str(); + assert(model_blob_str.size() != 0); + } else { + std::ofstream f(graph_name + ".blob", std::ios::out | std::ios::trunc | std::ios::binary); + compiled_model.export_model(f); + model_blob_str = graph_name + ".blob"; + } + // ORT_RETURN_IF_ERROR(ep_ctx_handle_.ExportEPCtxModel(graph_body_viewer, // graph_name, // logger, @@ -181,7 +183,7 @@ OrtStatus* BackendManager::ExportCompiledBlobAsEPCtxNode(const OrtGraphViewer* g // return nullptr; } -// + //bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { // bool has_batched_inputs = true; // @@ -212,33 +214,27 @@ OrtStatus* BackendManager::ExportCompiledBlobAsEPCtxNode(const OrtGraphViewer* g // } // return has_batched_inputs; //} -// + bool BackendManager::ModelHasSymbolicInputDims(const OrtGraphViewer* subgraph) const { - bool has_sym_dims = false; const char** required_inputs = nullptr; size_t input_count = 0; graph_api_->OrtGraph_GetRequiredInputs(subgraph, &required_inputs, &input_count); - for (int i = 0; i < input_count; i++) { - OrtValueInfoRef* value_info = nullptr; - graph_api_->OrtGraph_GetValueInfo(subgraph, required_inputs[i], &value_info); - if (value_info->shape == nullptr) { - has_sym_dims = true; - graph_api_->OrtGraph_ReleaseValueInfo(value_info); - break; - } - for (size_t j = 0; j < value_info->shape_len; j++) { - // if (dim.value_case() != dim.kDimValue) { TODO:yang - // has_sym_dims = true; - // graph_api_->OrtGraph_ReleaseValueInfo(value_info); - // break; - // } - } - graph_api_->OrtGraph_ReleaseValueInfo(value_info); - if (has_sym_dims) { - break; - } - } - return has_sym_dims; + for (int i = 0; i < input_count; i++) { + OrtValueInfoRef* value_info = nullptr; + graph_api_->OrtGraph_GetValueInfo(subgraph, required_inputs[i], &value_info); + if (value_info->shape == nullptr) { + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + return true; + } + for (size_t j = 0; j < value_info->shape_len; j++) { + if (value_info->shape[j] == -1) { // symbolic dimensions are represented as -1 in onnxruntime + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + return true; + } + } + graph_api_->OrtGraph_ReleaseValueInfo(value_info); + } + return false; } // Check to see if the graph is QDQ @@ -284,31 +280,29 @@ static bool IsQDQGraph(const OrtGraphApi* graph_api, const OrtGraphViewer* graph // model_proto->SerializeToOstream(dump); // } //} -// -//std::unique_ptr -//BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, -// const onnxruntime::GraphViewer& subgraph, -// const logging::Logger& logger) const { -// std::chrono::time_point model_proto_create_start_, model_proto_create_end_; -// if (openvino_ep::backend_utils::IsDebugEnabled()) { -// model_proto_create_start_ = std::chrono::high_resolution_clock::now(); -// } -// -// auto print_model_proto_duration = [&]() { -// if (openvino_ep::backend_utils::IsDebugEnabled()) { -// model_proto_create_end_ = std::chrono::high_resolution_clock::now(); -// auto model_proto_create_duration = -// std::chrono::duration_cast( -// model_proto_create_end_ - model_proto_create_start_) -// .count(); + +void* BackendManager::GetModelProtoFromFusedNode(const OrtNode* fused_node, + const OrtGraphViewer* subgraph, size_t* model_proto_len) const { + std::chrono::time_point model_proto_create_start_, model_proto_create_end_; + if (openvino_ep::backend_utils::IsDebugEnabled()) { + model_proto_create_start_ = std::chrono::high_resolution_clock::now(); + } + + auto print_model_proto_duration = [&]() { + if (openvino_ep::backend_utils::IsDebugEnabled()) { + model_proto_create_end_ = std::chrono::high_resolution_clock::now(); + auto model_proto_create_duration = + std::chrono::duration_cast( + model_proto_create_end_ - model_proto_create_start_) + .count(); // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model Proto creation took: " << model_proto_create_duration << " ms."; -// } -// }; -// -// // QDQ stripping enabled only for the NPU -// if (global_context_.device_type.find("NPU") != std::string::npos && -// global_context_.enable_qdq_optimizer && -// IsQDQGraph(_graph_api, subgraph)) { + } + }; + + // QDQ stripping enabled only for the NPU + if (global_context_.device_type.find("NPU") != std::string::npos && + global_context_.enable_qdq_optimizer && + IsQDQGraph(graph_api_, subgraph)) { // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] QDQ optimization pass status: 1"; // std::unique_ptr model; // Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, model); @@ -318,17 +312,15 @@ static bool IsQDQGraph(const OrtGraphApi* graph_api, const OrtGraphViewer* graph // DumpOpenVINOEPModel(global_context_.onnx_model_path_name, model_proto.get(), fused_node); // ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); // return model_proto; -// } else { + return nullptr; + } // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] QDQ optimization pass status: 0"; -// auto model = subgraph.CreateModel(logger); -// auto model_proto = model->ToProto(); -// model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); -// subgraph.ToProto(*model_proto->mutable_graph(), true, true); -// print_model_proto_duration(); -// DumpOpenVINOEPModel(global_context_.onnx_model_path_name, model_proto.get(), fused_node); -// return model_proto; -// } -//} + void* ret = nullptr; + graph_api_->OrtGraph_SerializeToArray(subgraph, &ret, model_proto_len); + print_model_proto_duration(); +// DumpOpenVINOEPModel(global_context_.onnx_model_path_name, model_proto.get(), fused_node); + return ret; +} std::vector> GetInputTensorShapes(const Ort::KernelContext& context) { const auto input_count = context.GetInputCount(); @@ -359,29 +351,28 @@ std::string MakeMapKeyString(const std::vector>& shapes, return key; } -//std::shared_ptr -//BackendManager::ReWriteInputShapeInfo(const ONNX_NAMESPACE::ModelProto& model_proto, -// const std::vector>& input_shapes) { -// auto model_copy = std::shared_ptr(ONNX_NAMESPACE::ModelProto::Create()); -// std::string proto_str; -// model_proto.SerializeToString(proto_str); -// model_copy->ParseFromString(proto_str); -// auto graph_proto = model_copy->mutable_graph(); -// -// for (size_t i = 0, limit = input_shapes.size(); i < limit; i++) { -// auto g_in_shape = graph_proto->mutable_input(static_cast(i)) -// ->mutable_type() -// ->mutable_tensor_type() -// ->mutable_shape(); -// g_in_shape->clear_dim(); -// const auto& shape = input_shapes[i]; -// for (size_t dim = 0, end = shape.size(); dim < end; dim++) { -// g_in_shape->add_dim()->set_dim_value(shape[dim]); -// } -// } -// return model_copy; -//} -// +std::unique_ptr +BackendManager::ReWriteInputShapeInfo(void* model_proto, size_t model_proto_len, + const std::vector>& input_shapes) { + auto model_copy = std::make_unique(); + std::string proto_str(static_cast(model_proto), model_proto_len); + model_copy->ParseFromString(proto_str); + auto graph_proto = model_copy->mutable_graph(); + + for (size_t i = 0, limit = input_shapes.size(); i < limit; i++) { + auto g_in_shape = graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); + g_in_shape->clear_dim(); + const auto& shape = input_shapes[i]; + for (size_t dim = 0, end = shape.size(); dim < end; dim++) { + g_in_shape->add_dim()->set_dim_value(shape[dim]); + } + } + return model_copy; +} + //std::shared_ptr //BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto) { // auto model_copy = std::shared_ptr(ONNX_NAMESPACE::ModelProto::Create()); @@ -401,90 +392,91 @@ std::string MakeMapKeyString(const std::vector>& shapes, // } // return model_copy; //} -// -//void BackendManager::Compute(OrtKernelContext* context) { -// Ort::KernelContext ctx(context); -// std::chrono::high_resolution_clock::time_point start_compute, end_compute; -//#ifdef OPENVINO_FIL_ENABLED -// static bool fil_enabled = true; -// if (fil_enabled) { -// start_compute = std::chrono::high_resolution_clock::now(); + +void BackendManager::Compute(OrtKernelContext* context) { + Ort::KernelContext ctx(context); + std::chrono::high_resolution_clock::time_point start_compute, end_compute; +#ifdef OPENVINO_FIL_ENABLED + static bool fil_enabled = true; + if (fil_enabled) { + start_compute = std::chrono::high_resolution_clock::now(); // LOGS_DEFAULT(INFO) << "Start Compute"; -// } -//#endif -// // OV NPU doesn't support dynamic shaped model inference. -// // if disable_dynamic_shapes is set to true then execution of dynamic model is done -// // by rewriting the model to static shaped model at runtime based on input shape. -// // disable_dynamic_shapes is always set to true for OV NPU plugin. -// bool use_dynamic_backend = true; -// if (subgraph_context_.has_dynamic_input_shape && -// !GetGlobalContext().disable_dynamic_shapes && -// (GetGlobalContext().device_type.find("CPU") != std::string::npos || -// GetGlobalContext().device_type.find("GPU") != std::string::npos)) { -// concrete_backend_->Infer(context); -// use_dynamic_backend = false; -// } else if (use_dynamic_backend && subgraph_context_.has_dynamic_input_shape) { -// std::vector> tensor_shapes = GetInputTensorShapes(ctx); -// auto key = MakeMapKeyString(tensor_shapes, GetGlobalContext().device_type); -// std::shared_ptr dynamic_backend; -// auto search = backend_map_.find(key); -// if (search == backend_map_.end()) { + } +#endif + // OV NPU doesn't support dynamic shaped model inference. + // if disable_dynamic_shapes is set to true then execution of dynamic model is done + // by rewriting the model to static shaped model at runtime based on input shape. + // disable_dynamic_shapes is always set to true for OV NPU plugin. + bool use_dynamic_backend = true; + if (subgraph_context_.has_dynamic_input_shape && + !GetGlobalContext().disable_dynamic_shapes && + (GetGlobalContext().device_type.find("CPU") != std::string::npos || + GetGlobalContext().device_type.find("GPU") != std::string::npos)) { + concrete_backend_->Infer(context); + use_dynamic_backend = false; + } else if (use_dynamic_backend && subgraph_context_.has_dynamic_input_shape) { + std::vector> tensor_shapes = GetInputTensorShapes(ctx); + auto key = MakeMapKeyString(tensor_shapes, GetGlobalContext().device_type); + std::shared_ptr dynamic_backend; + auto search = backend_map_.find(key); + if (search == backend_map_.end()) { // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " // << "Creating dynamic backend for key: " << key; // LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " // << "Backend created for graph " << subgraph_context_.subgraph_name; -// auto modelproto_with_concrete_shapes = ReWriteInputShapeInfo(*model_proto_, tensor_shapes); -// try { -// dynamic_backend = BackendFactory::MakeBackend(*modelproto_with_concrete_shapes, -// GetGlobalContext(), -// subgraph_context_, -// ep_ctx_handle_); -// } catch (const OnnxRuntimeException& ex) { -// // Build option disables fallback to CPU on compilation failures with NPU. -//#if defined(OPENVINO_DISABLE_NPU_FALLBACK) -// LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."; -// ORT_THROW(ex.what()); -//#else -// if (GetGlobalContext().device_type.find("NPU") != std::string::npos && -// !GetGlobalContext().disable_cpu_fallback) { + auto modelproto_with_concrete_shapes = ReWriteInputShapeInfo(model_proto_, model_proto_len_, tensor_shapes); + const std::string model_with_concrete_shapes = modelproto_with_concrete_shapes->SerializeAsString(); + try { + dynamic_backend = BackendFactory::MakeBackend(const_cast(model_with_concrete_shapes.c_str()), model_with_concrete_shapes.length(), + GetGlobalContext(), + subgraph_context_, + ep_ctx_handle_); + } catch (const std::exception& ex) { + // Build option disables fallback to CPU on compilation failures with NPU. +#if defined(OPENVINO_DISABLE_NPU_FALLBACK) + LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."; + throw std::runtime_error(ex.what()); +#else + if (GetGlobalContext().device_type.find("NPU") != std::string::npos && + !GetGlobalContext().disable_cpu_fallback) { // LOGS_DEFAULT(WARNING) << ex.what(); // LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." // << "Falling back to OV CPU for execution"; -// GetGlobalContext().device_type = "CPU"; -// GetGlobalContext().precision_str = "FP32"; -// key = MakeMapKeyString(tensor_shapes, GetGlobalContext().device_type); -// try { -// dynamic_backend = BackendFactory::MakeBackend(*modelproto_with_concrete_shapes, -// GetGlobalContext(), -// subgraph_context_, -// ep_ctx_handle_); -// } catch (std::string const& msg) { -// ORT_THROW(msg); -// } -// } else { -// ORT_THROW(ex.what()); -// } -//#endif -// } -// backend_map_.insert({key, dynamic_backend}); -// } else { -// dynamic_backend = search->second; -// } -// -// dynamic_backend->Infer(context); -// } else { -// concrete_backend_->Infer(context); -// } -//#ifdef OPENVINO_FIL_ENABLED -// if (fil_enabled) { -// end_compute = std::chrono::high_resolution_clock::now(); -// LOGS_DEFAULT(INFO) << "End Compute"; -// std::chrono::duration compute_time = end_compute - start_compute; -// std::cout << "Compute Time: " << compute_time.count() << " s" << std::endl; -// fil_enabled = false; // calculating compute time for first run only -// } -//#endif -//} + GetGlobalContext().device_type = "CPU"; + GetGlobalContext().precision_str = "FP32"; + key = MakeMapKeyString(tensor_shapes, GetGlobalContext().device_type); + try { + dynamic_backend = BackendFactory::MakeBackend(const_cast(model_with_concrete_shapes.c_str()), model_with_concrete_shapes.length(), + GetGlobalContext(), + subgraph_context_, + ep_ctx_handle_); + } catch (std::string const& msg) { + throw std::runtime_error(msg); + } + } else { + throw std::runtime_error(ex.what()); + } +#endif + } + backend_map_.insert({key, dynamic_backend}); + } else { + dynamic_backend = search->second; + } + + dynamic_backend->Infer(context); + } else { + concrete_backend_->Infer(context); + } +#ifdef OPENVINO_FIL_ENABLED + if (fil_enabled) { + end_compute = std::chrono::high_resolution_clock::now(); + LOGS_DEFAULT(INFO) << "End Compute"; + std::chrono::duration compute_time = end_compute - start_compute; + std::cout << "Compute Time: " << compute_time.count() << " s" << std::endl; + fil_enabled = false; // calculating compute time for first run only + } +#endif +} void BackendManager::ShutdownBackendManager() { } diff --git a/samples/openvino/backend_manager.h b/samples/openvino/backend_manager.h index 58c8db4d312c2..1b754fbcab402 100644 --- a/samples/openvino/backend_manager.h +++ b/samples/openvino/backend_manager.h @@ -12,6 +12,7 @@ #include "contexts.h" #include "onnx_ctx_model_helper.h" #include "ibackend.h" +#include "onnx/onnx_pb.h" namespace onnxruntime { namespace openvino_ep { @@ -30,24 +31,24 @@ class BackendManager { OrtStatus* ExportCompiledBlobAsEPCtxNode(const OrtGraphViewer* subgraph); private: -// std::unique_ptr GetModelProtoFromFusedNode( -// const onnxruntime::Node& fused_node, -// const onnxruntime::GraphViewer& subgraph, -// const logging::Logger& logger) const; -// + void* GetModelProtoFromFusedNode( + const OrtNode* fused_node, + const OrtGraphViewer* subgraph, size_t* model_proto_len) const; + bool ModelHasSymbolicInputDims(const OrtGraphViewer* subgraph) const; // bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; // // std::shared_ptr // ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto); // -// std::shared_ptr -// ReWriteInputShapeInfo(const ONNX_NAMESPACE::ModelProto& model_proto, -// const std::vector>& input_shapes); -// -// std::unique_ptr model_proto_; -// std::shared_ptr concrete_backend_; -// std::map> backend_map_; + std::unique_ptr + ReWriteInputShapeInfo(void* model_proto, size_t model_proto_len, + const std::vector>& input_shapes); + + void* model_proto_; // TODO(leca): release + size_t model_proto_len_; + std::shared_ptr concrete_backend_; + std::map> backend_map_; SubGraphContext subgraph_context_; GlobalContext global_context_; EPCtxHandler ep_ctx_handle_{}; diff --git a/samples/openvino/backend_utils.cc b/samples/openvino/backend_utils.cc index 62386e9fe4b7a..68f1c49d6d789 100644 --- a/samples/openvino/backend_utils.cc +++ b/samples/openvino/backend_utils.cc @@ -39,44 +39,44 @@ struct static_cast_int64 { int64_t operator()(const T1& x) const { return static_cast(x); } }; -//std::shared_ptr -//CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, -// std::map>& const_outputs_map) { -// if (IsCILogEnabled()) { -// std::cout << "CreateNgraphFunc" << std::endl; -// } -// const std::string model = model_proto.SerializeAsString(); -// try { -// auto cnn_network = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name); -// -// // Check for Constant Folding -// if (!global_context.is_wholly_supported_graph) { -// ov::pass::ConstantFolding pass_const_obj; -// pass_const_obj.run_on_model(cnn_network); -// auto& results = const_cast(cnn_network.get()->get_results()); -// size_t index = results.size() - 1; -// -// for (auto it = results.rbegin(); it != results.rend(); ++it) { -// if (auto const_node = -// std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { -// const_outputs_map[(*it)->get_friendly_name()] = const_node; -// results.erase(results.begin() + index); -// } -// --index; -// } -// } -//#ifndef NDEBUG -// if (IsDebugEnabled()) { -// std::string name = cnn_network->get_friendly_name(); -// ov::pass::Serialize serializer(name + ".xml", name + ".bin"); -// serializer.run_on_model(cnn_network); -// } -//#endif -// return cnn_network; -// } catch (std::string const& msg) { -// throw std::runtime_error(msg); -// } -//} +std::shared_ptr +CreateOVModel(void* model_proto, size_t model_proto_len, const GlobalContext& global_context, + std::map>& const_outputs_map) { + if (IsCILogEnabled()) { + std::cout << "CreateNgraphFunc" << std::endl; + } + const std::string model(static_cast(model_proto), model_proto_len); + try { + auto cnn_network = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name); + + // Check for Constant Folding + if (!global_context.is_wholly_supported_graph) { + ov::pass::ConstantFolding pass_const_obj; + pass_const_obj.run_on_model(cnn_network); + auto& results = const_cast(cnn_network.get()->get_results()); + size_t index = results.size() - 1; + + for (auto it = results.rbegin(); it != results.rend(); ++it) { + if (auto const_node = + std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { + const_outputs_map[(*it)->get_friendly_name()] = const_node; + results.erase(results.begin() + index); + } + --index; + } + } +#ifndef NDEBUG + if (IsDebugEnabled()) { + std::string name = cnn_network->get_friendly_name(); + ov::pass::Serialize serializer(name + ".xml", name + ".bin"); + serializer.run_on_model(cnn_network); + } +#endif + return cnn_network; + } catch (std::string const& msg) { + throw std::runtime_error(msg); + } +} Ort::UnownedValue GetOutputTensor(Ort::KernelContext& context, size_t batch_size, diff --git a/samples/openvino/backend_utils.h b/samples/openvino/backend_utils.h index c700f86f9c0f7..3d126b2334ed1 100644 --- a/samples/openvino/backend_utils.h +++ b/samples/openvino/backend_utils.h @@ -60,10 +60,11 @@ void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, size_t batch_slice_idx); -//std::shared_ptr -//CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, -// const GlobalContext& global_context, -// std::map>& const_outputs_map); +std::shared_ptr +CreateOVModel(void* model_proto, + size_t model_proto_len, + const GlobalContext& global_context, + std::map>& const_outputs_map); void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName); diff --git a/samples/openvino/backends/backend_factory.cc b/samples/openvino/backends/backend_factory.cc new file mode 100644 index 0000000000000..b4fd4988b880d --- /dev/null +++ b/samples/openvino/backends/backend_factory.cc @@ -0,0 +1,35 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include +#include "../contexts.h" +#include "../ibackend.h" +#include "basic_backend.h" + +namespace onnxruntime { +namespace openvino_ep { + +std::shared_ptr +BackendFactory::MakeBackend(void* model_proto, + size_t model_proto_len, + GlobalContext& global_context, + const SubGraphContext& subgraph_context, + EPCtxHandler& ep_ctx_handle) { + std::string type = global_context.device_type; + if (type == "CPU" || type.find("GPU") != std::string::npos || + type.find("NPU") != std::string::npos || + type.find("HETERO") != std::string::npos || + type.find("MULTI") != std::string::npos || + type.find("AUTO") != std::string::npos) { + std::shared_ptr concrete_backend_; + try { + concrete_backend_ = std::make_shared(model_proto, model_proto_len, global_context, subgraph_context, ep_ctx_handle); + } catch (std::string const& msg) { + throw std::runtime_error(msg); + } + return concrete_backend_; + } + throw std::runtime_error("[OpenVINO-EP] Backend factory error: Unknown backend type: " + type); +} +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/backends/basic_backend.cc b/samples/openvino/backends/basic_backend.cc new file mode 100644 index 0000000000000..da7f5889a3d10 --- /dev/null +++ b/samples/openvino/backends/basic_backend.cc @@ -0,0 +1,569 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include +#include +#include +#include +#include +#include + +#include "../backend_utils.h" +#include "basic_backend.h" +#include "../onnx_ctx_model_helper.h" +#include "../backend_manager.h" +#include "../openvino_utils.h" + +namespace onnxruntime { + +namespace openvino_ep { + +using namespace backend_utils; + +BasicBackend::BasicBackend(void* model_proto, + size_t model_proto_len, + GlobalContext& global_context, + const SubGraphContext& subgraph_context, + EPCtxHandler& ep_ctx_handle) + : global_context_(global_context), subgraph_context_(subgraph_context) { + std::string& hw_target = global_context_.device_type; + + is_ep_ctx_graph_ = ep_ctx_handle.IsValidOVEPCtxGraph(); + + if (ValidateSubgraph(const_outputs_map_)) + return; + + // OV Config + ov::AnyMap device_config; + PopulateConfigValue(device_config); + + // Enable caching + EnableCaching(); + + // Setting OpenCL queue throttling for GPU + EnableGPUThrottling(device_config); + + // Enable streams; default=1 unless ovverriden by user config + EnableStreams(); + + // Set the inference_num_threads property of the CPU + SetNumThreads(device_config); + +#ifndef NDEBUG + if (IsDebugEnabled()) { + std::string file_name = subgraph_context.subgraph_name + "_static.onnx"; + std::fstream outfile(file_name, std::ios::out | std::ios::trunc | std::ios::binary); + outfile << model_proto; + } +#endif + + try { + std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str; + + if (global_context.is_wholly_supported_graph) { // Full graph is supported +#if defined(IO_BUFFER_ENABLED) + if (is_ep_ctx_graph_) { + std::istringstream model_stream(ep_ctx_handle.GetModelBlobString()); + exe_network_ = global_context_.ie_core.ImportModel(model_stream, + remote_context_, + subgraph_context_.subgraph_name); + ie_cnn_network_ = exe_network_.Get().get_runtime_model(); + } else if ((global_context.device_type.find("GPU") != std::string::npos) && + (global_context_.context != nullptr)) { +// LOGS_DEFAULT(INFO) << log_tag << "IO Buffering Enabled"; + cl_context ctx = static_cast(global_context_.context); + remote_context_ = new ov::intel_gpu::ocl::ClContext(global_context_.ie_core.Get(), ctx); + ie_cnn_network_ = CreateOVModel(model_proto, model_proto_len, global_context_, subgraph_context_, const_outputs_map_); + exe_network_ = global_context_.ie_core.CompileModel( + ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); + ie_cnn_network_ = exe_network_.Get().get_runtime_model(); + } else { + ie_cnn_network_ = CreateOVModel(model_proto, model_proto_len, global_context_, subgraph_context_, const_outputs_map_); + exe_network_ = global_context_.ie_core.CompileModel( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + } +#else // !IO_BUFFER_ENABLED + if (is_ep_ctx_graph_) { + // If the blob is held in an EPContext node, then skip FE+Compile + // and directly move on to creating a backend with the executable blob + exe_network_ = global_context_.ie_core.ImportModel(ep_ctx_handle.GetModelBlobStream(), + hw_target, + device_config, + subgraph_context_.subgraph_name); + ie_cnn_network_ = exe_network_.Get().get_runtime_model(); + } else if (!subgraph_context_.has_dynamic_input_shape) { + // Inputs with static dimenstions + std::string prec_str = (global_context_.precision_str != "ACCURACY") ? global_context_.precision_str : global_context_.model_precision; + const std::string model(static_cast(model_proto), model_proto_len); + exe_network_ = global_context_.ie_core.CompileModel(model, + hw_target, + prec_str, + global_context_.cache_dir, + device_config, + subgraph_context_.subgraph_name); + ie_cnn_network_ = exe_network_.Get().get_runtime_model(); + } else { // Inputs with dynamic dimensions + ie_cnn_network_ = CreateOVModel(model_proto, model_proto_len, global_context_, const_outputs_map_); + exe_network_ = global_context_.ie_core.CompileModel( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + } +#endif + } else { // Full graph is not supported + ie_cnn_network_ = CreateOVModel(model_proto, model_proto_len, global_context_, const_outputs_map_); + exe_network_ = global_context_.ie_core.CompileModel( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + } +// LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; + } catch (const char* msg) { + throw std::runtime_error(msg); + } + + inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, 1)); +} + +bool BasicBackend::ValidateSubgraph(std::map>& const_outputs_map) { + if (const_outputs_map.size() == subgraph_context_.output_names.size()) + subgraph_context_.is_constant = true; + if (subgraph_context_.is_constant) { +// LOGS_DEFAULT(INFO) << log_tag << "The subgraph is a const. Directly moving to Infer stage."; + return true; + } + return false; +} + +void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { + device_config = {}; + // Set inference precision based on device precision for OV backend + if (global_context_.precision_str.find("FP16") != std::string::npos && + global_context_.device_type == "GPU") { + device_config.emplace(ov::hint::inference_precision("f16")); + } + if (global_context_.precision_str.find("FP32") != std::string::npos) { + device_config.emplace(ov::hint::inference_precision("f32")); + } + if (global_context_.precision_str.find("ACCURACY") != std::string::npos && + global_context_.device_type == "GPU") { + if (global_context_.OpenVINO_Version.at(0) >= 2024 && global_context_.OpenVINO_Version.at(1) >= 1) { + device_config.emplace(ov::hint::inference_precision(ov::element::undefined)); + device_config.emplace(ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)); + } else { + if (global_context_.model_precision != "") + device_config.emplace(ov::hint::inference_precision(global_context_.model_precision)); + } + } +#ifndef NDEBUG + if (openvino_ep::backend_utils::IsDebugEnabled()) { + device_config.emplace(ov::enable_profiling(true)); + } +#endif + + // Set a priority level for the current workload for preemption; default priority is "DEFAULT" + // CPU Plugin doesn't support workload priority + if (global_context_.device_type.find("CPU") == std::string::npos) + device_config.emplace(ov::hint::model_priority(global_context_.model_priority)); + + if (global_context_.device_type.find("NPU") != std::string::npos) { + std::pair device_property; + device_property = std::make_pair("NPU_COMPILER_TYPE", "DRIVER"); + + const std::string env_npu_compiler_type = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_NPU_COMPILER_TYPE"); + if (!env_npu_compiler_type.empty()) { + device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type); + } + device_config.emplace(ov::device::properties("NPU", device_property)); + } +} + +void BasicBackend::EnableCaching() { + // cache_dir argument has no effect when working with an embed-mode EPContext Graph + if (is_ep_ctx_graph_) return; + + if (!global_context_.cache_dir.empty()) { +// LOGS_DEFAULT(INFO) << log_tag << "Enables Caching"; + global_context_.ie_core.SetCache(global_context_.cache_dir, global_context_.device_type); + } +} + +void BasicBackend::EnableGPUThrottling(ov::AnyMap& device_config) { + if (global_context_.enable_opencl_throttling == true && + global_context_.device_type.find("GPU") != std::string::npos) { +// LOGS_DEFAULT(INFO) << log_tag << "Enabled OpenCL queue throttling for GPU device"; + std::pair device_property; + device_property = std::make_pair("PLUGIN_THROTTLE", "1"); + device_config.emplace(ov::device::properties("GPU_CONFIG_KEY", device_property)); + } +} + +void BasicBackend::EnableStreams() { + // Return silently for NPU as it's currently treated as a read-only flag by the NPU plugin + // and throws an exception for the same + if (global_context_.device_type.find("NPU") != std::string::npos) + return; + + // Streams can be set only if the device is not one of AUTO, MULTI, or HETERO + // Throw an exception if the user tries to set num_streams for these devices + if ((global_context_.device_type.find("MULTI") != std::string::npos) || + (global_context_.device_type.find("HETERO") != std::string::npos) || + (global_context_.device_type.find("AUTO") != std::string::npos)) { + if (global_context_.num_streams != 1) { + throw std::runtime_error(log_tag + "Cannot set NUM_STREAMS to " + + std::to_string(global_context_.num_streams) + " for device " + global_context_.device_type); + } + // Do nothing + } else { + global_context_.ie_core.SetStreams(global_context_.device_type, global_context_.num_streams); + } +} + +void BasicBackend::SetNumThreads(ov::AnyMap& device_config) { + // inference_num_threads is applicable only for the CPU device + if (global_context_.device_type.find("CPU") != std::string::npos) + device_config.emplace(ov::inference_num_threads(global_context_.num_of_threads)); +} + +// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on +// an Infer Request indexed by infer_req_idx +void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { + try { + auto graph_input_info = exe_network_.Get().inputs(); + int input_idx = 0; + for (auto input_info_iter = graph_input_info.begin(); + input_info_iter != graph_input_info.end(); ++input_info_iter) { + auto input_names = input_info_iter->get_names(); + std::string onnx_input_name; + std::string input_name; + // use names retrieved from original ONNX model to assign the right onnx input name for the graph + for (auto it = subgraph_context_.input_names.begin(); it != subgraph_context_.input_names.end(); ++it) { + if (it->second == input_idx) { + onnx_input_name = it->first; + break; + } + } + // using the input name retrieved from ONNX original to match with the input names returned by OV tensors + if (input_names.find(onnx_input_name) != input_names.end()) { + input_name = std::move(onnx_input_name); + } else { + throw std::runtime_error(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); + } + size_t batch_slice_idx = 0; + if (subgraph_context_.has_dynamic_input_shape && + !global_context_.disable_dynamic_shapes && + (global_context_.device_type.find("CPU") != std::string::npos || + global_context_.device_type.find("GPU") != std::string::npos)) { + auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); + auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); + auto tensor_shape = tensor_info.GetShape(); + auto tensor_size = tensor_shape.size(); + const char* tensor_data = tensor.GetTensorData(); + auto tensor_iter = 0; + ov::Shape input_tensor_shape = ov::Shape(tensor_size, 0); + for (auto i = tensor_shape.begin(); i != tensor_shape.end(); ++i) { + input_tensor_shape[tensor_iter] = *i; + tensor_iter += 1; + } + auto input = ie_cnn_network_->get_parameters().at(input_idx); + OVTensorPtr tensor_ptr; + // avoid input copies on the CPU device + if (global_context_.device_type.find("CPU") != std::string::npos) { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape, + (void*)tensor_data); + } else { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); + FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + } + + try { + infer_request->SetTensor(input_name, tensor_ptr); + } catch (const char* msg) { + throw std::runtime_error(msg); + } + } else { + OVTensorPtr graph_input_blob; + try { + graph_input_blob = infer_request->GetTensor(input_name); + } catch (const char* msg) { + throw std::runtime_error(msg); + } + FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); + } + input_idx++; + } + // Start Async inference + infer_request->StartAsync(); + } catch (const char* msg) { + throw std::runtime_error(msg); + } +} + +#ifdef IO_BUFFER_ENABLED +// Wait for Remote Aynchronous inference completion +void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { + try { + auto graph_input_info = exe_network_.Get().inputs(); + int input_idx = 0; + for (auto input_info_iter = graph_input_info.begin(); + input_info_iter != graph_input_info.end(); ++input_info_iter) { + auto input_names = input_info_iter->get_names(); + std::string onnx_input_name; + std::string input_name; + // use names retrieved from original ONNX model to assign the right onnx input name for the graph + for (auto it = subgraph_context_.input_names.begin(); it != subgraph_context_.input_names.end(); ++it) { + if (it->second == input_idx) { + onnx_input_name = it->first; + break; + } + } + // using the input name retrieved from ONNX original to match with the input names returned by OV tensors + if (input_names.find(onnx_input_name) != input_names.end()) { + input_name = onnx_input_name; + } else { + throw std::runtime_error(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); + } + input_idx++; + // Kernel Context Input Buffer + const auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); + // If the ORTValue wraps a device pointer + auto mem_info = tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + // Get the shared buffer pointer + const void* tensor_data = tensor.GetTensorRawData(); + const cl::Buffer* shared_buffer_const = static_cast(tensor_data); + // Create an Input Remote Blob + auto input = ie_cnn_network_->get_parameters().at(0); + auto remote_blob = remote_context_->create_tensor( + input->get_element_type(), input->get_shape(), *shared_buffer_const); + ov::Tensor tensor_remote = static_cast(remote_blob); + OVTensorPtr tensor_ptr = std::make_shared(tensor_remote); + infer_request->SetTensor(input_name, tensor_ptr); + } else { + OVTensorPtr graph_input_blob; + graph_input_blob = infer_request->GetTensor(input_name); + size_t batch_slice_idx = 0; + FillInputBlob(graph_input_blob, batch_slice_idx, input_name, context, subgraph_context_); + } + } + + // Set the output blob as remote blob + auto graph_output_info = exe_network_.Get().outputs(); + for (auto output_info_iter = graph_output_info.begin(); + output_info_iter != graph_output_info.end(); ++output_info_iter) { + auto output_names = output_info_iter->get_names(); + std::string onnx_output_name; + std::string output_name; + bool output_name_found = false; + // using the output name retrieved from ONNX original to match with the output names returned by OV tensors + for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { + onnx_output_name = it->first; + if (output_names.find(onnx_output_name) != output_names.end()) { + // Assigning the output_name + output_name = it->first; + output_name_found = true; + break; + } + } + if (!output_name_found) { + throw std::runtime_error( + log_tag + + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); + } + + size_t batch_size = 1; + Ort::UnownedValue tensor = GetOutputTensor(context, + batch_size, + infer_request, + output_name, + subgraph_context_.output_names); + auto mem_info = tensor.GetTensorMemoryInfo(); + // Check if ORT Value wraps a device pointer + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + const void* tensor_data = tensor.GetTensorRawData(); + const cl::Buffer* shared_buffer_const = static_cast(tensor_data); + // Create a shared Blob, set the Infer Request Output Blob + auto output = ie_cnn_network_->get_results().at(0); + auto remote_tensor = + remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); + ov::Tensor tensor_t = static_cast(remote_tensor); + OVTensorPtr tensor_ptr = std::make_shared(tensor_t); + try { + infer_request->SetTensor(output_name, tensor_ptr); + } catch (const char* msg) { + throw std::runtime_error(msg); + } + } + } + + // Start Async inference + infer_request->StartAsync(); + } catch (const char* msg) { + throw std::runtime_error(msg); + } +} +#endif + +// Wait for asynchronous inference completion on an Infer Request object indexed by infer_req_idx +// and copy the results into a slice location within the batched output buffer indexed by batch_slice_idx +void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { + // Wait for Async inference completion + try { + infer_request->WaitRequest(); + auto graph_output_info = exe_network_.Get().outputs(); + for (auto output_info_iter = graph_output_info.begin(); + output_info_iter != graph_output_info.end(); ++output_info_iter) { + OVTensorPtr graph_output_blob; + auto output_names = output_info_iter->get_names(); + std::string onnx_output_name; + std::string output_name; + bool output_name_found = false; + // using the output name retrieved from ONNX original to match with the output names returned by OV tensors + for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { + onnx_output_name = it->first; + if (output_names.find(onnx_output_name) != output_names.end()) { + // Assigning the output_name + output_name = it->first; + output_name_found = true; + break; + } + } + if (!output_name_found) { + throw std::runtime_error( + log_tag + + "Output names mismatch between OpenVINO and ONNX. " + "[ONNX Output: ] " + + onnx_output_name + + " doesn't exist in the " + "list of OpenVINO output tensor names"); + } + try { + graph_output_blob = infer_request->GetTensor(output_name); + } catch (const char* msg) { + throw std::runtime_error(msg); + } + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); + auto mem_info = output_tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + return; + } else { + size_t batch_slice = 0; + FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); + } + } + + if (!const_outputs_map_.empty()) { + for (const auto& item : const_outputs_map_) { + const auto& out_name = item.first; + auto node = item.second; + Ort::UnownedValue output_tensor = GetOutputTensor(context, + std::move(out_name), + subgraph_context_.output_names, + node); + auto mem_info = output_tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + throw std::runtime_error(log_tag + "IO Buffering is not supported for constant subgraphs"); + } else { + FillOutputsWithConstantData(std::move(node), output_tensor); + } + } + } + } catch (const char* msg) { + throw std::runtime_error(msg); + } +} + +void BasicBackend::Infer(OrtKernelContext* ctx) { + // Preliminary Thread safety mechanism + // currently allows a maximum of 8 Infer request's to parallel execute at the same time + Ort::KernelContext context(ctx); + +// LOGS_DEFAULT(INFO) << log_tag << "Running graph " << subgraph_context_.subgraph_name; +// LOGS_DEFAULT(INFO) << log_tag << "In Infer"; + + if (subgraph_context_.is_constant) { + for (const auto& item : const_outputs_map_) { + std::string out_name = item.first; + std::shared_ptr node = item.second; + try { + Ort::UnownedValue output_tensor = GetOutputTensor(context, + std::move(out_name), + subgraph_context_.output_names, + node); + FillOutputsWithConstantData(std::move(node), output_tensor); + } catch (std::string const& msg) { + throw std::runtime_error(msg); + } + } + // Get Output tensors +// LOGS_DEFAULT(INFO) << log_tag << "Inference successful"; + // Enable CI Logs + if (IsCILogEnabled()) { + std::cout << "Inference successful" << std::endl; + } + + } else { + // Requesting for an idle infer_request from a pool of infer_requests_ + OVInferRequestPtr infer_request; + infer_request = inferRequestsQueue_->getIdleRequest(); + +#ifdef IO_BUFFER_ENABLED + if ((global_context_.device_type.find("GPU") != std::string::npos) && + (global_context_.context != nullptr) && global_context_.is_wholly_supported_graph) { + try { + StartRemoteAsyncInference(context, infer_request); + } catch (std::string const& msg) { + throw std::runtime_error(msg); + } + } else { + try { + StartAsyncInference(context, infer_request); + } catch (std::string const& msg) { + throw std::runtime_error(msg); + } + } +#else + try { + StartAsyncInference(context, infer_request); + } catch (const std::runtime_error& e) { + throw std::runtime_error(log_tag + " Exception at StartAsyncInference: " + e.what()); + } +#endif + try { + CompleteAsyncInference(context, infer_request); + } catch (const std::runtime_error& e) { + throw std::runtime_error(log_tag + " Exception at CompleteAsyncInference: " + e.what()); + } + + // Get Output tensors +// LOGS_DEFAULT(INFO) << log_tag << "Inference successful"; + // Enable CI Logs + if (IsCILogEnabled()) { + std::cout << "Inference successful" << std::endl; + } + + // Create a duplicate infer_request_ shared ptr on the stack in the current local scope, + // as the infer_request gets freed in the next stage the reference count for the infer_request decrements & + // thus we dont have any dangling ptr leading to seg faults in the debug mode subsequent execution call + OVInferRequestPtr infer_request_ = infer_request; + + // Once the inference is completed, the infer_request becomes free and is placed back into pool of infer_requests_ + inferRequestsQueue_->putIdleRequest(std::move(infer_request)); +#ifndef NDEBUG +#ifndef IO_BUFFER_ENABLED // Printing performance counts is disabled when IO_BUFFER_ENABLED + if (openvino_ep::backend_utils::IsDebugEnabled()) { + inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode + std::string& hw_target = global_context_.device_type; + printPerformanceCounts(std::move(infer_request_), std::cout, hw_target); + } +#endif +#endif + } +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/backends/basic_backend.h b/samples/openvino/backends/basic_backend.h new file mode 100644 index 0000000000000..22731de26a0db --- /dev/null +++ b/samples/openvino/backends/basic_backend.h @@ -0,0 +1,111 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "../contexts.h" +#include "../ibackend.h" +#include "../ov_interface.h" + +namespace onnxruntime { +namespace openvino_ep { + +class InferRequestsQueue; +class BasicBackend : public IBackend { + public: + BasicBackend(void* model_proto, + size_t model_proto_len, + GlobalContext& global_context, + const SubGraphContext& subgraph_context, + EPCtxHandler& ep_ctx_handle); + + void Infer(OrtKernelContext* context) override; + ov::CompiledModel& GetOVCompiledModel() override { + return exe_network_.Get(); + } + + private: + void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); + bool ValidateSubgraph(std::map>& const_outputs_map); + void PopulateConfigValue(ov::AnyMap& device_config); + void EnableCaching(); + void EnableGPUThrottling(ov::AnyMap& device_config); + void EnableStreams(); + void SetNumThreads(ov::AnyMap& device_config); + void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); + +#ifdef IO_BUFFER_ENABLED + void StartRemoteAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); +#endif + + void CompleteAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); + + GlobalContext& global_context_; + SubGraphContext subgraph_context_; + mutable std::mutex compute_lock_; + std::shared_ptr ie_cnn_network_; + OVExeNetwork exe_network_; + std::map> const_outputs_map_; + std::unique_ptr inferRequestsQueue_; + bool is_ep_ctx_graph_{false}; +#if defined IO_BUFFER_ENABLED + OVRemoteContextPtr remote_context_; +#endif +}; + +class InferRequestsQueue { + public: + InferRequestsQueue(OVExeNetwork& net, size_t nireq) { + OVInferRequestPtr infer_request; + for (size_t id = 0; id < nireq; id++) { + infer_request = std::make_shared(net.CreateInferRequest()); + infer_requests_.push_back(infer_request); + } + } + + ~InferRequestsQueue() { + // clearing out the infer_requests_ vector pool in the class's destructor + for (auto& pointer : infer_requests_) { + pointer = nullptr; + } + infer_requests_.erase(std::remove(infer_requests_.begin(), infer_requests_.end(), nullptr), infer_requests_.end()); + } + + void printstatus() { + std::cout << "printing elements of the vector (infer_requests_): " << std::endl; + for (auto i = infer_requests_.begin(); i != infer_requests_.end(); ++i) { + i->get()->QueryStatus(); + } + std::cout << '\n'; + } + + void putIdleRequest(OVInferRequestPtr infer_request) { + std::unique_lock lock(_mutex); + infer_requests_.push_back(infer_request); + _cv.notify_one(); + } + + OVInferRequestPtr getIdleRequest() { + std::unique_lock lock(_mutex); + _cv.wait(lock, [this] { return infer_requests_.size() > 0; }); + auto request = infer_requests_.at(0); + infer_requests_.erase(infer_requests_.begin()); + return request; + } + + private: + std::mutex _mutex; + std::condition_variable _cv; + std::vector infer_requests_; +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/ibackend.h b/samples/openvino/ibackend.h index 8e54f7d1cb5d4..d391abc2a6cca 100644 --- a/samples/openvino/ibackend.h +++ b/samples/openvino/ibackend.h @@ -7,23 +7,24 @@ #include "core/session/onnxruntime_cxx_api.h" #include "onnx_ctx_model_helper.h" -//namespace onnxruntime { -//namespace openvino_ep { -// -//class IBackend { -// public: -// virtual void Infer(OrtKernelContext* context) = 0; -// virtual ov::CompiledModel& GetOVCompiledModel() = 0; -//}; -// -//class BackendFactory { -// public: -// static std::shared_ptr -// MakeBackend(const ONNX_NAMESPACE::ModelProto& model_proto, -// GlobalContext& global_context, -// const SubGraphContext& subgraph_context, -// EPCtxHandler& ctx_handle); -//}; -// -//} // namespace openvino_ep -//} // namespace onnxruntime +namespace onnxruntime { +namespace openvino_ep { + +class IBackend { + public: + virtual void Infer(OrtKernelContext* context) = 0; + virtual ov::CompiledModel& GetOVCompiledModel() = 0; +}; + +class BackendFactory { + public: + static std::shared_ptr + MakeBackend(void* model_proto, + size_t model_proto_len, + GlobalContext& global_context, + const SubGraphContext& subgraph_context, + EPCtxHandler& ctx_handle); +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc index 7c9211bca3b44..e9c57ae5b58ba 100644 --- a/samples/openvino/openvino_execution_provider.cc +++ b/samples/openvino/openvino_execution_provider.cc @@ -62,6 +62,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const OrtExecutionProvider::Compile = [](OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info) -> OrtStatusPtr { OpenVINOExecutionProvider* p = static_cast(this_); + this_->extra_param_for_create_state_func = p; for (int i = 0; i < cnt; i++) { p->global_context_->use_api_2 = true; @@ -69,8 +70,8 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const // For precompiled blob, directly load the model instead of compiling the model // For original model, check if the user wants to export a model with pre-compiled blob - std::shared_ptr backend_manager = - std::make_shared(*p->global_context_, + std::unique_ptr backend_manager = + std::make_unique(*p->global_context_, node[i], graph[i], p->ep_ctx_handle_); @@ -78,14 +79,18 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const if (p->global_context_->export_ep_ctx_blob && !p->ep_ctx_handle_.IsValidOVEPCtxGraph()) { backend_manager->ExportCompiledBlobAsEPCtxNode(graph[i]); } + const char* fused_node_name = nullptr; + graph_api_->OrtNode_GetName(node[i], &fused_node_name); + p->backend_managers_.emplace(fused_node_name, std::move(backend_manager)); node_compute_info[i].CreateFunctionStateFunc = [](OrtComputeContext* context, void* extra_param, void** state) -> int { + OpenVINOExecutionProvider* this_ = reinterpret_cast(extra_param); std::unique_ptr p = std::make_unique(); - p->allocate_func = context->AllocateFunc; - p->destroy_func = context->DestroyFunc; + p->AllocateFunc = context->AllocateFunc; + p->DestroyFunc = context->DestroyFunc; p->allocator_handle = context->allocator_handle; - // p->backend_manager = static_cast(extra_param); - // p->backend_manager = backend_manager; TODO:yang + p->node_name = context->node_name; + p->backend_manager = this_->backend_managers_[context->node_name].get(); *state = p.release(); return 0; }; @@ -108,7 +113,16 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const return nullptr; }; - //OrtExecutionProvider::ReleaseIndexedSubGraphs + OrtExecutionProvider::ReleaseIndexedSubGraphs = [](OrtIndexedSubGraph** indexed_sub_graphs, size_t num_sub_graph) { + if (indexed_sub_graphs == nullptr) return; + for (size_t i = 0; i < num_sub_graph; i++) { + OrtIndexedSubGraph* sub_graph = indexed_sub_graphs[i]; + delete[] sub_graph->node_index; + delete sub_graph->meta_def; + delete sub_graph; + } + delete[] indexed_sub_graphs; + }; } OpenVINOExecutionProviderFactory::OpenVINOExecutionProviderFactory() { diff --git a/samples/openvino/openvino_execution_provider.h b/samples/openvino/openvino_execution_provider.h index baae12a65a843..d013005ff1825 100644 --- a/samples/openvino/openvino_execution_provider.h +++ b/samples/openvino/openvino_execution_provider.h @@ -145,10 +145,11 @@ struct OpenVINOExecutionProviderInfo { }; struct OpenVINOEPFunctionState { - AllocateFunc allocate_func = nullptr; - DestroyFunc destroy_func = nullptr; - AllocatorHandle allocator_handle = nullptr; - std::shared_ptr backend_manager; + void*(ORT_API_CALL* AllocateFunc)(void*, size_t, size_t); + void(ORT_API_CALL* DestroyFunc)(void*, void*); + void* allocator_handle; + const char* node_name; + openvino_ep::BackendManager* backend_manager; }; // Logical device representation. @@ -160,6 +161,7 @@ class OpenVINOExecutionProvider : public OrtExecutionProvider { private: std::unique_ptr global_context_; openvino_ep::EPCtxHandler ep_ctx_handle_{}; + std::unordered_map> backend_managers_; static const OrtApi* api_; static const OrtGraphApi* graph_api_; }; diff --git a/samples/openvino/openvino_utils.cc b/samples/openvino/openvino_utils.cc index f0b641b102f53..8bf306005c414 100644 --- a/samples/openvino/openvino_utils.cc +++ b/samples/openvino/openvino_utils.cc @@ -1,10 +1,10 @@ -// #include +#if defined(_WIN32) +#include +#endif #include "openvino_utils.h" namespace onnxruntime { std::string GetEnvironmentVar(const std::string& var_name) { -// TODO(leca): #ifdef _WIN32 -//#endif #if defined(_WIN32) constexpr DWORD kBufferSize = 32767; @@ -21,7 +21,6 @@ namespace onnxruntime { return buffer; } #endif - return std::string(); } From 08e3f20c16fc0ba1d4a32d8d6a16f7b5b0217a98 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Sat, 2 Nov 2024 10:33:01 -0700 Subject: [PATCH 63/81] Add unit test for TRT EP plugin (#22548) Hook TRT EP plugin to run the existing unit test in CI - Migrate from `onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc` - Replace internal APIs with new EP APIs - Add unit test in `onnxruntime_shared_lib_test` (which links against onnxruntime dll) - Build ORT with `--test_tensorrt_ep_plugin` to run `onnxruntime_shared_lib_test` Note: The unit test doesn't cover all the cases since current TRT EP plugin hasn't added all the features yet, will update later. --- cmake/onnxruntime_unittests.cmake | 4 + .../test/shared_lib/test_trt_ep_plugin.cc | 370 ++++++++++++++++++ .../trt_ep_test_model_static_input_shape.onnx | Bin 0 -> 449 bytes .../tensorRTEp/tensorrt_execution_provider.cc | 1 + .../tensorrt_execution_provider_info.cc | 140 +++---- .../tensorrt_execution_provider_info.h | 3 +- tools/ci_build/build.py | 2 + 7 files changed, 450 insertions(+), 70 deletions(-) create mode 100644 onnxruntime/test/shared_lib/test_trt_ep_plugin.cc create mode 100644 onnxruntime/test/testdata/trt_ep_test_model_static_input_shape.onnx diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index d3acd1718dd87..4c4092d92b7da 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -483,6 +483,10 @@ if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) endif() +if (onnxruntime_TEST_TENSORRT_EP_PLUGIN) + list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_trt_ep_plugin.cc) +endif() + if(onnxruntime_RUN_ONNX_TESTS) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_io_types.cc) endif() diff --git a/onnxruntime/test/shared_lib/test_trt_ep_plugin.cc b/onnxruntime/test/shared_lib/test_trt_ep_plugin.cc new file mode 100644 index 0000000000000..da9bb503095ab --- /dev/null +++ b/onnxruntime/test/shared_lib/test_trt_ep_plugin.cc @@ -0,0 +1,370 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include + +namespace onnxruntime { + +const ORTCHAR_T* ep_plugin_lib = "/home/lochi/repos/ort_for_docker_ep_plugin_2/samples/tensorRTEp/build/libTensorRTEp.so"; // hardcode path for now +const ORTCHAR_T* ep_plugin_name = "tensorrtEp"; +const ORTCHAR_T* model_path = "testdata/trt_ep_test_model_static_input_shape.onnx"; +const ORTCHAR_T* model_path_2 = "testdata/trt_ep_test_model_dynamic_input_shape.onnx"; + +inline void THROW_ON_ERROR(OrtStatus* status, const OrtApi* api) { + if (status != nullptr && api != nullptr) { + std::cout<<"ErrorMessage:"<GetErrorMessage(status)<<"\n"; + abort(); + } +} + +bool HasCacheFileWithPrefix(const std::string& prefix, std::string file_dir = "") { + std::filesystem::path target_dir; + if (file_dir.empty()) { + target_dir = std::filesystem::current_path(); + } else { + target_dir = std::filesystem::path(file_dir); + } + + for (const auto& entry : std::filesystem::directory_iterator(target_dir)) { + if (entry.is_regular_file()) { + std::string filename = entry.path().filename().string(); + if (filename.rfind(prefix, 0) == 0) { + return true; + } + } + } + return false; +} + +void ValidateOutputs(std::vector& ort_outputs, + std::vector& expected_dims, + std::vector& expected_values) { + + auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), expected_dims); + size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(expected_values.size(), total_len); + + float* f = ort_outputs[0].GetTensorMutableData(); + for (size_t i = 0; i != total_len; ++i) { + ASSERT_EQ(expected_values[i], f[i]); + } +} + +void RunSession(Ort::Session& session, + std::vector& input_names, + std::vector& ort_inputs, + const char* const* output_names, + std::vector& expected_dims, + std::vector& expected_values) { + std::vector ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, 1); + ValidateOutputs(ort_outputs, expected_dims, expected_values); +} + +void CreateSessionAndRunInference() { + // Use C API here since EP plugin only supports C API for now + OrtEnv* env = nullptr; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; + THROW_ON_ERROR(api->CreateEnv(log_level, "", &env), api); + THROW_ON_ERROR(api->RegisterPluginExecutionProviderLibrary(ep_plugin_lib, env, ep_plugin_name), api); + OrtSessionOptions* so = nullptr; + THROW_ON_ERROR(api->CreateSessionOptions(&so), api); + std::vector keys{"trt_engine_cache_enable", "trt_engine_cache_prefix", "trt_dump_ep_context_model", "trt_ep_context_file_path"}; + std::vector values{"1", "TRTEP_Cache_Test", "1", "EP_Context_model.onnx"}; + THROW_ON_ERROR(api->SessionOptionsAppendPluginExecutionProvider(so, ep_plugin_name, env, keys.data(), values.data(), keys.size()), api); + + // Use C++ Wrapper + Ort::SessionOptions ort_so{so}; + Ort::Env ort_env{env}; + + Ort::Session session(ort_env, model_path, ort_so); + + std::vector ort_inputs; + std::vector input_names; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + // input 0, 1, 2 + std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + std::vector input_dims = {1, 3, 2}; + input_names.emplace_back("X"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Y"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Z"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + + // output 0 + const char* output_names[] = {"M"}; + + // Run inference + // TRT engine will be created and cached + // TRT profile will be created and cached only for dynamic input shape + // Data in profile, + // X: 1, 3, 3, 2, 2, 2 + // Y: 1, 3, 3, 2, 2, 2 + // Z: 1, 3, 3, 2, 2, 2 + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + // Verify on cache with customized prefix + ASSERT_TRUE(HasCacheFileWithPrefix("TRTEP_Cache_Test")); + + // Verify EP context model with user provided name + ASSERT_TRUE(HasCacheFileWithPrefix("EP_Context_model.onnx")); +} + +/* + * Create one session and run by multiple threads + */ +void CreateSessionAndRunInference2() { + // Use C API + OrtEnv* env = nullptr; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; + THROW_ON_ERROR(api->CreateEnv(log_level, "", &env), api); + THROW_ON_ERROR(api->RegisterPluginExecutionProviderLibrary(ep_plugin_lib, env, ep_plugin_name), api); + OrtSessionOptions* so = nullptr; + THROW_ON_ERROR(api->CreateSessionOptions(&so), api); + std::vector keys{"trt_engine_cache_enable", "trt_engine_cache_prefix", "trt_dump_ep_context_model", "trt_ep_context_file_path"}; + std::vector values{"1", "TRTEP_Cache_Test", "1", "EP_Context_model.onnx"}; + THROW_ON_ERROR(api->SessionOptionsAppendPluginExecutionProvider(so, ep_plugin_name, env, keys.data(), values.data(), keys.size()), api); + + // Use C++ Wrapper + Ort::SessionOptions ort_so{so}; + Ort::Env ort_env{env}; + + Ort::Session session(ort_env, model_path, ort_so); + + std::vector ort_inputs; + std::vector input_names; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + // input 0, 1, 2 + std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + std::vector input_dims = {1, 3, 2}; + input_names.emplace_back("X"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Y"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Z"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + + // output 0 + const char* output_names[] = {"M"}; + std::vector y_dims = {1, 3, 2}; + std::vector values_y = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; + + std::vector threads; + int num_thread = 5; + for (int i = 0; i < num_thread; ++i) { + threads.push_back(std::thread(RunSession, std::ref(session), std::ref(input_names), std::ref(ort_inputs), std::ref(output_names), std::ref(y_dims), std::ref(values_y))); + } + + for (auto& th : threads) + th.join(); + + // Verify on cache with customized prefix + ASSERT_TRUE(HasCacheFileWithPrefix("TRTEP_Cache_Test")); +} + +TEST(TensorrtExecutionProviderPluginTest, SmallModel) { + // Use C API + OrtEnv* env = nullptr; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; + THROW_ON_ERROR(api->CreateEnv(log_level, "", &env), api); + THROW_ON_ERROR(api->RegisterPluginExecutionProviderLibrary(ep_plugin_lib, env, ep_plugin_name), api); + OrtSessionOptions* so = nullptr; + THROW_ON_ERROR(api->CreateSessionOptions(&so), api); + std::vector keys; + std::vector values; + THROW_ON_ERROR(api->SessionOptionsAppendPluginExecutionProvider(so, ep_plugin_name, env, keys.data(), values.data(), keys.size()), api); + + // Use C++ Wrapper + Ort::SessionOptions ort_so{so}; + Ort::Env ort_env{env}; + Ort::Session session(ort_env, model_path, ort_so); + + std::vector ort_inputs; + std::vector input_names; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + // input 0, 1, 2 + std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + std::vector input_dims = {1, 3, 2}; + input_names.emplace_back("X"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Y"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Z"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + + // output 0 + const char* output_names[] = {"M"}; + + // Run inference + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + // Validate results + std::vector y_dims = {1, 3, 2}; + std::vector values_y = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; + ValidateOutputs(ort_outputs, y_dims, values_y); +} + +TEST(TensorrtExecutionProviderPluginTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) { + std::vector threads; + std::vector dims = {1, 3, 2}; + int num_thread = 5; + + for (int i = 0; i < num_thread; ++i) + threads.push_back(std::thread(CreateSessionAndRunInference)); + + for (auto& th : threads) + th.join(); +} + +TEST(TensorrtExecutionProviderPluginTest, SessionCreationWithSingleThreadAndInferenceWithMultiThreads) { + std::vector dims = {1, 3, 2}; + + CreateSessionAndRunInference2(); +} + +TEST(TensorrtExecutionProviderPluginTest, EPContextNode) { + // Use C API + OrtEnv* env = nullptr; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; + THROW_ON_ERROR(api->CreateEnv(log_level, "", &env), api); + THROW_ON_ERROR(api->RegisterPluginExecutionProviderLibrary(ep_plugin_lib, env, ep_plugin_name), api); + + std::vector ort_inputs; + std::vector input_names; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + // input 0, 1, 2 + std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + std::vector input_dims = {1, 3, 2}; + input_names.emplace_back("X"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Y"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + input_names.emplace_back("Z"); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_data.data()), + input_data.size(), input_dims.data(), input_dims.size())); + + // output 0 + const char* output_names[] = {"M"}; + std::vector y_dims = {1, 3, 2}; + std::vector values_y = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; + + /* + * Test case 1: Dump context model + * + * provider options=> + * trt_ep_context_file_path = "EP_Context_model.onnx" + * + * expected result => + * context model "EP_Context_model.onnx" should be created in current directory + * + */ + OrtSessionOptions* so = nullptr; + THROW_ON_ERROR(api->CreateSessionOptions(&so), api); + std::vector keys{"trt_engine_cache_enable", "trt_dump_ep_context_model", "trt_ep_context_file_path"}; + std::vector values{"1", "1", "EP_Context_model.onnx"}; + THROW_ON_ERROR(api->SessionOptionsAppendPluginExecutionProvider(so, ep_plugin_name, env, keys.data(), values.data(), keys.size()), api); + + Ort::SessionOptions ort_so{so}; + Ort::Env ort_env{env}; + Ort::Session session(ort_env, model_path, ort_so); + + ASSERT_TRUE(HasCacheFileWithPrefix("EP_Context_model.onnx")); + + /* + * Test case 2: Dump context model + * + * provider options=> + * trt_engine_cache_prefix = "TRT_engine_cache" + * trt_ep_context_file_path = "context_model_folder" + * trt_engine_cache_path = "engine_cache_folder" + * + * expected result => + * engine cache "./context_model_folder/engine_cache_folder/TRT_engine_cache...engine" should be created + * context model "./context_model_folder/trt_ep_test_model_static_input_shape_ctx.onnx" should be created + */ + OrtSessionOptions* so2 = nullptr; + THROW_ON_ERROR(api->CreateSessionOptions(&so2), api); + std::vector keys2{"trt_engine_cache_enable", "trt_dump_ep_context_model", "trt_engine_cache_prefix", "trt_engine_cache_path", "trt_ep_context_file_path"}; + std::vector values2{"1", "1", "TRT_engine_cache", "engine_cache_folder", "context_model_folder"}; + THROW_ON_ERROR(api->SessionOptionsAppendPluginExecutionProvider(so2, ep_plugin_name, env, keys2.data(), values2.data(), keys2.size()), api); + + Ort::SessionOptions ort_so2{so2}; + Ort::Session session2(ort_env, model_path, ort_so2); + + auto new_engine_cache_path = std::filesystem::path("context_model_folder").append("engine_cache_folder").string(); + // Test engine cache path: + // "./context_model_folder/engine_cache_folder/TRT_engine_cache...engine" should be created + ASSERT_TRUE(HasCacheFileWithPrefix("TRT_engine_cache", new_engine_cache_path)); + // Test context model path: + // "./context_model_folder/trt_ep_test_model_static_input_shape_ctx.onnx" should be created + ASSERT_TRUE(HasCacheFileWithPrefix("trt_ep_test_model_static_input_shape_ctx.onnx", "context_model_folder")); + + /* + * Test case 3: Run the dumped context model + * + * context model path = "./EP_Context_model.onnx" (created from case 1) + * + * expected result=> + * engine cache is also in the same current dirctory as "./xxxxx.engine" + * and the "ep_cache_context" attribute node of the context model should point to that. + * + */ + OrtSessionOptions* so3 = nullptr; + THROW_ON_ERROR(api->CreateSessionOptions(&so3), api); + std::vector keys3{"trt_engine_cache_enable"}; + std::vector values3{"1"}; + THROW_ON_ERROR(api->SessionOptionsAppendPluginExecutionProvider(so3, ep_plugin_name, env, keys3.data(), values3.data(), keys3.size()), api); + + Ort::SessionOptions ort_so3{so3}; + Ort::Session session3(ort_env, "EP_Context_model.onnx", ort_so3); + + // Run inference + auto ort_outputs3 = session3.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + output_names, 1); + // Validate results + ValidateOutputs(ort_outputs3, y_dims, values_y); +} + +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/trt_ep_test_model_static_input_shape.onnx b/onnxruntime/test/testdata/trt_ep_test_model_static_input_shape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4286222dd05bce8d4a7028cf9a5f8f83d7366227 GIT binary patch literal 449 zcmZ{eJx{|h5QgiPkLXm8F-5H^giuF@lbVsO*}E{&ERk_cW0agD$I$Y}nQ?$hjTD9( z?w;qpy9ge?7(4)b2DTAnvboNdnSJ-!a(?#PEk>(6kI&oYeu=^DSin-j)_-n%?8YcW zt<8 literal 0 HcmV?d00001 diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 2e61d1cb1b92e..50e306a345b87 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1676,6 +1676,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const }; info_ = TensorrtExecutionProviderInfo::FromProviderOptions(ep_info); + if (ep_info.size() > 0) info_.has_trt_options = true; device_id_ = info_.device_id; api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.cc b/samples/tensorRTEp/tensorrt_execution_provider_info.cc index 4d3030a365f18..8a34cf0c7b2c0 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_info.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.cc @@ -1,7 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include //#incldue "core/providers/cuda/cuda_pch.h" + #include "tensorrt_execution_provider_info.h" +#include "provider_options_utils.h" +#include "cuda/cuda_common.h" namespace onnxruntime { namespace tensorrt { @@ -54,75 +58,73 @@ constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { TensorrtExecutionProviderInfo info{}; - for (const auto& [k, v] : options) { - if (k == "device_id") info.device_id = std::atoi(v.c_str()); - } -// void* user_compute_stream = nullptr; -// ORT_THROW_IF_ERROR( -// ProviderOptionsParser{} -// .AddValueParser( -// tensorrt::provider_option_names::kDeviceId, -// [&info](const std::string& value_str) -> Status { -// ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); -// int num_devices{}; -// CUDA_RETURN_IF_ERROR(cudaGetDeviceCount(&num_devices)); -// ORT_RETURN_IF_NOT( -// 0 <= info.device_id && info.device_id < num_devices, -// "Invalid device ID: ", info.device_id, -// ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); -// return Status::OK(); -// }) -// .AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations) -// .AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) -// .AddValueParser( -// tensorrt::provider_option_names::kUserComputeStream, -// [&user_compute_stream](const std::string& value_str) -> Status { -// size_t address; -// ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); -// user_compute_stream = reinterpret_cast(address); -// return Status::OK(); -// }) -// .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) -// .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) -// .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) -// .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) -// .AddAssignmentToReference(tensorrt::provider_option_names::kDLAEnable, info.dla_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kDLACore, info.dla_core) -// .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) -// .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) -// .AddAssignmentToReference(tensorrt::provider_option_names::kWeightStrippedEngineEnable, info.weight_stripped_engine_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kOnnxModelFolderPath, info.onnx_model_folder_path) -// .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix) -// .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) -// .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) -// .AddAssignmentToReference(tensorrt::provider_option_names::kContextMemorySharingEnable, info.context_memory_sharing_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kLayerNormFP32Fallback, info.layer_norm_fp32_fallback) -// .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCachePath, info.timing_cache_path) -// .AddAssignmentToReference(tensorrt::provider_option_names::kForceTimingCacheMatch, info.force_timing_cache) -// .AddAssignmentToReference(tensorrt::provider_option_names::kDetailedBuildLog, info.detailed_build_log) -// .AddAssignmentToReference(tensorrt::provider_option_names::kBuildHeuristics, info.build_heuristics_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kSparsityEnable, info.sparsity_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level) -// .AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams) -// .AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources) -// .AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths) -// .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes) -// .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) -// .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) -// .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) -// .AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model) -// .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) -// .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) -// .AddAssignmentToReference(tensorrt::provider_option_names::kEngineHwCompatible, info.engine_hw_compatible) -// .Parse(options)); // add new provider option here. -// -// info.user_compute_stream = user_compute_stream; -// info.has_user_compute_stream = (user_compute_stream != nullptr); + + void* user_compute_stream = nullptr; + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddValueParser( + tensorrt::provider_option_names::kDeviceId, + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + int num_devices{}; + CUDA_RETURN_IF_ERROR(cudaGetDeviceCount(&num_devices)); + ORT_RETURN_IF_NOT( + 0 <= info.device_id && info.device_id < num_devices, + "Invalid device ID: ", info.device_id, + ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); + return Status::OK(); + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations) + .AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + tensorrt::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) + .AddAssignmentToReference(tensorrt::provider_option_names::kDLAEnable, info.dla_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDLACore, info.dla_core) + .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kWeightStrippedEngineEnable, info.weight_stripped_engine_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kOnnxModelFolderPath, info.onnx_model_folder_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix) + .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) + .AddAssignmentToReference(tensorrt::provider_option_names::kContextMemorySharingEnable, info.context_memory_sharing_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kLayerNormFP32Fallback, info.layer_norm_fp32_fallback) + .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCachePath, info.timing_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kForceTimingCacheMatch, info.force_timing_cache) + .AddAssignmentToReference(tensorrt::provider_option_names::kDetailedBuildLog, info.detailed_build_log) + .AddAssignmentToReference(tensorrt::provider_option_names::kBuildHeuristics, info.build_heuristics_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kSparsityEnable, info.sparsity_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level) + .AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams) + .AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources) + .AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model) + .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineHwCompatible, info.engine_hw_compatible) + .Parse(options)); // add new provider option here. + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); return info; } diff --git a/samples/tensorRTEp/tensorrt_execution_provider_info.h b/samples/tensorRTEp/tensorrt_execution_provider_info.h index 92a14daf539e8..5ca1d6df0a6e2 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_info.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_info.h @@ -4,7 +4,8 @@ #pragma once #include -#include "core/framework/provider_options.h" +#include "provider_options.h" +#include "common.h" #define TRT_DEFAULT_OPTIMIZER_LEVEL 3 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 75fbf5d0851ae..8d49aaf37402b 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -579,6 +579,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument( "--use_tvm_hash", action="store_true", help="Build ipp-crypto for hash generation. It is used by TVM EP only" ) + parser.add_argument("--test_tensorrt_ep_plugin", action="store_true", help="Build with TensorRT EP Plugin Test App") parser.add_argument("--use_tensorrt", action="store_true", help="Build with TensorRT") parser.add_argument( "--use_tensorrt_builtin_parser", action="store_true", default=True, help="Use TensorRT builtin parser" @@ -1027,6 +1028,7 @@ def generate_build_tree( "-Donnxruntime_USE_LLVM=" + ("ON" if args.use_tvm else "OFF"), "-Donnxruntime_ENABLE_MICROSOFT_INTERNAL=" + ("ON" if args.enable_msinternal else "OFF"), "-Donnxruntime_USE_VITISAI=" + ("ON" if args.use_vitisai else "OFF"), + "-Donnxruntime_TEST_TENSORRT_EP_PLUGIN=" + ("ON" if args.test_tensorrt_ep_plugin else "OFF"), "-Donnxruntime_USE_TENSORRT=" + ("ON" if args.use_tensorrt else "OFF"), "-Donnxruntime_USE_TENSORRT_BUILTIN_PARSER=" + ("ON" if args.use_tensorrt_builtin_parser and not args.use_tensorrt_oss_parser else "OFF"), From b0b3123d5ef9a1bbd9db8a100a96d735fb9edb46 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Tue, 5 Nov 2024 22:30:55 +0800 Subject: [PATCH 64/81] add test for openvino plugin ep and fix bugs (#22734) --- samples/c_test/test.cpp | 10 ++++- samples/openvino/CMakeLists.txt | 1 + .../openvino/openvino_execution_provider.cc | 40 ++++++++++++++++++- .../openvino/openvino_execution_provider.h | 5 ++- 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index cbbd568a418d0..004c7b416ca28 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -40,6 +40,12 @@ void TestTensorRTAndCudaEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* g_ort->ReleaseCUDAProviderOptions(cuda_options); } +void TestOpenVinoEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/yangu/work/onnxruntime/samples/openvino/build/libOpenVINOEp.so", env, "openvinoEp")); + std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; + THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "openvinoEp", env, keys.data(), values.data(), keys.size())); +} + void TestOriginalTensorRTEp(const OrtApi* g_ort, OrtSessionOptions* so) { OrtTensorRTProviderOptionsV2* tensorrt_options = nullptr; THROW_ON_ERROR(g_ort->CreateTensorRTProviderOptions(&tensorrt_options)); @@ -275,7 +281,7 @@ void RunControlFlow(OrtEnv* p_env, OrtSessionOptions* so) { for (size_t i = 0; i < 2; i++) std::cout< #include +#include "provider_options_utils.h" #include "openvino_execution_provider.h" #include "openvino_utils.h" #include "ov_versions/capability.h" namespace onnxruntime { +OpenVINOExecutionProviderInfo OpenVINOExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { + OpenVINOExecutionProviderInfo info{}; + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddAssignmentToReference("device_type", info.device_type_) + .AddAssignmentToReference("precision", info.precision_) + .AddAssignmentToReference("enable_npu_fast_compile", info.enable_npu_fast_compile_) + .AddAssignmentToReference("cache_dir", info.cache_dir_) + .AddAssignmentToReference("model_priority", info.model_priority_) + .AddAssignmentToReference("num_streams", info.num_streams_) + .AddAssignmentToReference("context", info.context_) + .AddAssignmentToReference("enable_opencl_throttling", info.enable_opencl_throttling_) + .AddAssignmentToReference("disable_dynamic_shapes", info.disable_dynamic_shapes_) + .AddAssignmentToReference("num_of_threads", info.num_of_threads_) + .AddAssignmentToReference("export_ep_ctx_blob", info.export_ep_ctx_blob_) + .AddAssignmentToReference("enable_qdq_optimizer", info.enable_qdq_optimizer_) + .AddAssignmentToReference("disable_cpu_fallback", info.disable_cpu_fallback_) + .Parse(options)); + return info; +} + const OrtApi* OpenVINOExecutionProvider::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); const OrtGraphApi* OpenVINOExecutionProvider::graph_api_ = OpenVINOExecutionProvider::api_->GetGraphApi(ORT_API_VERSION); @@ -123,8 +145,24 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const } delete[] indexed_sub_graphs; }; + type = ep_type; + info_ = OpenVINOExecutionProviderInfo::FromProviderOptions(provider_options); + global_context_ = std::make_unique(); + global_context_->device_type = info_.device_type_; + global_context_->precision_str = info_.precision_; + global_context_->enable_npu_fast_compile = info_.enable_npu_fast_compile_; + global_context_->cache_dir = info_.cache_dir_; + global_context_->model_priority = info_.model_priority_; + global_context_->num_streams = info_.num_streams_; + global_context_->context = info_.context_; + global_context_->enable_opencl_throttling = info_.enable_opencl_throttling_; + global_context_->disable_dynamic_shapes = info_.disable_dynamic_shapes_; + global_context_->num_of_threads = info_.num_of_threads_; + global_context_->OpenVINO_Version = {OPENVINO_VERSION_MAJOR, OPENVINO_VERSION_MINOR}; + global_context_->export_ep_ctx_blob = info_.export_ep_ctx_blob_; + global_context_->enable_qdq_optimizer = info_.enable_qdq_optimizer_; + global_context_->disable_cpu_fallback = info_.disable_cpu_fallback_; } - OpenVINOExecutionProviderFactory::OpenVINOExecutionProviderFactory() { OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size) -> OrtExecutionProvider* { ProviderOptions options; diff --git a/samples/openvino/openvino_execution_provider.h b/samples/openvino/openvino_execution_provider.h index d013005ff1825..7364e7c48b0bd 100644 --- a/samples/openvino/openvino_execution_provider.h +++ b/samples/openvino/openvino_execution_provider.h @@ -83,7 +83,9 @@ struct OpenVINOExecutionProviderInfo { bool enable_qdq_optimizer_{false}; bool disable_cpu_fallback_{false}; - OpenVINOExecutionProviderInfo() = delete; + OpenVINOExecutionProviderInfo(){}; + + static OpenVINOExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); explicit OpenVINOExecutionProviderInfo(std::string dev_type, std::string precision, bool enable_npu_fast_compile, size_t num_of_threads, std::string cache_dir, std::string model_priority, @@ -159,6 +161,7 @@ class OpenVINOExecutionProvider : public OrtExecutionProvider { ~OpenVINOExecutionProvider() = default; private: + mutable OpenVINOExecutionProviderInfo info_; std::unique_ptr global_context_; openvino_ep::EPCtxHandler ep_ctx_handle_{}; std::unordered_map> backend_managers_; From 9dbb0b12832000c50e5fa2cb9952b63540f47715 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 6 Nov 2024 01:45:59 +0000 Subject: [PATCH 65/81] add missing mutex to plugin trt ep --- samples/tensorRTEp/tensorrt_execution_provider.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 50e306a345b87..108e915d0856a 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -3455,7 +3455,8 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi this_->input_info_[context->node_name], this_->output_info_[context->node_name], this_->context_memory_sharing_enable_, - &this_->max_ctx_mem_size_}; + &this_->max_ctx_mem_size_, + &this_->tensorrt_mu_}; *state = p.release(); return 0; }; From 5a598033290d216fc56be5fda33fea579f51c504 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Tue, 5 Nov 2024 17:55:27 -0800 Subject: [PATCH 66/81] merge code --- samples/c_test/CMakeLists.txt | 18 ++++- samples/c_test/test.cpp | 71 ++++++++++++++++--- samples/openvino/CMakeLists.txt | 2 +- .../openvino/openvino_execution_provider.cc | 34 ++++----- 4 files changed, 96 insertions(+), 29 deletions(-) diff --git a/samples/c_test/CMakeLists.txt b/samples/c_test/CMakeLists.txt index 9a460ecb72560..4dc73a76d7e34 100644 --- a/samples/c_test/CMakeLists.txt +++ b/samples/c_test/CMakeLists.txt @@ -2,10 +2,26 @@ # cd build/ # cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug # cmake --build ./ +# NOTE: For Windows, copy onnxruntime.dll and onnxruntime.pdb into the same folder of TestOutTreeEp.exe, otherwise, during runtime, +# it will search the default system path (C:\Windows\System32) for onnxruntime.dll cmake_minimum_required(VERSION 3.26) project(TestOutTreeEp) add_executable(TestOutTreeEp test.cpp) target_include_directories(TestOutTreeEp PUBLIC "../../include/onnxruntime") -#target_link_libraries(TestOutTreeEp PUBLIC "/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so") +if (WIN32) +#find_library(ORT_LIB NAMES onnxruntime PATHS "C:/Users/leca/source/onnxruntime/build/Windows/Debug/Debug" NO_DEFAULT_PATH) +#target_link_libraries(TestOutTreeEp PUBLIC ${ORT_LIB}) + +target_link_libraries(TestOutTreeEp PUBLIC "C:/Users/leca/source/onnxruntime/build/Windows/Debug/Debug/onnxruntime.lib") + +#add_library(onnxruntime SHARED IMPORTED) +#set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION "C:/Users/leca/source/onnxruntime/build/Windows/Debug/Debug/onnxruntime.dll") +#target_link_libraries(TestOutTreeEp PUBLIC onnxruntime) + +#link_directories("C:/Users/leca/source/onnxruntime/build/Windows/Debug/Debug") +#target_link_libraries(TestOutTreeEp PUBLIC onnxruntime.lib) + +elseif (LINUX) target_link_libraries(TestOutTreeEp PUBLIC "/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so") +endif() diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 004c7b416ca28..dba9684e5ccf4 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -12,25 +12,41 @@ inline void THROW_ON_ERROR(OrtStatus* status) { } void TestCompileBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary(L"/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", env, "outTreeEp")); +#else THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", env, "outTreeEp")); +#endif std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "outTreeEp", env, keys.data(), values.data(), keys.size())); } void TestKernelBasedEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary(L"/home/leca/code/onnxruntime/samples/outTreeEp_kernel/build/libkernelEp.so", env, "kernelEp")); +#else THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp_kernel/build/libkernelEp.so", env, "kernelEp")); +#endif std::vector keys{"int_property", "str_property"}, values{"3", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "kernelEp", env, keys.data(), values.data(), keys.size())); } void TestTensorRTEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary(L"/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); +#else THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); +#endif std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); } void TestTensorRTAndCudaEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary(L"/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); +#else THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); +#endif std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); @@ -40,12 +56,6 @@ void TestTensorRTAndCudaEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* g_ort->ReleaseCUDAProviderOptions(cuda_options); } -void TestOpenVinoEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { - THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/yangu/work/onnxruntime/samples/openvino/build/libOpenVINOEp.so", env, "openvinoEp")); - std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; - THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "openvinoEp", env, keys.data(), values.data(), keys.size())); -} - void TestOriginalTensorRTEp(const OrtApi* g_ort, OrtSessionOptions* so) { OrtTensorRTProviderOptionsV2* tensorrt_options = nullptr; THROW_ON_ERROR(g_ort->CreateTensorRTProviderOptions(&tensorrt_options)); @@ -59,11 +69,25 @@ void TestOriginalTensorRTEp(const OrtApi* g_ort, OrtSessionOptions* so) { g_ort->ReleaseTensorRTProviderOptions(tensorrt_options); } +void TestOpenVINOEp(OrtEnv* env, OrtSessionOptions* so) { +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary(L"C:/Users/leca/source/onnxruntime/samples/openvino/build/Debug/OpenVINOEp.dll", env, "OpenVINOEp")); +#else + THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/yangu/work/onnxruntime/samples/openvino/build/libOpenVINOEp.so", env, "OpenVINOEp")); +#endif + std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; + THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "OpenVINOEp", env, keys.data(), values.data(), keys.size())); +} + void RunResnet18v1_7(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { // download resnet18-v1-7 model at: // https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet18-v1-7.tar.gz OrtSession* session = nullptr; +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->CreateSession(p_env, L"/home/leca/models/resnet18-v1-7/resnet18-v1-7.onnx", so, &session)); +#else THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/resnet18-v1-7/resnet18-v1-7.onnx", so, &session)); +#endif const int input_data_cnt = 3 * 224 * 224; float input_data[input_data_cnt]; @@ -92,7 +116,11 @@ void RunResnet18v1_7(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) void RunRelu(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { OrtSession* session = nullptr; +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->CreateSession(p_env, L"C:/share/models/relu/Relu.onnx", so, &session)); +#else THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/code/onnxruntime/samples/c_test/Relu.onnx", so, &session)); +#endif OrtMemoryInfo* memory_info = nullptr; THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); @@ -117,7 +145,11 @@ void RunRelu(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { void RunDecoder(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { OrtSession* session = nullptr; +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->CreateSession(p_env, L"/home/leca/models/decoder/decoder.onnx", so, &session)); +#else THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/decoder/decoder.onnx", so, &session)); +#endif OrtMemoryInfo* memory_info = nullptr; THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); @@ -175,7 +207,11 @@ void RunDecoder(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { OrtSession* session = nullptr; +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->CreateSession(p_env, L"/home/leca/models/faster_rcnn/faster_rcnn_R_50_FPN_1x.onnx", so, &session)); +#else THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/faster_rcnn/faster_rcnn_R_50_FPN_1x.onnx", so, &session)); +#endif OrtMemoryInfo* memory_info = nullptr; THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); @@ -204,9 +240,13 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so, const char* model) { OrtSession* session = nullptr; +#ifdef _WIN32 + if (!strcmp(model, "tyolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, L"/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session)); + else if (!strcmp(model, "yolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, L"/home/leca/models/yolov3/yolov3.onnx", so, &session)); +#else if (!strcmp(model, "tyolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session)); else if (!strcmp(model, "yolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/yolov3/yolov3.onnx", so, &session)); - +#endif OrtMemoryInfo* memory_info = nullptr; THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); @@ -246,7 +286,11 @@ void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so, const char* model) { void RunControlFlow(OrtEnv* p_env, OrtSessionOptions* so) { OrtSession* session = nullptr; +#ifdef _WIN32 + THROW_ON_ERROR(g_ort->CreateSession(p_env, L"/home/leca/models/control_flow/control_flow_model.onnx", so, &session)); +#else THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/control_flow/control_flow_model.onnx", so, &session)); +#endif OrtMemoryInfo* memory_info = nullptr; THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); @@ -281,8 +325,13 @@ void RunControlFlow(OrtEnv* p_env, OrtSessionOptions* so) { for (size_t i = 0; i < 2; i++) std::cout<> a; +#endif OrtEnv* p_env = nullptr; OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;//OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env)); @@ -339,8 +388,8 @@ int main(int argc, char *argv[]) { TestTensorRTAndCudaEp(g_ort, p_env, so); } else if (strcmp(argv[1], "otc") == 0) { TestOriginalTensorRTEp(g_ort, so); - } else if (strcmp(argv[1], "ov") == 0) { - TestOpenVinoEp(g_ort, p_env, so); + } else if (!strcmp(argv[1], "o")) { + TestOpenVINOEp(p_env, so); } if (!strcmp(argv[2], "relu")) { @@ -361,6 +410,8 @@ int main(int argc, char *argv[]) { g_ort->UnregisterPluginExecutionProviderLibrary(p_env, "kernelEp"); } else if (!strcmp(argv[1], "t") || !strcmp(argv[1], "tc")) { g_ort->UnregisterPluginExecutionProviderLibrary(p_env, "tensorrtEp"); + } else if (!strcmp(argv[1], "o")) { + g_ort->UnregisterPluginExecutionProviderLibrary(p_env, "OpenVINOEp"); } g_ort->ReleaseEnv(p_env); diff --git a/samples/openvino/CMakeLists.txt b/samples/openvino/CMakeLists.txt index 0a5ad8ee435b0..e80019071032b 100644 --- a/samples/openvino/CMakeLists.txt +++ b/samples/openvino/CMakeLists.txt @@ -14,7 +14,7 @@ add_definitions(-DONNX_ML) file(GLOB openvino_src "./*.cc" "./ov_versions/*.cc" "./backends/*.cc") add_library(OpenVINOEp SHARED ${openvino_src}) target_include_directories(OpenVINOEp PUBLIC "../../include/onnxruntime" - "../utils" +# "../utils" ${OPENVINO_HOME}/include "../../build/Windows/Debug/_deps/onnx-src" "../../build/Windows/Debug/_deps/onnx-build" diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc index 40f1348f4af9a..bfe4982e875df 100644 --- a/samples/openvino/openvino_execution_provider.cc +++ b/samples/openvino/openvino_execution_provider.cc @@ -1,6 +1,6 @@ #include #include -#include "provider_options_utils.h" +//#include "provider_options_utils.h" #include "openvino_execution_provider.h" #include "openvino_utils.h" #include "ov_versions/capability.h" @@ -9,22 +9,22 @@ namespace onnxruntime { OpenVINOExecutionProviderInfo OpenVINOExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { OpenVINOExecutionProviderInfo info{}; - ORT_THROW_IF_ERROR( - ProviderOptionsParser{} - .AddAssignmentToReference("device_type", info.device_type_) - .AddAssignmentToReference("precision", info.precision_) - .AddAssignmentToReference("enable_npu_fast_compile", info.enable_npu_fast_compile_) - .AddAssignmentToReference("cache_dir", info.cache_dir_) - .AddAssignmentToReference("model_priority", info.model_priority_) - .AddAssignmentToReference("num_streams", info.num_streams_) - .AddAssignmentToReference("context", info.context_) - .AddAssignmentToReference("enable_opencl_throttling", info.enable_opencl_throttling_) - .AddAssignmentToReference("disable_dynamic_shapes", info.disable_dynamic_shapes_) - .AddAssignmentToReference("num_of_threads", info.num_of_threads_) - .AddAssignmentToReference("export_ep_ctx_blob", info.export_ep_ctx_blob_) - .AddAssignmentToReference("enable_qdq_optimizer", info.enable_qdq_optimizer_) - .AddAssignmentToReference("disable_cpu_fallback", info.disable_cpu_fallback_) - .Parse(options)); +// ORT_THROW_IF_ERROR( +// ProviderOptionsParser{} +// .AddAssignmentToReference("device_type", info.device_type_) +// .AddAssignmentToReference("precision", info.precision_) +// .AddAssignmentToReference("enable_npu_fast_compile", info.enable_npu_fast_compile_) +// .AddAssignmentToReference("cache_dir", info.cache_dir_) +// .AddAssignmentToReference("model_priority", info.model_priority_) +// .AddAssignmentToReference("num_streams", info.num_streams_) +// .AddAssignmentToReference("context", info.context_) +// .AddAssignmentToReference("enable_opencl_throttling", info.enable_opencl_throttling_) +// .AddAssignmentToReference("disable_dynamic_shapes", info.disable_dynamic_shapes_) +// .AddAssignmentToReference("num_of_threads", info.num_of_threads_) +// .AddAssignmentToReference("export_ep_ctx_blob", info.export_ep_ctx_blob_) +// .AddAssignmentToReference("enable_qdq_optimizer", info.enable_qdq_optimizer_) +// .AddAssignmentToReference("disable_cpu_fallback", info.disable_cpu_fallback_) +// .Parse(options)); return info; } From 084f7350d1445a44008b9399fe3329235a24265c Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Wed, 6 Nov 2024 20:22:58 +0800 Subject: [PATCH 67/81] fix bugs (#22744) --- samples/c_test/test.cpp | 2 +- samples/openvino/CMakeLists.txt | 4 +-- .../openvino/openvino_execution_provider.cc | 34 +++++++++---------- samples/openvino/openvino_utils.cc | 4 ++- samples/openvino/ov_versions/utils.cc | 2 ++ 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index dba9684e5ccf4..d0bf543f72b8f 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -75,7 +75,7 @@ void TestOpenVINOEp(OrtEnv* env, OrtSessionOptions* so) { #else THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/yangu/work/onnxruntime/samples/openvino/build/libOpenVINOEp.so", env, "OpenVINOEp")); #endif - std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; + std::vector keys{"device_type", "precision"}, values{"CPU", "FP32"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "OpenVINOEp", env, keys.data(), values.data(), keys.size())); } diff --git a/samples/openvino/CMakeLists.txt b/samples/openvino/CMakeLists.txt index e80019071032b..5ec99e56f6a36 100644 --- a/samples/openvino/CMakeLists.txt +++ b/samples/openvino/CMakeLists.txt @@ -11,10 +11,10 @@ list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime) add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) -file(GLOB openvino_src "./*.cc" "./ov_versions/*.cc" "./backends/*.cc") +file(GLOB openvino_src "./*.cc" "./ov_versions/*.cc" "./backends/*.cc" "../utils/status.cc") add_library(OpenVINOEp SHARED ${openvino_src}) target_include_directories(OpenVINOEp PUBLIC "../../include/onnxruntime" -# "../utils" + "../utils" ${OPENVINO_HOME}/include "../../build/Windows/Debug/_deps/onnx-src" "../../build/Windows/Debug/_deps/onnx-build" diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc index bfe4982e875df..38b4e5864bafc 100644 --- a/samples/openvino/openvino_execution_provider.cc +++ b/samples/openvino/openvino_execution_provider.cc @@ -1,6 +1,6 @@ #include #include -//#include "provider_options_utils.h" +#include "provider_options_utils.h" #include "openvino_execution_provider.h" #include "openvino_utils.h" #include "ov_versions/capability.h" @@ -9,22 +9,22 @@ namespace onnxruntime { OpenVINOExecutionProviderInfo OpenVINOExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { OpenVINOExecutionProviderInfo info{}; -// ORT_THROW_IF_ERROR( -// ProviderOptionsParser{} -// .AddAssignmentToReference("device_type", info.device_type_) -// .AddAssignmentToReference("precision", info.precision_) -// .AddAssignmentToReference("enable_npu_fast_compile", info.enable_npu_fast_compile_) -// .AddAssignmentToReference("cache_dir", info.cache_dir_) -// .AddAssignmentToReference("model_priority", info.model_priority_) -// .AddAssignmentToReference("num_streams", info.num_streams_) -// .AddAssignmentToReference("context", info.context_) -// .AddAssignmentToReference("enable_opencl_throttling", info.enable_opencl_throttling_) -// .AddAssignmentToReference("disable_dynamic_shapes", info.disable_dynamic_shapes_) -// .AddAssignmentToReference("num_of_threads", info.num_of_threads_) -// .AddAssignmentToReference("export_ep_ctx_blob", info.export_ep_ctx_blob_) -// .AddAssignmentToReference("enable_qdq_optimizer", info.enable_qdq_optimizer_) -// .AddAssignmentToReference("disable_cpu_fallback", info.disable_cpu_fallback_) -// .Parse(options)); + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddAssignmentToReference("device_type", info.device_type_) + .AddAssignmentToReference("precision", info.precision_) + .AddAssignmentToReference("enable_npu_fast_compile", info.enable_npu_fast_compile_) + .AddAssignmentToReference("cache_dir", info.cache_dir_) + .AddAssignmentToReference("model_priority", info.model_priority_) + .AddAssignmentToReference("num_streams", info.num_streams_) + .AddAssignmentToReference("context", info.context_) + .AddAssignmentToReference("enable_opencl_throttling", info.enable_opencl_throttling_) + .AddAssignmentToReference("disable_dynamic_shapes", info.disable_dynamic_shapes_) + .AddAssignmentToReference("num_of_threads", info.num_of_threads_) + .AddAssignmentToReference("export_ep_ctx_blob", info.export_ep_ctx_blob_) + .AddAssignmentToReference("enable_qdq_optimizer", info.enable_qdq_optimizer_) + .AddAssignmentToReference("disable_cpu_fallback", info.disable_cpu_fallback_) + .Parse(options)); return info; } diff --git a/samples/openvino/openvino_utils.cc b/samples/openvino/openvino_utils.cc index 8bf306005c414..b3d4a37e4ca7e 100644 --- a/samples/openvino/openvino_utils.cc +++ b/samples/openvino/openvino_utils.cc @@ -20,8 +20,10 @@ namespace onnxruntime { buffer.resize(char_count); return buffer; } +#else + char* val = getenv(var_name.c_str()); + return val == NULL ? std::string() : std::string(val); #endif - return std::string(); } OrtStatus* ForEachNodeDef(const OrtGraphApi* graph_api, const OrtGraphViewer* graph, const OrtNode* node, diff --git a/samples/openvino/ov_versions/utils.cc b/samples/openvino/ov_versions/utils.cc index 5db4458edc0eb..877f7c10a4a33 100644 --- a/samples/openvino/ov_versions/utils.cc +++ b/samples/openvino/ov_versions/utils.cc @@ -78,11 +78,13 @@ void AppendClusterToSubGraph(const size_t* node_index, size_t node_count, meta_def->input_len = inputs.size(); meta_def->inputs = new char* [inputs.size()]; for (int i = 0; i < inputs.size(); i++) { + meta_def->inputs[i] = new char [inputs[i].length() + 1]; strcpy(meta_def->inputs[i], inputs[i].c_str()); } meta_def->output_len = outputs.size(); meta_def->outputs = new char* [outputs.size()]; for (int i = 0; i < outputs.size(); i++) { + meta_def->outputs[i] = new char [outputs[i].length() + 1]; strcpy(meta_def->outputs[i], outputs[i].c_str()); } From 2b1cfdf905216d74038ee751a9df547f3ace0b10 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Wed, 6 Nov 2024 17:40:52 -0800 Subject: [PATCH 68/81] relu and resnet works in OpenVINO plugin --- samples/c_test/test.cpp | 2 +- samples/openvino/CMakeLists.txt | 1 + samples/openvino/openvino_execution_provider.cc | 1 + samples/openvino/openvino_utils.cc | 3 ++- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index d0bf543f72b8f..22704953101bf 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -84,7 +84,7 @@ void RunResnet18v1_7(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) // https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet18-v1-7.tar.gz OrtSession* session = nullptr; #ifdef _WIN32 - THROW_ON_ERROR(g_ort->CreateSession(p_env, L"/home/leca/models/resnet18-v1-7/resnet18-v1-7.onnx", so, &session)); + THROW_ON_ERROR(g_ort->CreateSession(p_env, L"C:/share/models/resnet18-v1-7/resnet18-v1-7.onnx", so, &session)); #else THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/resnet18-v1-7/resnet18-v1-7.onnx", so, &session)); #endif diff --git a/samples/openvino/CMakeLists.txt b/samples/openvino/CMakeLists.txt index 5ec99e56f6a36..c5c7225ca127a 100644 --- a/samples/openvino/CMakeLists.txt +++ b/samples/openvino/CMakeLists.txt @@ -11,6 +11,7 @@ list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime) add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) +add_definitions(-DOPENVINO_CONFIG_CPU) file(GLOB openvino_src "./*.cc" "./ov_versions/*.cc" "./backends/*.cc" "../utils/status.cc") add_library(OpenVINOEp SHARED ${openvino_src}) target_include_directories(OpenVINOEp PUBLIC "../../include/onnxruntime" diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc index 38b4e5864bafc..d7d5331037742 100644 --- a/samples/openvino/openvino_execution_provider.cc +++ b/samples/openvino/openvino_execution_provider.cc @@ -120,6 +120,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const char* ep_type, const auto function_state = static_cast(state); try { function_state->backend_manager->Compute(context); + std::cout << "In OpenVinoEP::Compile()'s ComputeFunc\n"; } catch (const std::exception& ex) { return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, ex.what()); } diff --git a/samples/openvino/openvino_utils.cc b/samples/openvino/openvino_utils.cc index b3d4a37e4ca7e..ffcccb5104fba 100644 --- a/samples/openvino/openvino_utils.cc +++ b/samples/openvino/openvino_utils.cc @@ -22,8 +22,9 @@ namespace onnxruntime { } #else char* val = getenv(var_name.c_str()); - return val == NULL ? std::string() : std::string(val); + return val == nullptr ? std::string() : std::string(val); #endif + return std::string(); } OrtStatus* ForEachNodeDef(const OrtGraphApi* graph_api, const OrtGraphViewer* graph, const OrtNode* node, From e337d8f1371b7084add5300a37c9c899b4e3892a Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:44:17 -0800 Subject: [PATCH 69/81] Add OrtGraphApis::OrtNode_GetAttributeStrWithSize to handle case where attribute might contain null character (#22769) When running EP Context model, EP might call `OrtGraphApis::OrtNode_GetAttributeStr` to get the string-based content of the attribute. However, the API returns the c_str() of the string, and it's possible that the cache context contains null character, so the string might be cut off and caller ends up getting the wrong string. Add a new OrtGraphApis::OrtNode_GetAttributeStrWithSize to return const char* pointer and string size. --- .../core/session/onnxruntime_c_api_ep.h | 21 +++++++++++++++++++ .../core/session/onnxruntime_c_api_ep.cc | 16 ++++++++++++++ onnxruntime/core/session/ort_apis_ep.h | 4 ++++ samples/tensorRTEp/onnx_ctx_model_helper.cc | 5 +++-- 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index df6ad6ecb85f9..43cbba444ad1d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -556,6 +556,17 @@ ORT_API2_STATUS(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* k */ ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Outptr_ const char** out); +/** \brief Gets the i-th string in the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[in] i The index of the string + * \param[out] out The i-th string in the attribute + * \param[out] size The length of the string + * + */ +ORT_API2_STATUS(OrtNode_GetAttributeIthStrWithSize, const OrtNode* node, const char* key, int i, _Outptr_ const char** out, _Outptr_ size_t* size); + /** \brief Gets the string value of the attribute with the given key. * * \param[in] node The node to query @@ -565,6 +576,16 @@ ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key */ ORT_API2_STATUS(OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out); +/** \brief Gets the string value of the attribute with the given key. + * + * \param[in] node The node to query + * \param[in] key The attribute key + * \param[out] out The string value of the attribute + * \param[out] size The length of the string + * + */ +ORT_API2_STATUS(OrtNode_GetAttributeStrWithSize, const OrtNode* node, const char* key, _Outptr_ const char** out, _Outptr_ size_t* size); + /** \brief Gets the int value of the attribute with the given key. * * \param[in] node The node to query diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index 30740b9773b83..e834072cc36eb 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -752,12 +752,26 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeIthStr, const OrtNode* nod return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeIthStrWithSize, const OrtNode* node, const char* key, int i, _Outptr_ const char** out, _Outptr_ size_t* size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *size = n->GetAttributes().at(key).strings(i).size(); + *out = n->GetAttributes().at(key).strings(i).c_str(); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); *out = n->GetAttributes().at(key).s().c_str(); return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeStrWithSize, const OrtNode* node, const char* key, _Outptr_ const char** out, _Outptr_ size_t* size) { + const ::onnxruntime::Node* n = reinterpret_cast(node); + *size = n->GetAttributes().at(key).s().size(); + *out = n->GetAttributes().at(key).s().c_str(); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeInt, const OrtNode* node, const char* key, _Out_ int64_t* out) { const ::onnxruntime::Node* n = reinterpret_cast(node); *out = n->GetAttributes().at(key).i(); @@ -841,7 +855,9 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtNode_GetAttributeIthInt, &OrtGraphApis::OrtNode_GetAttributeIthFloat, &OrtGraphApis::OrtNode_GetAttributeIthStr, + &OrtGraphApis::OrtNode_GetAttributeIthStrWithSize, &OrtGraphApis::OrtNode_GetAttributeStr, + &OrtGraphApis::OrtNode_GetAttributeStrWithSize, &OrtGraphApis::OrtNode_GetAttributeInt, &OrtGraphApis::OrtNode_GetAttributeFloat, &OrtGraphApis::OrtNode_GetSubgraphs, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index 7e010e8f8a2c4..a448807573373 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -105,8 +105,12 @@ ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthFloat, const OrtNode* node, const cha ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Outptr_ const char** out); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthStrWithSize, const OrtNode* node, const char* key, int i, _Outptr_ const char** out, _Outptr_ size_t* size); + ORT_API_STATUS_IMPL(OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out); +ORT_API_STATUS_IMPL(OrtNode_GetAttributeStrWithSize, const OrtNode* node, const char* key, _Outptr_ const char** out, _Outptr_ size_t* size); + ORT_API_STATUS_IMPL(OrtNode_GetAttributeInt, const OrtNode* node, const char* key, _Out_ int64_t* out); ORT_API_STATUS_IMPL(OrtNode_GetAttributeFloat, const OrtNode* node, const char* key, _Out_ float* out); diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index 426d2484ef98d..65311f56f3f91 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -127,8 +127,9 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView if (embed_mode) { // Get engine from byte stream. const char* context_binary_cstr = nullptr; - graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr); - std::string context_binary(context_binary_cstr); + size_t size; + graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr, &szie); + std::string context_binary(context_binary_cstr, size); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; From afe92e16d63b6aa4869541983d64b0e19d2d32b9 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:44:36 -0800 Subject: [PATCH 70/81] Make EP plugin be able to create and update EP Context graph (#22740) This PR support several features: - Add new graph API to create and update EP Context graph, and dump EP Context model. 1. OrtGraph_CreateOrUpdateEpCtxGraph 2. OrtGraph_DumpOnnxModel 3. OrtGraph_ReleaseGraph - Add new graph API to dump onnx model - The APIs provided by this PR can dump EP Context model when the whole model can be run by one EP, the APIs also aim to support the case where the whole model is partitioned into multiple EP's subgraphs. (Note: i haven't fully tested the partitioning case, please help review it) - Modify TRT EP plugin to use those APIs. --- .../core/session/onnxruntime_c_api_ep.h | 61 +++++- .../core/session/onnxruntime_c_api_ep.cc | 194 +++++++++++++++++- onnxruntime/core/session/ort_apis_ep.h | 19 +- .../tensorRTEp/tensorrt_execution_provider.cc | 118 ++++++----- .../tensorRTEp/tensorrt_execution_provider.h | 9 + 5 files changed, 346 insertions(+), 55 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index 43cbba444ad1d..f20607a893d20 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -333,6 +333,48 @@ ORT_API2_STATUS(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info); * */ ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); // TODO(leca): review and discuss + +/** \brief Serialize the graph(model) to disk. + * + * \param[in] graph The graph to be serialized + * \param[in] onnx_model_path The file path to save to + * + */ +ORT_API2_STATUS(OrtGraph_DumpOnnxModel, const OrtGraph* graph, const char* onnx_model_path); + +/** \brief Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise: + * 1. if the given node name can't be found in the graph, add an new "EP Context" node to the existing graph + * 2. if the node being found with the givne node name, update the node attributes only + * + * Please see https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html for more details about EP Context design + * + * \param[in] graph The graph to create or add + * \param[in] node_name The node to be added or updated + * \param[in] main_context The attribute of EP Context op + * \param[in] embed_mode The attribute of EP Context op + * \param[in] cache_path The cache or binary file path. It's for setting the ep_cache_context attribute if embed_mode is 0 + * \param[in] cache_data The cache or binary data. It's for setting the ep_cache_context attribute if embed_mode is 1 + * \param[in] size The size of cache data. + * \param[in] extra_attr_keys The other attribute names + * \param[in] extra_attr_values The other attribute value in string + * \param[in] extra_attr_num Number of other attributes + * \param[out] ep_context_graph The constructed or updated ep context graph + * + * \remarks The caller is responsible for releasing the ep_context_graph using OrtGraph_ReleaseGraph. + * + */ +ORT_API2_STATUS(OrtGraph_CreateOrUpdateEpCtxGraph, + const OrtGraphViewer* graph, + const char* node_name, + const int64_t main_context, + const int64_t embed_mode, + const char* cache_path, + char* cache_data, + size_t size, + const char* const* extra_attr_keys, + const char* const* extra_attr_values, + size_t extra_attr_num, + _Outptr_ OrtGraph** ep_context_graph); /** \brief Construct a subgraph from the Graph with the given node indices. * @@ -345,17 +387,28 @@ ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ vo * */ ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss + +/** \brief Release the graph instance. + * + * NOTE!!: Invoke this function after the use of OrtGraph_CreateOrUpdateEpCtxGraph. As OrtGraph_CreateOrUpdateEpCtxGraph allocates model instead of + * graph, this API releases graph's owning_model explicitly which in turn will release the graph + * (because graph is hosted in an unique_ptr in Model class) + * + * \param[in] graph The graph to release + * + */ +ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraph* graph); -/** \brief Release the graph. +/** \brief Release the graph viewer instance. * - * NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocate model instead of - * graph, this API release graph's owning_model explicitly which in turn will release the graph + * NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocates model instead of + * graph, this API releases graph's owning_model explicitly which in turn will release the graph * (because graph is hosted in an unique_ptr in Model class) * * \param[in] graph The graph to release * */ -ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph); +ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); /** \brief Gets the name of the node * diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index e834072cc36eb..af884ef353d22 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -9,6 +9,9 @@ #include "core/framework/tensorprotoutils.h" #include "core/session/ort_apis.h" +#include +#include + using namespace onnxruntime; ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetName, const OrtGraphViewer* graph, _Out_ const char** out) { @@ -477,6 +480,184 @@ static void SetAllGraphInputs(Graph& graph, std::unordered_map(graph); + auto model = &(internal_graph->GetModel()); + + // Two options to generate model proto: + // 1. directly call model->ToProto() + // 2. new model ---> model->ToProto ---> update graph proto in model proto with GraphViewerToProto() + // + // TODO: (Chi) Need more thinking on which to choose + + // option 1 + std::unique_ptr model_proto = std::make_unique(model->ToProto()); + + // option 2 + //auto model_proto = model->ToProto(); + //graph->ToProto(*model_proto->mutable_graph(), true, true); + //model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + std::fstream dump(onnx_model_path, std::ios::out | std::ios::trunc | std::ios::binary); + model_proto->SerializeToOstream(&dump); + //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path; + return nullptr; +} + +/* Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise: + * 1. if the given node name can't be found in the graph, add an new "EP Context" node to the existing graph + * 2. if the node is already existed, update the node attributes only + */ +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, + const OrtGraphViewer* graph, + const char* node_name, + const int64_t main_context, + const int64_t embed_mode, + const char* cache_path, + char* cache_data, + size_t size, + const char* const* extra_attr_keys, + const char* const* extra_attr_values, + size_t extra_attr_num, + _Outptr_ OrtGraph** ep_context_graph) { + + const std::string EPCONTEXT_OP = "EPContext"; + const std::string MAIN_CONTEXT = "main_context"; + const std::string EMBED_MODE = "embed_mode"; + const std::string EP_CACHE_CONTEXT = "ep_cache_context"; + const std::string ONNX_MODEL_FILENAME = "onnx_model_filename"; + const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; + const std::string EPCONTEXT_WARNING = + "It's suggested to set the ORT graph optimization level to 0 and \ + make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ + for the best model loading time"; + + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + ::onnxruntime::Graph* graph_build; + + if (!graph_viewer && !(*ep_context_graph)) return nullptr; + + std::unordered_map attr_keys_values; + for (size_t i = 0; i < extra_attr_num; i++) { + attr_keys_values[extra_attr_keys[i]] = extra_attr_values[i]; + } + + // Create a new graph or use the existing one + if (*ep_context_graph == nullptr) { + Model* model_build = new Model (graph_viewer->Name(), true, ModelMetaData(), PathString(), +#if !defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(), +#else + IOnnxRuntimeOpSchemaRegistryList(), graph_viewer->DomainToVersionMap(), +#endif // ORT_MINIMAL_BUILD + std::vector(), graph_viewer->GetGraph().GetLogger()); + graph_build = &(model_build->MainGraph()); + *ep_context_graph = reinterpret_cast(graph_build); + } else { + graph_build = reinterpret_cast<::onnxruntime::Graph*>(*ep_context_graph); + } + + // Get graph inputs and outputs + std::vector inputs, outputs; + if (graph_viewer) { + for (auto input : graph_viewer->GetInputs()) { + auto& n_input = graph_build->GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + inputs.push_back(&n_input); + } + + for (auto output : graph_viewer->GetOutputs()) { + auto& n_output = graph_build->GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + outputs.push_back(&n_output); + } + } + + // locate specific node if any + auto get_node_index = [&](Graph* graph, const char* node_name) -> size_t { + std::string name = node_name; + for (auto& node : graph->Nodes()) { + if (name == node.Name()) { + return node.Index(); + } + } + // return impossible value to indicate the node is not existed + return std::numeric_limits::max(); + }; + size_t node_idx = get_node_index(graph_build, node_name); + bool node_existed = node_idx != std::numeric_limits::max() ? true : false; + + // Create or get EP context node attributes + auto new_node_attributes = NodeAttributes(); // using NodeAttributes = std::unordered_map + NodeAttributes* node_attributes; + if (node_existed) { + node_attributes = &graph_build->GetNode(node_idx)->GetMutableAttributes(); + } else { + new_node_attributes.reserve(3 + extra_attr_num); + node_attributes = &new_node_attributes; + } + std::unique_ptr attr_0 = std::make_unique(); // main_context + std::unique_ptr attr_1 = std::make_unique(); // embed_mode + std::unique_ptr attr_2 = std::make_unique(); // ep_cache_context + + std::string cache_data_str = ""; + std::string cache_path_str = cache_path; + + // main_context + attr_0->set_name(MAIN_CONTEXT); + attr_0->set_type(onnx::AttributeProto_AttributeType_INT); + attr_0->set_i(main_context); + + // embed_mode + attr_1->set_name(EMBED_MODE); + attr_1->set_type(onnx::AttributeProto_AttributeType_INT); + attr_1->set_i(embed_mode); + + // ep_cache_context + attr_2->set_name(EP_CACHE_CONTEXT); + attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); + if (embed_mode) { + if (size > 0) { + cache_data_str.assign(cache_data, size); + } + attr_2->set_s(cache_data_str); + //LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + } else { + attr_2->set_s(cache_path_str); + } + + (*node_attributes)[MAIN_CONTEXT] = *attr_0; + (*node_attributes)[EMBED_MODE] = *attr_1; + (*node_attributes)[EP_CACHE_CONTEXT] = *attr_2; + + // other attributes + std::unordered_map::iterator it; + for (it = attr_keys_values.begin(); it != attr_keys_values.end(); ++it) { + std::string key = it->first; + std::string value = it->second; + if (key == ONNX_MODEL_FILENAME) value = std::filesystem::path(value).filename().string(); + + std::unique_ptr attr = std::make_unique(); + attr->set_name(key); + attr->set_type(onnx::AttributeProto_AttributeType_STRING); + attr->set_s(value); + (*node_attributes)[key] = *attr; + } + + if (!node_existed && graph_viewer) { + std::string name = node_name; + graph_build->AddNode(name, EPCONTEXT_OP, "", inputs, outputs, node_attributes, EPCONTEXT_OP_DOMAIN); + } + + common::Status status = graph_build->Resolve(); + if (status != Status::OK()) return onnxruntime::ToOrtStatus(status); + + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); // Get parent graph output names @@ -595,7 +776,15 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraph* ort_graph) { + if (ort_graph) { + const ::onnxruntime::Graph* graph = reinterpret_cast(ort_graph); + delete &(graph->GetModel()); + } + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph) { if (graph) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); delete &(graph_viewer->GetGraph()).GetModel(); @@ -830,8 +1019,11 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetValueInfo, &OrtGraphApis::OrtGraph_ReleaseValueInfo, &OrtGraphApis::OrtGraph_SerializeToArray, + &OrtGraphApis::OrtGraph_DumpOnnxModel, + &OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, &OrtGraphApis::OrtGraph_GetSubGraph, &OrtGraphApis::OrtGraph_ReleaseGraph, + &OrtGraphApis::OrtGraph_ReleaseGraphViewer, &OrtGraphApis::OrtNode_GetName, &OrtGraphApis::OrtNode_GetDescription, &OrtGraphApis::OrtNode_GetDomain, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index a448807573373..945539e2db07b 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -55,9 +55,26 @@ ORT_API_STATUS_IMPL(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info); ORT_API_STATUS_IMPL(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); +ORT_API_STATUS_IMPL(OrtGraph_DumpOnnxModel, const OrtGraph* graph, const char* onnx_model_path); + +ORT_API_STATUS_IMPL(OrtGraph_CreateOrUpdateEpCtxGraph, + const OrtGraphViewer* graph, + const char* node_name, + const int64_t main_context, + const int64_t embed_mode, + const char* cache_path, + char* cache_data, + size_t size, + const char* const* extra_attr_keys, + const char* const* extra_attr_values, + size_t extra_attr_num, + _Outptr_ OrtGraph** ep_context_graph); + ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); -ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph); +ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraph* graph); + +ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Outptr_ const char** out); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 108e915d0856a..ae44bb761a301 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -2097,6 +2097,19 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const auto lock = GetApiLock(); runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); } + + // EP Context setting + if (dump_ep_context_model_) { + extra_attr_keys_.push_back(k_ep_ctx_hardware_architecture.c_str()); + extra_attr_keys_.push_back(k_ep_ctx_onnx_model_filename.c_str()); + + if (engine_cache_enable_ && engine_hw_compatible_) { + extra_attr_values_.push_back(k_cc_hw_compatible.c_str()); + } else { + extra_attr_values_.push_back(compute_capability_.c_str()); + } + extra_attr_values_.push_back(model_path_); + } } TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory() { @@ -2518,10 +2531,36 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort weight_stripped_engine_refit_ = true; } - // Generate file name for dumping ep context model - if (dump_ep_context_model_ && ctx_model_path_.empty()) { - ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); - } + auto create_ep_context_model = [this] (const OrtGraphViewer* graph_body_viewer, + std::string& engine_cache_path, + std::string& engine_cache_relative_path_to_context_model_dir, + const char* ep_context_node_name, + char* serialized_engine, + size_t serialized_engine_size) { + // if ep context model name is not given, create a model name based on original model name + if (ctx_model_path_.empty()) { + ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); + } + + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + } + + graph_api_->OrtGraph_CreateOrUpdateEpCtxGraph(graph_body_viewer, + ep_context_node_name, + 1, // main_context + ep_context_embed_mode_, + ep_cache_context_attr_.c_str(), + serialized_engine, + serialized_engine_size, + extra_attr_keys_.data(), + extra_attr_values_.data(), + extra_attr_keys_.size(), + &ep_ctx_graph_); + + }; if (!has_dynamic_shape) { std::string timing_cache_path = ""; @@ -2655,26 +2694,12 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort //LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; } } - // dump EP context node model + + // create and dump ep context model if (dump_ep_context_model_) { - // "ep_cache_context" node attribute should be a relative path to context model directory - if (ep_cache_context_attr_.empty()) { - auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); - ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); - } - std::string compute_capability_hw_compat = compute_capability_; - if (engine_cache_enable_ && engine_hw_compatible_) { - compute_capability_hw_compat = "80+"; - } -// std::unique_ptr model_proto{CreateCtxModel(graph_body_viewer, -// ep_cache_context_attr_, -// reinterpret_cast(serialized_engine->data()), -// serialized_engine->size(), -// ep_context_embed_mode_, -// compute_capability_hw_compat, -// model_path_, -// GetLogger())}; -// DumpCtxModel(model_proto.get(), ctx_model_path_); + create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); + graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); } } } @@ -2753,30 +2778,14 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort input_shape_ranges_[node_name] = input_implicit_shape_ranges; profiles_.emplace(node_name, std::move(trt_profiles)); - // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty engine. - // TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model. - // However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here. - if (dump_ep_context_model_ && has_dynamic_shape) { - // "ep_cache_context" node attribute should be a relative path to context model directory - if (ep_cache_context_attr_.empty()) { - auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); - ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + // Create ep context model if the model has dynamic shape, + // dump the model is embed mode is 0, otherwise update and dump the model at runtime. + if (has_dynamic_shape && dump_ep_context_model_) { + create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, nullptr, 0); + if (ep_context_embed_mode_ == 0) { + graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); + graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); } - std::string compute_capability_hw_compat = compute_capability_; - if (engine_cache_enable_ && engine_hw_compatible_) { - compute_capability_hw_compat = "80+"; - } -// model_proto_.reset(CreateCtxModel(graph_body_viewer, -// ep_cache_context_attr_, -// nullptr, -// 0, -// ep_context_embed_mode_, -// compute_capability_hw_compat, -// model_path_, -// GetLogger())); -// if (ep_context_embed_mode_ == 0) { -// DumpCtxModel(model_proto_.get(), ctx_model_path_); -// } } // Create function state @@ -3136,8 +3145,19 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // dump ep context model if (this_->dump_ep_context_model_ && this_->ep_context_embed_mode_) { - //UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), serialized_engine->size()); // TODO(leca) - //DumpCtxModel(model_proto_.get(), ctx_model_path_); + graph_api_->OrtGraph_CreateOrUpdateEpCtxGraph(nullptr, + fused_node_name.c_str(), + 1, // main_context + this_->ep_context_embed_mode_, + this_->ep_cache_context_attr_.c_str(), + reinterpret_cast(serialized_engine->data()), + serialized_engine->size(), + this_->extra_attr_keys_.data(), + this_->extra_attr_values_.data(), + this_->extra_attr_keys_.size(), + &this_->ep_ctx_graph_); + graph_api_->OrtGraph_DumpOnnxModel(this_->ep_ctx_graph_, this_->ctx_model_path_.c_str()); + graph_api_->OrtGraph_ReleaseGraph(this_->ep_ctx_graph_); } context_update = true; @@ -3733,7 +3753,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } nodes_list_output.push_back(next_nodes_list[i]); } - graph_api_->OrtGraph_ReleaseGraph(sub_graph_viewer); + graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer); } } } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 9ba05e951615b..65f049657b88f 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -214,6 +214,10 @@ struct TensorrtShortFuncState { using DDSOutputAllocatorMap = std::unordered_map>; std::string GetWeightRefittedEnginePath(std::string engine_cache_path); +static const std::string k_cc_hw_compatible = "80+"; +static const std::string k_ep_ctx_hardware_architecture = "hardware_architecture"; +static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename"; + struct TensorrtExecutionProvider : public OrtExecutionProvider { TensorrtExecutionProvider(const char* ep_type, const ProviderOptions& provider_options); bool IsGraphCaptured(int graph_annotation_id) const { return false; } @@ -309,6 +313,11 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { std::string ctx_model_path_; std::string ep_cache_context_attr_; std::string engine_cache_relative_path_to_context_model_dir; + + OrtGraph* ep_ctx_graph_ = nullptr; + std::vector extra_attr_keys_; + std::vector extra_attr_values_; + // std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create(); std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; From 63f8774ddc4266c6f2a4899f9ab26b7dc713884f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 14 Nov 2024 18:47:17 +0000 Subject: [PATCH 71/81] [TensorRT EP Plugin] use new graph api for ep context model generation --- samples/tensorRTEp/onnx_ctx_model_helper.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index 65311f56f3f91..56b52395ad04f 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -128,7 +128,7 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView // Get engine from byte stream. const char* context_binary_cstr = nullptr; size_t size; - graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr, &szie); + graph_api_->OrtNode_GetAttributeStrWithSize(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr, &size); std::string context_binary(context_binary_cstr, size); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); From bf359a1ce2327bc0c693c646e658ad71f0f768b1 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Sat, 16 Nov 2024 00:24:14 +0000 Subject: [PATCH 72/81] use cuda's preferred allocator for plugin trt and builtin cuda combined case --- onnxruntime/core/framework/session_state.cc | 21 ++++++++++++------- onnxruntime/core/framework/session_state.h | 7 +++++++ .../test/perftest/command_args_parser.cc | 5 ++++- onnxruntime/test/perftest/ort_test_session.cc | 11 ++++++++-- onnxruntime/test/perftest/ort_test_session.h | 1 + .../test/perftest/test_configuration.h | 1 + samples/c_test/test.cpp | 4 ++-- samples/tensorRTEp/tensorrt_cuda_allocator.cc | 2 +- .../tensorRTEp/tensorrt_execution_provider.cc | 8 +++---- 9 files changed, 43 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index d76c06079c44b..03431e378a6e4 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -105,11 +105,18 @@ SessionState::SessionState(Graph& graph, allocators_ = allocators_unique_ptr_.get(); // The allocator registration rule: // Each location (OrtDevice) will only have 1 allocator used for whole session. - // The EP which is registered first will have higher priority + // For plugin EP, let the builtin EP with the same OrtDevice value register first (think about the plugin TRT and builtin CUDA scenario, it may be changed back in the future) + // Once the allocator has been registered, it won't be overwritten by the allocator with the same key (OrtDevice) + std::string register_resource_after = ""; + IExecutionProvider* plugin_ep = nullptr; for (auto& ep : execution_providers_) { - auto allocators = ep->CreatePreferredAllocators(); - for (auto& alloc : allocators) { - allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key + if (register_resource_after == "") { + register_resource_after = ShouldPostPoneRegisterResourceFor(ep.get(), execution_providers_); + if (register_resource_after == "") InitializeAllocators(ep.get()); + else plugin_ep = ep.get(); + } else { + InitializeAllocators(ep.get()); + if (register_resource_after == ep->Type()) InitializeAllocators(plugin_ep); } } } @@ -1379,15 +1386,15 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringRegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); - else out_tree_ep = ep.get(); + else plugin_ep = ep.get(); } else { ep->RegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); - if (register_resource_after == ep->Type()) out_tree_ep->RegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); + if (register_resource_after == ep->Type()) plugin_ep->RegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); } } #endif diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index b1a7504b283c5..71fd3713da0e6 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -394,6 +394,13 @@ class SessionState { const InlinedHashMap& outer_scope_node_arg_to_location_map = {}, bool graph_info_already_created = false); + inline void InitializeAllocators(IExecutionProvider* ep) { + auto allocators = ep->CreatePreferredAllocators(); + for (auto& alloc : allocators) { + allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key + } + } + #ifdef ENABLE_TRAINING Status GeneratePatternGroupCache( gsl::span inputs, diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index b7c99fa66a1ea..02fed26cfe463 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -205,7 +205,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzn"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzng"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -390,6 +390,9 @@ static bool ParseSessionConfigs(const std::string& configs_string, case 'n': test_config.run_config.exit_after_session_creation = true; break; + case 'g': + test_config.plugin = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index ff782da35cbe6..772d0d462fa13 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -49,6 +49,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device Ort::SessionOptions session_options; provider_name_ = performance_test_config.machine_config.provider_type_name; + plugin_ = performance_test_config.plugin; if (provider_name_ == onnxruntime::kDnnlExecutionProvider) { #ifdef USE_DNNL // Generate provider options @@ -170,6 +171,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } else if (provider_name_ == onnxruntime::kTensorrtExecutionProvider) { #ifdef USE_TENSORRT const auto& api = Ort::GetApi(); + if (plugin_) { + Ort::ThrowOnError(api.RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); + std::vector keys{"trt_engine_cache_enable", "trt_dump_ep_context_model", "trt_ep_context_embed_mode"}, values{"0", "0", "0"}; + Ort::ThrowOnError(api.SessionOptionsAppendPluginExecutionProvider(session_options, "tensorrtEp", env, keys.data(), values.data(), keys.size())); + } else { OrtTensorRTProviderOptionsV2* tensorrt_options; Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); std::unique_ptr rel_trt_options( @@ -213,9 +219,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); - + } OrtCUDAProviderOptions cuda_options; - cuda_options.device_id = tensorrt_options->device_id; +// cuda_options.device_id = tensorrt_options->device_id; + cuda_options.device_id = 0; cuda_options.cudnn_conv_algo_search = static_cast(performance_test_config.run_config.cudnn_conv_algo); cuda_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; // TODO: Support arena configuration for users of perf test diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index f1a4220ab325e..51fa154a14e1f 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -46,6 +46,7 @@ class OnnxRuntimeTestSession : public TestSession { std::vector input_names_str_; const int input_length_; std::string provider_name_; + bool plugin_; }; } // namespace perftest diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 70a6b12690d5d..8edb775b7ab32 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -70,6 +70,7 @@ struct PerformanceTestConfig { ModelInfo model_info; MachineConfig machine_config; RunConfig run_config; + bool plugin = false; }; } // namespace perftest diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index 22704953101bf..8f65b70d9551b 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -37,7 +37,7 @@ void TestTensorRTEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* so) { #else THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); #endif - std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; + std::vector keys{"device_id"}, values{"0"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); } @@ -47,7 +47,7 @@ void TestTensorRTAndCudaEp(const OrtApi* g_ort, OrtEnv* env, OrtSessionOptions* #else THROW_ON_ERROR(g_ort->RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); #endif - std::vector keys{"device_id", "str_property"}, values{"0", "strvalue"}; + std::vector keys{"device_id"}, values{"0"}; THROW_ON_ERROR(g_ort->SessionOptionsAppendPluginExecutionProvider(so, "tensorrtEp", env, keys.data(), values.data(), keys.size())); OrtCUDAProviderOptionsV2* cuda_options = nullptr; diff --git a/samples/tensorRTEp/tensorrt_cuda_allocator.cc b/samples/tensorRTEp/tensorrt_cuda_allocator.cc index 044bba043f5a4..89e62dae3f296 100644 --- a/samples/tensorRTEp/tensorrt_cuda_allocator.cc +++ b/samples/tensorRTEp/tensorrt_cuda_allocator.cc @@ -20,7 +20,7 @@ void CUDAAllocator::CheckDevice(bool throw_when_fail) const { CUDA_RETURN_IF_ERROR(cuda_err); } #else - ORT_UNUSED_PARAMETER(throw_when_fail); +// ORT_UNUSED_PARAMETER(throw_when_fail); #endif } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index ae44bb761a301..dcff52661fb7c 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1676,7 +1676,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const }; info_ = TensorrtExecutionProviderInfo::FromProviderOptions(ep_info); - if (ep_info.size() > 0) info_.has_trt_options = true; + if (ep_info.size() > 0) info_.has_trt_options = true; device_id_ = info_.device_id; api_->CreateDevice(OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, OrtMemoryType::OrtMemoryType_Default, device_id_, &default_device); @@ -2535,7 +2535,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort std::string& engine_cache_path, std::string& engine_cache_relative_path_to_context_model_dir, const char* ep_context_node_name, - char* serialized_engine, + char* serialized_engine, size_t serialized_engine_size) { // if ep context model name is not given, create a model name based on original model name if (ctx_model_path_.empty()) { @@ -3146,7 +3146,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // dump ep context model if (this_->dump_ep_context_model_ && this_->ep_context_embed_mode_) { graph_api_->OrtGraph_CreateOrUpdateEpCtxGraph(nullptr, - fused_node_name.c_str(), + fused_node_name.c_str(), 1, // main_context this_->ep_context_embed_mode_, this_->ep_cache_context_attr_.c_str(), @@ -3365,7 +3365,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort // IncrementRegularRunCountBeforeGraphCapture(); // } // } - std::cout << "end of ComputeFunc in TRTEp's CreateNodeComputeInfoFromGraph()\n"; +// std::cout << "end of ComputeFunc in TRTEp's CreateNodeComputeInfoFromGraph()\n"; return nullptr; }; From c267ea52491407a4060c078c51f77e77462a1ca8 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:21:03 -0800 Subject: [PATCH 73/81] [TensorRT EP Plugin] Add cuda::Impl_Cast (#22908) TRT 8 doesn't support INT64 and DOUBLE data type. TRT 10 doesn't support DOUBLE data type. Therefore, TRT EP internally needs to convert INT64 to INT32, and DOUBLE to FLOAT, which needs the cuda::Impl_Cast function. The implementation is copied from CUDA EP. --- samples/tensorRTEp/CMakeLists.txt | 2 +- .../cuda/cu_inc/unary_elementwise_impl.cuh | 78 ++++++++++++++++ .../cuda/unary_elementwise_ops_impl.cu | 91 +++++++++++++++++++ .../cuda/unary_elementwise_ops_impl.h | 54 +++++++++++ .../tensorRTEp/tensorrt_execution_provider.cc | 84 ++++++++--------- 5 files changed, 267 insertions(+), 42 deletions(-) create mode 100644 samples/tensorRTEp/cuda/cu_inc/unary_elementwise_impl.cuh create mode 100644 samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu create mode 100644 samples/tensorRTEp/cuda/unary_elementwise_ops_impl.h diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index ebf5448b80a93..b0982784e7c19 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -12,7 +12,7 @@ find_package(CUDAToolkit REQUIRED) add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) add_definitions(-DNV_TENSORRT_MAJOR=10) -file(GLOB tensorrt_src "./*.cc" "../utils/status.cc") +file(GLOB tensorrt_src "./*.cc" "../utils/status.cc" "./cuda/unary_elementwise_ops_impl.cu") add_library(TensorRTEp SHARED ${tensorrt_src}) target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "../utils" diff --git a/samples/tensorRTEp/cuda/cu_inc/unary_elementwise_impl.cuh b/samples/tensorRTEp/cuda/cu_inc/unary_elementwise_impl.cuh new file mode 100644 index 0000000000000..87cf7c832cd21 --- /dev/null +++ b/samples/tensorRTEp/cuda/cu_inc/unary_elementwise_impl.cuh @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +namespace cuda { + +// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer +// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +template +inline __host__ __device__ INT CeilDiv(INT a, INT2 b) // ceil(a/b) +{ + return (INT)(((size_t)a + (size_t)b - 1) / (size_t)b); // these size_t casts are necessary since b may be INT_MAX (for maxGridSize[]) +} + +struct GridDim { + enum : CUDA_LONG { + maxThreadsPerBlock = 256, // max threads per block + maxElementsPerThread = 4, // max element processed per thread + }; +}; + +template +__global__ void _UnaryElementWise( + const InT* input_data, + OutT* output_data, + const FuncT functor, + CUDA_LONG N) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + InT value[NumElementsPerThread]; + + CUDA_LONG id = start; + #pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + value[i] = input_data[id]; + id += NumThreadsPerBlock; + } + } + + id = start; + #pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + output_data[id] = functor(value[i]); + id += NumThreadsPerBlock; + } + } +} + +template +void UnaryElementWiseImpl( + cudaStream_t stream, + const InT* input_data, + OutT* output_data, + const FuncT& func, + size_t count) { + if (count == 0) // special case where there's a dim value of 0 in the shape + return; + + int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + CUDA_LONG N = static_cast(count); + _UnaryElementWise + <<>>( + input_data, + output_data, + func, + N); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu b/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu new file mode 100644 index 0000000000000..da3fcfcc73bb4 --- /dev/null +++ b/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "cu_inc/unary_elementwise_impl.cuh" + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 +#include "cuda_fp8.h" +#endif +#include + +namespace onnxruntime { + +namespace cuda { + +// the postfix of means the types supported by the op: +// B: uint8_t +// W: uint16_t +// U: uint32_t +// Z: uint64_t +// C: int8_t +// S: int16_t +// I: int32_t +// L: int64_t +// H: float16 +// F: float +// D: double +// O: bool +// X: BFloat16 + +// When casting, half needs to be converted via float type from most other types +template +struct ViaTypeMap { + typedef T ViaT; +}; + +template <> +struct ViaTypeMap { + typedef float ViaT; +}; + +template +struct OP_Cast { + __device__ __inline__ OutT operator()(const InT& a) const { + typedef typename ViaTypeMap::ViaT ViaT; + return (OutT)((ViaT)a); + } +}; + +#define IMPL_CAST_IMPL(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ + } + +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \ + size_t /*count*/) { \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ + } + +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ + //IMPL_CAST_IMPL(T, BFloat16) + +IMPL_CAST_IMPL_FROM(half) +IMPL_CAST_IMPL_FROM(float) +IMPL_CAST_IMPL_FROM(double) +IMPL_CAST_IMPL_FROM(int8_t) +IMPL_CAST_IMPL_FROM(int16_t) +IMPL_CAST_IMPL_FROM(int32_t) +IMPL_CAST_IMPL_FROM(int64_t) +IMPL_CAST_IMPL_FROM(uint8_t) +IMPL_CAST_IMPL_FROM(uint16_t) +IMPL_CAST_IMPL_FROM(uint32_t) +IMPL_CAST_IMPL_FROM(uint64_t) +IMPL_CAST_IMPL_FROM(bool) +//IMPL_CAST_IMPL_FROM(BFloat16) + +} // namespace cuda +} // namespace onnxruntime diff --git a/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.h b/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.h new file mode 100644 index 0000000000000..392cf46f42ff5 --- /dev/null +++ b/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace cuda { + +// Cast + +#define DECL_IMPL_CAST(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count); + +#define DECL_IMPL_CAST_FROM(T) \ + DECL_IMPL_CAST(T, half) \ + DECL_IMPL_CAST(T, float) \ + DECL_IMPL_CAST(T, double) \ + DECL_IMPL_CAST(T, int8_t) \ + DECL_IMPL_CAST(T, int16_t) \ + DECL_IMPL_CAST(T, int32_t) \ + DECL_IMPL_CAST(T, int64_t) \ + DECL_IMPL_CAST(T, uint8_t) \ + DECL_IMPL_CAST(T, uint16_t) \ + DECL_IMPL_CAST(T, uint32_t) \ + DECL_IMPL_CAST(T, uint64_t) \ + DECL_IMPL_CAST(T, bool) \ + //DECL_IMPL_CAST(T, BFloat16) + +DECL_IMPL_CAST_FROM(half) +DECL_IMPL_CAST_FROM(float) +DECL_IMPL_CAST_FROM(double) +DECL_IMPL_CAST_FROM(int8_t) +DECL_IMPL_CAST_FROM(int16_t) +DECL_IMPL_CAST_FROM(int32_t) +DECL_IMPL_CAST_FROM(int64_t) +DECL_IMPL_CAST_FROM(uint8_t) +DECL_IMPL_CAST_FROM(uint16_t) +DECL_IMPL_CAST_FROM(uint32_t) +DECL_IMPL_CAST_FROM(uint64_t) +DECL_IMPL_CAST_FROM(bool) +//DECL_IMPL_CAST_FROM(BFloat16) + +template +void Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { + Explicit_Impl_Cast(stream, input_data, output_data, count); +} + +} // namespace cuda + +} // namespace onnxruntime diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index dcff52661fb7c..5ed548eee41e1 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -10,6 +10,7 @@ #include "tensorrt_cuda_allocator.h" #include "onnx_ctx_model_helper.h" #include "onnx/onnx_pb.h" +#include "cuda/unary_elementwise_ops_impl.h" #ifdef _WIN32 #include @@ -671,19 +672,19 @@ OrtStatusPtr ApplyProfileShapesFromInputTensorValue(std::vector(); \ -// if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ -// scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ -// data = scratch_buffers.back().get(); \ -// cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ -// } else { \ -// scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ -// data = scratch_buffers.back().get(); \ -// } \ -// break; \ -// } +#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + data = scratch_buffers.back().get(); \ + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } #define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ case DATA_TYPE: { \ @@ -721,14 +722,14 @@ OrtStatusPtr ApplyProfileShapesFromInputTensorValue(std::vector(); \ -// if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ -// cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ -// } \ -// break; \ -// } +#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ + } \ + break; \ + } OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, nvinfer1::ICudaEngine* trt_engine, @@ -836,10 +837,10 @@ OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) #else // Cast int64 input to int32 input because TensorRT < 10 doesn't support int64 -// CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) #endif // Cast double input to float because TensorRT doesn't support double -// CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return TensorrtExecutionProvider::api_->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); } @@ -3334,20 +3335,20 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort } } else { auto& output_tensor = output_tensors[i]; -//#if NV_TENSORRT_MAJOR < 10 -// if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { -// auto output_tensor_ptr = output_tensor.GetTensorMutableData(); -// if (output_tensor_ptr != nullptr) { -// cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); -// } -// } -//#endif -// if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { -// auto output_tensor_ptr = output_tensor.GetTensorMutableData(); -// if (output_tensor_ptr != nullptr) { -// cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); -// } -// } +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } } } @@ -3499,6 +3500,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; auto fused_node_name = trt_state->fused_node_name; + std::cout << fused_node_name << std::endl; auto& dds_output_allocator_map = this_->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); @@ -3664,10 +3666,10 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngi } #endif if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { -// auto output_tensor_ptr = output_tensor.GetTensorMutableData(); -// if (output_tensor_ptr != nullptr) { -// cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); -// } + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } } } } From 72afdc41ff562b8a827828bb44dc0779454a82a0 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 22 Nov 2024 06:27:44 +0000 Subject: [PATCH 74/81] fix build/compiler error for nvcc 11.8 --- samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu b/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu index da3fcfcc73bb4..ad515a23a42a3 100644 --- a/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu +++ b/samples/tensorRTEp/cuda/unary_elementwise_ops_impl.cu @@ -42,7 +42,9 @@ struct ViaTypeMap { template struct OP_Cast { __device__ __inline__ OutT operator()(const InT& a) const { - typedef typename ViaTypeMap::ViaT ViaT; + const bool any_float16 = std::is_same::value || std::is_same::value; + typedef typename std::conditional::type T; + typedef typename ViaTypeMap::ViaT ViaT; return (OutT)((ViaT)a); } }; From 6822206352d0e1e2bd908f36397c1e849211a975 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 3 Dec 2024 00:40:32 +0000 Subject: [PATCH 75/81] Do not expose OrtGraph --- .../core/session/onnxruntime_c_api_ep.h | 71 ++++++------------ .../core/session/onnxruntime_c_api_ep.cc | 74 +++++++------------ onnxruntime/core/session/ort_apis_ep.h | 18 ++--- samples/openvino/ov_versions/capability.cc | 2 +- .../tensorRTEp/tensorrt_execution_provider.cc | 22 ++---- .../tensorRTEp/tensorrt_execution_provider.h | 2 +- .../tensorrt_execution_provider_utils.h | 15 ++-- 7 files changed, 71 insertions(+), 133 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index f20607a893d20..db2d957fc3013 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -124,24 +124,7 @@ ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* * \param[out] out True if the graph is a subgraph * */ -ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out); - -/** \brief Get the parent graph of the graph - * - * \param[in] graph The graph to query - * \param[out] parent_graph The parent graph of the graph - * - */ -ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); - -/** \brief Check if the graph is a subgraph - * TODO(leca): maybe deprecate OrtGraph_IsSubgraph? - * - * \param[in] graph The graph to query - * \param[out] out True if the graph is a subgraph - * - */ -ORT_API2_STATUS(OrtGraph_IsSubgraph2, const OrtGraphViewer* graph, _Out_ bool* out); +ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* out); /** \brief Get the parent node of the graph * @@ -159,14 +142,6 @@ ORT_API2_STATUS(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ con */ ORT_API2_STATUS(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** model_path); -/** \brief Get the internal graph in the graph viewer - * - * \param[in] graph_viewer The graph viewer to query - * \param[out] graph The internal graph in the graph viewer - * - */ -ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); - /** \brief Gets the Graph inputs with no matching initializers, in the same order as defined in the GraphProto. * * NOTE!!: The caller is responsible for releasing the char array using ReleaseCharArray. @@ -333,39 +308,39 @@ ORT_API2_STATUS(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info); * */ ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); // TODO(leca): review and discuss - + /** \brief Serialize the graph(model) to disk. * * \param[in] graph The graph to be serialized * \param[in] onnx_model_path The file path to save to * */ -ORT_API2_STATUS(OrtGraph_DumpOnnxModel, const OrtGraph* graph, const char* onnx_model_path); +ORT_API2_STATUS(OrtGraph_DumpOnnxModel, const OrtGraphViewer* graph, const char* onnx_model_path); -/** \brief Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise: +/** \brief Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise: * 1. if the given node name can't be found in the graph, add an new "EP Context" node to the existing graph * 2. if the node being found with the givne node name, update the node attributes only * * Please see https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html for more details about EP Context design * * \param[in] graph The graph to create or add - * \param[in] node_name The node to be added or updated - * \param[in] main_context The attribute of EP Context op - * \param[in] embed_mode The attribute of EP Context op + * \param[in] node_name The node to be added or updated + * \param[in] main_context The attribute of EP Context op + * \param[in] embed_mode The attribute of EP Context op * \param[in] cache_path The cache or binary file path. It's for setting the ep_cache_context attribute if embed_mode is 0 * \param[in] cache_data The cache or binary data. It's for setting the ep_cache_context attribute if embed_mode is 1 * \param[in] size The size of cache data. * \param[in] extra_attr_keys The other attribute names * \param[in] extra_attr_values The other attribute value in string - * \param[in] extra_attr_num Number of other attributes + * \param[in] extra_attr_num Number of other attributes * \param[out] ep_context_graph The constructed or updated ep context graph * - * \remarks The caller is responsible for releasing the ep_context_graph using OrtGraph_ReleaseGraph. + * \remarks The caller is responsible for releasing the ep_context_graph using OrtGraph_ReleaseGraphViewer. * */ ORT_API2_STATUS(OrtGraph_CreateOrUpdateEpCtxGraph, const OrtGraphViewer* graph, - const char* node_name, + const char* node_name, const int64_t main_context, const int64_t embed_mode, const char* cache_path, @@ -374,7 +349,7 @@ ORT_API2_STATUS(OrtGraph_CreateOrUpdateEpCtxGraph, const char* const* extra_attr_keys, const char* const* extra_attr_values, size_t extra_attr_num, - _Outptr_ OrtGraph** ep_context_graph); + _Outptr_ OrtGraphViewer** ep_context_graph); /** \brief Construct a subgraph from the Graph with the given node indices. * @@ -383,21 +358,10 @@ ORT_API2_STATUS(OrtGraph_CreateOrUpdateEpCtxGraph, * \param[in] node_indices The indices of the nodes to include in the subgraph * \param[out] subgraph The constructed subgraph * - * \remarks The caller is responsible for releasing the subgraph using OrtGraph_ReleaseGraph. + * \remarks The caller is responsible for releasing the subgraph using OrtGraph_ReleaseGraphViewer. * */ ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss - -/** \brief Release the graph instance. - * - * NOTE!!: Invoke this function after the use of OrtGraph_CreateOrUpdateEpCtxGraph. As OrtGraph_CreateOrUpdateEpCtxGraph allocates model instead of - * graph, this API releases graph's owning_model explicitly which in turn will release the graph - * (because graph is hosted in an unique_ptr in Model class) - * - * \param[in] graph The graph to release - * - */ -ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraph* graph); /** \brief Release the graph viewer instance. * @@ -410,6 +374,15 @@ ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraph* graph); */ ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); +/** \brief Check are two graph actually pointing to the same graph. + * + * \param[in] graph1 The 1st graph + * \param[in] graph2 The 2nd graph + * \param[out] is_same Is graph1 and graph2 pointing to the same graph + * + */ +ORT_API2_STATUS(OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, _Out_ bool* is_same); + /** \brief Gets the name of the node * * \param[in] node The node to query @@ -634,7 +607,7 @@ ORT_API2_STATUS(OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _ * \param[in] node The node to query * \param[in] key The attribute key * \param[out] out The string value of the attribute - * \param[out] size The length of the string + * \param[out] size The length of the string * */ ORT_API2_STATUS(OrtNode_GetAttributeStrWithSize, const OrtNode* node, const char* key, _Outptr_ const char** out, _Outptr_ size_t* size); diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index af884ef353d22..4963f97fef03e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -34,19 +34,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetNodesIndexInTopologicalOrder, cons return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out) { - const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); - *out = graph_ptr->IsSubgraph(); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph) { - const ::onnxruntime::Graph* graph_ptr = reinterpret_cast(graph); - *parent_graph = reinterpret_cast(graph_ptr->ParentGraph()); - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSubgraph2, const OrtGraphViewer* graph, _Out_ bool* out) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* out) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); *out = graph_viewer->IsSubgraph(); return nullptr; @@ -64,12 +52,6 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetModelPath, const OrtGraphViewer* g return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph) { - const ::onnxruntime::GraphViewer* graph_viewer_ptr = reinterpret_cast(graph_viewer); - *graph = reinterpret_cast(&graph_viewer_ptr->GetGraph()); - return nullptr; -} - ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetRequiredInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); const auto& inputs = graph_viewer->GetInputs(); @@ -484,16 +466,17 @@ static void SetAllGraphInputs(Graph& graph, std::unordered_map(graph); + const GraphViewer* graph_viewer = reinterpret_cast(graph); + const ::onnxruntime::Graph* internal_graph = &(graph_viewer->GetGraph()); auto model = &(internal_graph->GetModel()); // Two options to generate model proto: // 1. directly call model->ToProto() - // 2. new model ---> model->ToProto ---> update graph proto in model proto with GraphViewerToProto() + // 2. new model ---> model->ToProto ---> update graph proto in model proto with GraphViewerToProto() // - // TODO: (Chi) Need more thinking on which to choose + // TODO: (Chi) Need more thinking on which to choose // option 1 std::unique_ptr model_proto = std::make_unique(model->ToProto()); @@ -509,13 +492,13 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_DumpOnnxModel, return nullptr; } -/* Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise: +/* Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise: * 1. if the given node name can't be found in the graph, add an new "EP Context" node to the existing graph * 2. if the node is already existed, update the node attributes only */ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, const OrtGraphViewer* graph, - const char* node_name, + const char* node_name, const int64_t main_context, const int64_t embed_mode, const char* cache_path, @@ -524,7 +507,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, const char* const* extra_attr_keys, const char* const* extra_attr_values, size_t extra_attr_num, - _Outptr_ OrtGraph** ep_context_graph) { + _Outptr_ OrtGraphViewer** ep_context_graph) { const std::string EPCONTEXT_OP = "EPContext"; const std::string MAIN_CONTEXT = "main_context"; @@ -547,7 +530,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, attr_keys_values[extra_attr_keys[i]] = extra_attr_values[i]; } - // Create a new graph or use the existing one + // Create a new graph or use the existing one if (*ep_context_graph == nullptr) { Model* model_build = new Model (graph_viewer->Name(), true, ModelMetaData(), PathString(), #if !defined(ORT_MINIMAL_BUILD) @@ -557,9 +540,11 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, #endif // ORT_MINIMAL_BUILD std::vector(), graph_viewer->GetGraph().GetLogger()); graph_build = &(model_build->MainGraph()); - *ep_context_graph = reinterpret_cast(graph_build); + auto graph_build_viewer = std::make_unique(*graph_build); + *ep_context_graph = reinterpret_cast(graph_build_viewer.release()); } else { - graph_build = reinterpret_cast<::onnxruntime::Graph*>(*ep_context_graph); + ::onnxruntime::GraphViewer* content_graph_viewer = reinterpret_cast<::onnxruntime::GraphViewer*>(*ep_context_graph); + graph_build = const_cast<::onnxruntime::Graph*>(&(content_graph_viewer->GetGraph())); } // Get graph inputs and outputs @@ -593,15 +578,15 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, // Create or get EP context node attributes auto new_node_attributes = NodeAttributes(); // using NodeAttributes = std::unordered_map NodeAttributes* node_attributes; - if (node_existed) { + if (node_existed) { node_attributes = &graph_build->GetNode(node_idx)->GetMutableAttributes(); } else { new_node_attributes.reserve(3 + extra_attr_num); node_attributes = &new_node_attributes; } std::unique_ptr attr_0 = std::make_unique(); // main_context - std::unique_ptr attr_1 = std::make_unique(); // embed_mode - std::unique_ptr attr_2 = std::make_unique(); // ep_cache_context + std::unique_ptr attr_1 = std::make_unique(); // embed_mode + std::unique_ptr attr_2 = std::make_unique(); // ep_cache_context std::string cache_data_str = ""; std::string cache_path_str = cache_path; @@ -611,7 +596,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, attr_0->set_type(onnx::AttributeProto_AttributeType_INT); attr_0->set_i(main_context); - // embed_mode + // embed_mode attr_1->set_name(EMBED_MODE); attr_1->set_type(onnx::AttributeProto_AttributeType_INT); attr_1->set_i(embed_mode); @@ -638,7 +623,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, for (it = attr_keys_values.begin(); it != attr_keys_values.end(); ++it) { std::string key = it->first; std::string value = it->second; - if (key == ONNX_MODEL_FILENAME) value = std::filesystem::path(value).filename().string(); + if (key == ONNX_MODEL_FILENAME) value = std::filesystem::path(value).filename().string(); std::unique_ptr attr = std::make_unique(); attr->set_name(key); @@ -776,22 +761,22 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraph* ort_graph) { - if (ort_graph) { - const ::onnxruntime::Graph* graph = reinterpret_cast(ort_graph); - delete &(graph->GetModel()); - } - return nullptr; -} - ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph) { if (graph) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); delete &(graph_viewer->GetGraph()).GetModel(); + delete graph_viewer; } return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, bool* is_same) { + const ::onnxruntime::GraphViewer* graph_viewer1 = reinterpret_cast(graph1); + const ::onnxruntime::GraphViewer* graph_viewer2 = reinterpret_cast(graph2); + *is_same = (&(graph_viewer1->GetGraph()) == &(graph_viewer2->GetGraph())); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetName, const OrtNode* node, _Outptr_ const char** out) { const ::onnxruntime::Node* n = reinterpret_cast(node); *out = n->Name().c_str(); @@ -997,11 +982,8 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_IsConstantInitializer, &OrtGraphApis::OrtGraph_GetNodesIndexInTopologicalOrder, &OrtGraphApis::OrtGraph_IsSubgraph, - &OrtGraphApis::OrtGraph_GetParentGraph, - &OrtGraphApis::OrtGraph_IsSubgraph2, &OrtGraphApis::OrtGraph_GetParenNode, &OrtGraphApis::OrtGraph_GetModelPath, - &OrtGraphApis::OrtGraph_GetOrtGraph, &OrtGraphApis::OrtGraph_GetRequiredInputs, &OrtGraphApis::OrtGraph_GetAllInputs, &OrtGraphApis::OrtGraph_GetAllInitializers, @@ -1022,8 +1004,8 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_DumpOnnxModel, &OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, &OrtGraphApis::OrtGraph_GetSubGraph, - &OrtGraphApis::OrtGraph_ReleaseGraph, &OrtGraphApis::OrtGraph_ReleaseGraphViewer, + &OrtGraphApis::OrtGraph_IsSameGraph, &OrtGraphApis::OrtNode_GetName, &OrtGraphApis::OrtNode_GetDescription, &OrtGraphApis::OrtNode_GetDomain, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index 945539e2db07b..dbedeaed3cbd0 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -9,18 +9,12 @@ ORT_API_STATUS_IMPL(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, ORT_API_STATUS_IMPL(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order, _Out_ size_t* num_nodes); -ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* out); - -ORT_API_STATUS_IMPL(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); - -ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph2, const OrtGraphViewer* graph, _Out_ bool* out); +ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* out); ORT_API_STATUS_IMPL(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** model_path); -ORT_API_STATUS_IMPL(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); - ORT_API_STATUS_IMPL(OrtGraph_GetRequiredInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); ORT_API_STATUS_IMPL(OrtGraph_GetAllInputs, const OrtGraphViewer* graph, _Outptr_ const char*** input_names, _Out_ size_t* input_len); @@ -55,11 +49,11 @@ ORT_API_STATUS_IMPL(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info); ORT_API_STATUS_IMPL(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); -ORT_API_STATUS_IMPL(OrtGraph_DumpOnnxModel, const OrtGraph* graph, const char* onnx_model_path); +ORT_API_STATUS_IMPL(OrtGraph_DumpOnnxModel, const OrtGraphViewer* graph, const char* onnx_model_path); ORT_API_STATUS_IMPL(OrtGraph_CreateOrUpdateEpCtxGraph, const OrtGraphViewer* graph, - const char* node_name, + const char* node_name, const int64_t main_context, const int64_t embed_mode, const char* cache_path, @@ -68,14 +62,14 @@ ORT_API_STATUS_IMPL(OrtGraph_CreateOrUpdateEpCtxGraph, const char* const* extra_attr_keys, const char* const* extra_attr_values, size_t extra_attr_num, - _Outptr_ OrtGraph** ep_context_graph); + _Outptr_ OrtGraphViewer** ep_context_graph); ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); -ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraph* graph); - ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); +ORT_API_STATUS_IMPL(OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, _Out_ bool* is_same); + ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Outptr_ const char** out); ORT_API_STATUS_IMPL(OrtNode_GetDescription, const OrtNode* node, _Outptr_ const char** out); diff --git a/samples/openvino/ov_versions/capability.cc b/samples/openvino/ov_versions/capability.cc index b102e932b9f04..f0bfbdf18b5d9 100644 --- a/samples/openvino/ov_versions/capability.cc +++ b/samples/openvino/ov_versions/capability.cc @@ -54,7 +54,7 @@ GetCapability::GetCapability(const OrtGraphViewer* graph_viewer_param, size_t GetCapability::Execute(OrtIndexedSubGraph*** indexed_sub_graph) { // Check if it is a subgraph bool is_subgraph = false; - graph_api_->OrtGraph_IsSubgraph2(graph_viewer_, &is_subgraph); + graph_api_->OrtGraph_IsSubgraph(graph_viewer_, &is_subgraph); const char* graph_name = nullptr; graph_api_->OrtGraph_GetName(graph_viewer_, &graph_name); if (is_subgraph && !strcmp(graph_name, "tf2onnx")) return 0; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 5ed548eee41e1..04f621d5d76e2 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1115,10 +1115,8 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& // Check the graph is the subgraph of control flow op bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer* graph) const { - const OrtGraph* cur_graph = nullptr; - graph_api_->OrtGraph_GetOrtGraph(graph, &cur_graph); bool is_subgraph = false; - graph_api_->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + graph_api_->OrtGraph_IsSubgraph(graph, &is_subgraph); if (is_subgraph) { const OrtNode* node = nullptr; graph_api_->OrtGraph_GetParenNode(graph, &node); @@ -1289,10 +1287,8 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr // Generate unique kernel name for TRT subgraph std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index); - const OrtGraph* cur_graph = nullptr; - graph_api_->OrtGraph_GetOrtGraph(graph, &cur_graph); bool is_subgraph = false; - graph_api_->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); + graph_api_->OrtGraph_IsSubgraph(graph, &is_subgraph); const std::string graph_type = is_subgraph ? "subgraph" : "graph"; const char* graph_name = nullptr; graph_api_->OrtGraph_GetName(graph, &graph_name); @@ -1465,12 +1461,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const const OrtGraphViewer** subgraphs = nullptr; size_t subgraph_count = 0; graph_api_->OrtNode_GetSubgraphs(parent_node, &subgraphs, &subgraph_count); - const OrtGraph* origin_graph = nullptr; - graph_api_->OrtGraph_GetOrtGraph(graph, &origin_graph); for (size_t i = 0; i < subgraph_count; i++) { - const OrtGraph* subgraph = nullptr; - graph_api_->OrtGraph_GetOrtGraph(subgraphs[i], &subgraph); - if (subgraph == origin_graph) { + bool same_graph = false; + graph_api_->OrtGraph_IsSameGraph(graph, subgraphs[i], &same_graph); + if (same_graph) { continue; } int number_of_ort_subgraph_nodes = 0; @@ -2700,7 +2694,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort if (dump_ep_context_model_) { create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, reinterpret_cast(serialized_engine->data()), serialized_engine->size()); graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); + graph_api_->OrtGraph_ReleaseGraphViewer(ep_ctx_graph_); } } } @@ -2785,7 +2779,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, nullptr, 0); if (ep_context_embed_mode_ == 0) { graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); + graph_api_->OrtGraph_ReleaseGraphViewer(ep_ctx_graph_); } } @@ -3158,7 +3152,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort this_->extra_attr_keys_.size(), &this_->ep_ctx_graph_); graph_api_->OrtGraph_DumpOnnxModel(this_->ep_ctx_graph_, this_->ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraph(this_->ep_ctx_graph_); + graph_api_->OrtGraph_ReleaseGraphViewer(this_->ep_ctx_graph_); } context_update = true; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 65f049657b88f..362f75d8da8f3 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -314,7 +314,7 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { std::string ep_cache_context_attr_; std::string engine_cache_relative_path_to_context_model_dir; - OrtGraph* ep_ctx_graph_ = nullptr; + OrtGraphViewer* ep_ctx_graph_ = nullptr; std::vector extra_attr_keys_; std::vector extra_attr_values_; diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index 8e7f6f1fbd923..ec3f923fe5d49 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -272,17 +272,12 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { HashValue model_hash = 0; const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION); - const OrtGraph* cur_graph = nullptr; - graph_api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph); - bool is_subgraph = false; - graph_api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); - while (is_subgraph) { - graph_api->OrtGraph_GetParentGraph(cur_graph, &cur_graph); - is_subgraph = false; - graph_api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph); - } + // TODO(leca): omit the logic to get the parent graph, as we don't want to expose both OrtGraph and OrtGraphViewer + // To add this logic back, either: + // 1. Change the ORT code to add new function (GetParentGraph) on GraphViewer, or + // 2. In Graph API, allocate space for GraphViewer object wrapping Graph object in the corresponding API function, but + // the GraphViewer object needs to be released in a separate Graph API - const OrtGraph* main_graph = cur_graph; uint32_t hash[4] = {0, 0, 0, 0}; auto hash_str = [&hash](const std::string& str) { From c8ddc7310e659148d643ce3f241f46869f7b1269 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 3 Dec 2024 22:54:45 +0000 Subject: [PATCH 76/81] initial commit for Graph C++ API --- .../core/session/onnxruntime_cxx_api_ep.h | 28 +++++++++++++++++++ .../core/session/onnxruntime_cxx_inline_ep.h | 28 +++++++++++++++++++ samples/outTreeEp/out_tree_ep.cc | 7 +++++ 3 files changed, 63 insertions(+) create mode 100644 include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h create mode 100644 include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h new file mode 100644 index 0000000000000..1188904f7be3c --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "onnxruntime_cxx_api.h" +#include "onnxruntime_c_api_ep.h" + +namespace Ort { +namespace PluginEP { + +struct Graph { +explicit Graph(const OrtGraphViewer*); +const char* GetName(); +private: +const OrtGraphViewer* graph_; +}; + +struct Node { +explicit Node(const OrtNode*); +const char* GetName(); +private: +const OrtNode* node_; +}; + +} +} + +#include "onnxruntime_cxx_inline_ep.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h new file mode 100644 index 0000000000000..99b80e2b56450 --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Do not include this file directly. Please include "onnxruntime_cxx_api_ep.h" instead. + +namespace Ort{ +namespace PluginEP { + +static const OrtGraphApi* ort_graph_api = GetApi().GetGraphApi(ORT_API_VERSION); + +inline Graph::Graph(const OrtGraphViewer* graph) : graph_(graph) {} + +inline const char* Graph::GetName() { + const char* graph_name = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetName(graph_, &graph_name)); + return graph_name; +} + +inline Node::Node(const OrtNode* node) : node_(node) {} + +inline const char* Node::GetName() { + const char* node_name = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetName(node_, &node_name)); + return node_name; +} + +} +} diff --git a/samples/outTreeEp/out_tree_ep.cc b/samples/outTreeEp/out_tree_ep.cc index c77891bab9c72..b0e2000eacc4a 100644 --- a/samples/outTreeEp/out_tree_ep.cc +++ b/samples/outTreeEp/out_tree_ep.cc @@ -2,6 +2,7 @@ #include #include #include +#include "core/session/onnxruntime_cxx_api_ep.h" namespace onnxruntime { OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExecutionProvider(), info(ep_info) { @@ -9,6 +10,10 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); const OrtGraphApi* ort_graph_api = api->GetGraphApi(ORT_API_VERSION); + + // Test Graph C++ API + Ort::PluginEP::Graph graph_cxx(graph); + std::cout<<"Test Graph C++ API Graph::GetName:"< cache; const size_t* nodes_index = nullptr; size_t nodes_count = 0; @@ -16,6 +21,8 @@ OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : OrtExe for (size_t i = 0; i < nodes_count; i++) { const OrtNode* node = nullptr; ort_graph_api->OrtGraph_GetOrtNode(graph, nodes_index[i], &node); + Ort::PluginEP::Node node_cxx(node); + std::cout<<"Test Graph C++ API Node::GetName:"<OrtNode_GetOpType(node, &node_op_type); if (!strcmp(node_op_type, "Relu")) { From e6be85e0a5096bc0adb672ccbd12e0bc0766d265 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 4 Dec 2024 19:42:00 +0000 Subject: [PATCH 77/81] Fix Chi's comment and rollback the change on OrtGraph_CreateOrUpdateEpCtxGraph --- .../core/session/onnxruntime_c_api_ep.h | 17 +++++++++++--- .../core/session/onnxruntime_c_api_ep.cc | 23 ++++++++++++------- onnxruntime/core/session/ort_apis_ep.h | 6 +++-- .../tensorRTEp/tensorrt_execution_provider.cc | 6 ++--- .../tensorRTEp/tensorrt_execution_provider.h | 2 +- 5 files changed, 37 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index db2d957fc3013..d0ff83e38be2f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -315,7 +315,7 @@ ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ vo * \param[in] onnx_model_path The file path to save to * */ -ORT_API2_STATUS(OrtGraph_DumpOnnxModel, const OrtGraphViewer* graph, const char* onnx_model_path); +ORT_API2_STATUS(OrtGraph_DumpOnnxModel, const OrtGraph* graph, const char* onnx_model_path); /** \brief Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise: * 1. if the given node name can't be found in the graph, add an new "EP Context" node to the existing graph @@ -335,7 +335,7 @@ ORT_API2_STATUS(OrtGraph_DumpOnnxModel, const OrtGraphViewer* graph, const char* * \param[in] extra_attr_num Number of other attributes * \param[out] ep_context_graph The constructed or updated ep context graph * - * \remarks The caller is responsible for releasing the ep_context_graph using OrtGraph_ReleaseGraphViewer. + * \remarks The caller is responsible for releasing the ep_context_graph using OrtGraph_ReleaseGraph. * */ ORT_API2_STATUS(OrtGraph_CreateOrUpdateEpCtxGraph, @@ -349,7 +349,7 @@ ORT_API2_STATUS(OrtGraph_CreateOrUpdateEpCtxGraph, const char* const* extra_attr_keys, const char* const* extra_attr_values, size_t extra_attr_num, - _Outptr_ OrtGraphViewer** ep_context_graph); + _Outptr_ OrtGraph** ep_context_graph); /** \brief Construct a subgraph from the Graph with the given node indices. * @@ -363,6 +363,17 @@ ORT_API2_STATUS(OrtGraph_CreateOrUpdateEpCtxGraph, */ ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss +/** \brief Release the graph instance. + * + * NOTE!!: Invoke this function after the use of OrtGraph_CreateOrUpdateEpCtxGraph. As OrtGraph_CreateOrUpdateEpCtxGraph allocates model instead of + * graph, this API releases graph's owning_model explicitly which in turn will release the graph + * (because graph is hosted in an unique_ptr in Model class) + * + * \param[in] graph The graph to release + * + */ +ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraph* graph); + /** \brief Release the graph viewer instance. * * NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocates model instead of diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index 4963f97fef03e..bcf5e6f1057aa 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -466,10 +466,9 @@ static void SetAllGraphInputs(Graph& graph, std::unordered_map(graph); - const ::onnxruntime::Graph* internal_graph = &(graph_viewer->GetGraph()); + const ::onnxruntime::Graph* internal_graph = reinterpret_cast(graph); auto model = &(internal_graph->GetModel()); // Two options to generate model proto: @@ -507,7 +506,7 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, const char* const* extra_attr_keys, const char* const* extra_attr_values, size_t extra_attr_num, - _Outptr_ OrtGraphViewer** ep_context_graph) { + _Outptr_ OrtGraph** ep_context_graph) { const std::string EPCONTEXT_OP = "EPContext"; const std::string MAIN_CONTEXT = "main_context"; @@ -540,11 +539,9 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, #endif // ORT_MINIMAL_BUILD std::vector(), graph_viewer->GetGraph().GetLogger()); graph_build = &(model_build->MainGraph()); - auto graph_build_viewer = std::make_unique(*graph_build); - *ep_context_graph = reinterpret_cast(graph_build_viewer.release()); + *ep_context_graph = reinterpret_cast(graph_build); } else { - ::onnxruntime::GraphViewer* content_graph_viewer = reinterpret_cast<::onnxruntime::GraphViewer*>(*ep_context_graph); - graph_build = const_cast<::onnxruntime::Graph*>(&(content_graph_viewer->GetGraph())); + graph_build = reinterpret_cast<::onnxruntime::Graph*>(*ep_context_graph); } // Get graph inputs and outputs @@ -756,11 +753,20 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr status = graph_build.Resolve(); if (status != Status::OK()) return onnxruntime::ToOrtStatus(status); + // TODO(leca): Maybe we should just return graph_build in the form of OrtGraph, so that we can reuse OrtGraph_ReleaseGraph auto sub_graph_viewer = std::make_unique(graph_build); *subgraph = reinterpret_cast(sub_graph_viewer.release()); return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraph* ort_graph) { + if (ort_graph) { + const ::onnxruntime::Graph* graph = reinterpret_cast(ort_graph); + delete &(graph->GetModel()); + } + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph) { if (graph) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); @@ -1004,6 +1010,7 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_DumpOnnxModel, &OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph, &OrtGraphApis::OrtGraph_GetSubGraph, + &OrtGraphApis::OrtGraph_ReleaseGraph, &OrtGraphApis::OrtGraph_ReleaseGraphViewer, &OrtGraphApis::OrtGraph_IsSameGraph, &OrtGraphApis::OrtNode_GetName, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index dbedeaed3cbd0..9b7ec33cc67ac 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -49,7 +49,7 @@ ORT_API_STATUS_IMPL(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info); ORT_API_STATUS_IMPL(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); -ORT_API_STATUS_IMPL(OrtGraph_DumpOnnxModel, const OrtGraphViewer* graph, const char* onnx_model_path); +ORT_API_STATUS_IMPL(OrtGraph_DumpOnnxModel, const OrtGraph* graph, const char* onnx_model_path); ORT_API_STATUS_IMPL(OrtGraph_CreateOrUpdateEpCtxGraph, const OrtGraphViewer* graph, @@ -62,10 +62,12 @@ ORT_API_STATUS_IMPL(OrtGraph_CreateOrUpdateEpCtxGraph, const char* const* extra_attr_keys, const char* const* extra_attr_values, size_t extra_attr_num, - _Outptr_ OrtGraphViewer** ep_context_graph); + _Outptr_ OrtGraph** ep_context_graph); ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); +ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraph* graph); + ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); ORT_API_STATUS_IMPL(OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, _Out_ bool* is_same); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 04f621d5d76e2..ad560c5a019f9 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -2694,7 +2694,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort if (dump_ep_context_model_) { create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, reinterpret_cast(serialized_engine->data()), serialized_engine->size()); graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraphViewer(ep_ctx_graph_); + graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); } } } @@ -2779,7 +2779,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort create_ep_context_model(graph_body_viewer, engine_cache_path, engine_cache_relative_path_to_context_model_dir, node_name, nullptr, 0); if (ep_context_embed_mode_ == 0) { graph_api_->OrtGraph_DumpOnnxModel(ep_ctx_graph_, ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraphViewer(ep_ctx_graph_); + graph_api_->OrtGraph_ReleaseGraph(ep_ctx_graph_); } } @@ -3152,7 +3152,7 @@ OrtStatusPtr TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const Ort this_->extra_attr_keys_.size(), &this_->ep_ctx_graph_); graph_api_->OrtGraph_DumpOnnxModel(this_->ep_ctx_graph_, this_->ctx_model_path_.c_str()); - graph_api_->OrtGraph_ReleaseGraphViewer(this_->ep_ctx_graph_); + graph_api_->OrtGraph_ReleaseGraph(this_->ep_ctx_graph_); } context_update = true; diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index 362f75d8da8f3..65f049657b88f 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -314,7 +314,7 @@ struct TensorrtExecutionProvider : public OrtExecutionProvider { std::string ep_cache_context_attr_; std::string engine_cache_relative_path_to_context_model_dir; - OrtGraphViewer* ep_ctx_graph_ = nullptr; + OrtGraph* ep_ctx_graph_ = nullptr; std::vector extra_attr_keys_; std::vector extra_attr_values_; From ce76175fa1bb7f8fb3e1ebe9c3e2b788a2750690 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:54:48 -0800 Subject: [PATCH 78/81] Add c++ wrapper for plugin ep api (#23045) --- .../core/session/onnxruntime_cxx_api_ep.h | 62 ++++ .../core/session/onnxruntime_cxx_inline_ep.h | 266 ++++++++++++++++++ 2 files changed, 328 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h index 1188904f7be3c..53298cbc605c1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h @@ -8,9 +8,43 @@ namespace Ort { namespace PluginEP { +struct ValueInfoRef { +explicit ValueInfoRef(OrtValueInfoRef*); +~ValueInfoRef(); +const std::vector GetShape(); +const ONNXTensorElementDataType GetTensorElementType(); +private: +OrtValueInfoRef* value_info_; +}; + struct Graph { explicit Graph(const OrtGraphViewer*); +const OrtGraphViewer* GetGraph() { return graph_; } const char* GetName(); +bool IsConstantInitializer(const char* name, bool check_outer_scope); +const std::vector GetNodesIndexInTopologicalOrder(int execution_order); +bool IsSubgraph(); +std::shared_ptr GetParenNode(); +std::filesystem::path GetModelPath(); +// std::vector> GetRequiredInputs(); +// std::vector> GetAllInputs(); +// std::vector> GetAllInitializers(); +std::shared_ptr GetOrtNode(size_t node_index); +// std::vector> GetNodesConsumingInput(const char* input_name); +std::shared_ptr GetNodeProducingOutput(const char* output_name); +int NumberOfNodes(); +int MaxNodeIndex(); +size_t GetOutputSize(); +std::string GetIthOutputName(size_t i); +int32_t GetIthOutputElemType(size_t i); +// std::shared_ptr GetInitializerTensor(const char* initializer_name); +std::shared_ptr GetValueInfo(const char* name); +// void SerializeToArray(void** data, size_t* data_size); +// void DumpOnnxModel(const std::filesystem::path& onnx_model_path); +// CreateOrUpdateEpCtxGraph(); +std::shared_ptr GetSubGraph(std::vector node_indices); +// bool IsSameGraph(const Graph& other); + private: const OrtGraphViewer* graph_; }; @@ -18,6 +52,34 @@ const OrtGraphViewer* graph_; struct Node { explicit Node(const OrtNode*); const char* GetName(); +const std::string GetDescription(); +const std::string GetDomain(); +int SinceVersion(); +const std::string GetExecutionProviderType(); +const std::string GetOpType(); +size_t GetImplicitInputSize(); +const std::string GetIthImplicitInputName(size_t i); +size_t GetNumInputs(); +const std::string GetIthInputName(size_t i); +size_t GetNumOutputs(); +const std::string GetIthOutputName(size_t i); +size_t GetIndex(); +// const std::vector GetAttributeNames(); +size_t GetAttributeSize(); +int GetAttributeType(std::string attribute_name); +size_t GetAttributeKeyCount(std::string attribute_name); +int GetAttributeIntSize(std::string attribute_name); +int GetAttributeFloatSize(std::string attribute_name); +int GetAttributeStringSize(std::string attribute_name); +int64_t GetAttributeIthInt(std::string attribute_name, size_t i); +float GetAttributeIthFloat(std::string attribute_name, size_t i); +const std::string GetAttributeIthStr(std::string attribute_name, size_t i); +// GetAttributeIthStrWithSize +const std::string GetAttributeStr(std::string attribute_name); +// GetAttributeStrWithSize +int64_t GetAttributeInt(std::string attribute_name); +float GetAttributeFloat(std::string attribute_name); +// GetSubgraphs private: const OrtNode* node_; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h index 99b80e2b56450..6ea31f9189a19 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h @@ -8,6 +8,21 @@ namespace PluginEP { static const OrtGraphApi* ort_graph_api = GetApi().GetGraphApi(ORT_API_VERSION); +inline ValueInfoRef::ValueInfoRef(OrtValueInfoRef* value_info) : value_info_(value_info) {} + +inline ValueInfoRef::~ValueInfoRef() { + ort_graph_api->OrtGraph_ReleaseValueInfo(value_info_); +} + +inline const std::vector ValueInfoRef::GetShape() { + std::vector shape(value_info_->shape, value_info_->shape + value_info_->shape_len); + return shape; +} + +inline const ONNXTensorElementDataType ValueInfoRef::GetTensorElementType() { + return value_info_->data_type; +} + inline Graph::Graph(const OrtGraphViewer* graph) : graph_(graph) {} inline const char* Graph::GetName() { @@ -16,6 +31,102 @@ inline const char* Graph::GetName() { return graph_name; } +inline bool Graph::IsConstantInitializer(const char* name, bool check_outer_scope) { + bool is_initializer = false; + ThrowOnError(ort_graph_api->OrtGraph_IsConstantInitializer(graph_, name, check_outer_scope, &is_initializer)); + return is_initializer; +} + +inline const std::vector Graph::GetNodesIndexInTopologicalOrder(int execution_order) { + const size_t* nodes_index = nullptr; + size_t nodes_count = 0; + ThrowOnError(ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_, execution_order, nodes_index, nodes_count)); + return std::vector(nodes_index, nodes_index + nodes_count); +} + +inline bool Graph::IsSubgraph() { + bool is_subgraph = false; + ThrowOnError(ort_graph_api->OrtGraph_IsSubgraph(graph_, &is_subgraph)); + return is_subgraph; +} + +inline std::shared_ptr Graph::GetParenNode() { + const OrtNode* parent_node = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetParenNode(graph_, &parent_node)); + return std::make_shared(parent_node); +} + +inline std::filesystem::path Graph::GetModelPath() { + const void* model_path = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetModelPath(graph_, &model_path)); + return *reinterpret_cast(model_path); +} + +inline std::shared_ptr Graph::GetOrtNode(size_t node_index) { + const OrtNode* node = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetOrtNode(graph_, node_index, &node)); + return std::make_shared(node); +} + +inline std::shared_ptr Graph::GetNodeProducingOutput(const char* output_name) { + const OrtNode* node = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetNodeProducingOutput(graph_, output_name, &node)); + return std::make_shared(node); +} + +inline int Graph::NumberOfNodes() { + int num_nodes = 0; + ThrowOnError(ort_graph_api->OrtGraph_NumberOfNodes(graph_, &num_nodes)); + return num_nodes; +} + +inline int Graph::MaxNodeIndex() { + int max_node_index = 0; + ThrowOnError(ort_graph_api->OrtGraph_MaxNodeIndex(graph_, &max_node_index)); + return max_node_index; +} + +inline size_t Graph::GetOutputSize() { + size_t output_size = 0; + ThrowOnError(ort_graph_api->OrtGraph_GetOutputSize(graph_, &output_size)); + return output_size; +} + +inline std::string Graph::GetIthOutputName(size_t i) { + const char* output_name = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetIthOutputName(graph_, i, &output_name)); + return std::string(output_name); +} + +inline int32_t Graph::GetIthOutputElemType(size_t i) { + int32_t elem_type = 0; + ThrowOnError(ort_graph_api->OrtGraph_GetIthOutputElemType(graph_, i, &elem_type)); + return elem_type; +} + +inline std::shared_ptr Graph::GetValueInfo(const char* name) { + OrtValueInfoRef* value_info = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetValueInfo(graph_, name, &value_info)); + return std::make_shared(value_info); +} + +// inline void Graph::DumpOnnxModel(const std::filesystem::path& onnx_model_path) { +// ThrowOnError(ort_graph_api->OrtGraph_DumpOnnxModel(graph_->GetGraph(), onnx_model_path.c_str())); +// } + +inline std::shared_ptr Graph::GetSubGraph(std::vector node_indices) { + const OrtGraphViewer* subgraph = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetSubGraph(graph_, node_indices.size(), node_indices.data(), &subgraph)); + // TODO:yang if should release subgraph in the decstructor of Graph? + return std::make_shared(subgraph); +} + +// inline bool Graph::IsSameGraph(const Graph& other) { +// bool is_same = false; +// ThrowOnError(ort_graph_api->OrtGraph_IsSameGraph(graph_, other.GetGraph(), &is_same)); +// return is_same; +// } + inline Node::Node(const OrtNode* node) : node_(node) {} inline const char* Node::GetName() { @@ -24,5 +135,160 @@ inline const char* Node::GetName() { return node_name; } +inline const std::string Node::GetDescription() { + const char* node_description = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetDescription(node_, &node_description)); + return std::string(node_description); +} + +inline const std::string Node::GetDomain() { + const char* node_domain = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetDomain(node_, &node_domain)); + return std::string(node_domain); } + +inline int Node::SinceVersion() { + int since_version = 0; + ThrowOnError(ort_graph_api->OrtNode_SinceVersion(node_, &since_version)); + return since_version; } + +inline const std::string Node::GetExecutionProviderType() { + const char* execution_provider_type = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetExecutionProviderType(node_, &execution_provider_type)); + return std::string(execution_provider_type); +} + +inline const std::string Node::GetOpType() { + const char* op_type = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetOpType(node_, &op_type)); + return std::string(op_type); +} + +inline size_t Node::GetImplicitInputSize() { + size_t implicit_input_size = 0; + ThrowOnError(ort_graph_api->OrtNode_GetImplicitInputSize(node_, &implicit_input_size)); + return implicit_input_size; +} + +inline const std::string Node::GetIthImplicitInputName(size_t i) { + const char* implicit_input_name = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetIthImplicitInputName(node_, i, &implicit_input_name)); + return std::string(implicit_input_name); +} + +inline size_t Node::GetNumInputs() { + size_t num_inputs = 0; + ThrowOnError(ort_graph_api->OrtNode_GetNumInputs(node_, &num_inputs)); + return num_inputs; +} + +inline const std::string Node::GetIthInputName(size_t i) { + const char* input_name = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetIthInputName(node_, i, &input_name)); + return std::string(input_name); +} + +inline size_t Node::GetNumOutputs() { + size_t num_outputs = 0; + ThrowOnError(ort_graph_api->OrtNode_GetNumOutputs(node_, &num_outputs)); + return num_outputs; +} + +inline const std::string Node::GetIthOutputName(size_t i) { + const char* output_name = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetIthOutputName(node_, i, &output_name)); + return std::string(output_name); +} + +inline size_t Node::GetIndex() { + size_t node_index = 0; + ThrowOnError(ort_graph_api->OrtNode_GetIndex(node_, &node_index)); + return node_index; +} + +// inline const std::vector Node::GetAttributeNames() { +// const ::onnxruntime::Node* n = reinterpret_cast(node_); +// const auto& attribute = n->GetAttributes(); +// std::vector attribute_names; +// for (const auto& attr : attribute) { +// attribute_names.push_back(attr.first); +// } +// return attribute_names; +// } + +inline size_t Node::GetAttributeSize() { + size_t attribute_size = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeSize(node_, &attribute_size)); + return attribute_size; +} + +inline int Node::GetAttributeType(std::string attribute_name) { + int attribute_type = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeType(node_, attribute_name.c_str(), &attribute_type)); + return attribute_type; +} + +inline size_t Node::GetAttributeKeyCount(std::string attribute_name) { + size_t attribute_key_count = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeKeyCount(node_, attribute_name.c_str(), &attribute_key_count)); + return attribute_key_count; +} + +inline int Node::GetAttributeIntSize(std::string attribute_name) { + int attribute_int_size = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeIntSize(node_, attribute_name.c_str(), &attribute_int_size)); + return attribute_int_size; +} + +inline int Node::GetAttributeFloatSize(std::string attribute_name) { + int attribute_float_size = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeFloatSize(node_, attribute_name.c_str(), &attribute_float_size)); + return attribute_float_size; +} + +inline int Node::GetAttributeStringSize(std::string attribute_name) { + int attribute_string_size = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeStringSize(node_, attribute_name.c_str(), &attribute_string_size)); + return attribute_string_size; +} + +inline int64_t Node::GetAttributeIthInt(std::string attribute_name, size_t i) { + int64_t attribute_ith_int = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeIthInt(node_, attribute_name.c_str(), i, &attribute_ith_int)); + return attribute_ith_int; +} + +inline float Node::GetAttributeIthFloat(std::string attribute_name, size_t i) { + float attribute_ith_float = 0.0f; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeIthFloat(node_, attribute_name.c_str(), i, &attribute_ith_float)); + return attribute_ith_float; +} + +inline const std::string Node::GetAttributeIthStr(std::string attribute_name, size_t i) { + const char* attribute_ith_string = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeIthStr(node_, attribute_name.c_str(), i, &attribute_ith_string)); + return std::string(attribute_ith_string); +} + +inline const std::string Node::GetAttributeStr(std::string attribute_name) { + const char* attribute_str = nullptr; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeStr(node_, attribute_name.c_str(), &attribute_str)); + return std::string(attribute_str); +} + +inline int64_t Node::GetAttributeInt(std::string attribute_name) { + int64_t attribute_int = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeInt(node_, attribute_name.c_str(), &attribute_int)); + return attribute_int; +} + +inline float Node::GetAttributeFloat(std::string attribute_name) { + float attribute_float = 0.0f; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeFloat(node_, attribute_name.c_str(), &attribute_float)); + return attribute_float; +} + + +} // namespace Ort +} // namespace PluginEP From fefbe278437f34e90a0140285066b301f2769e96 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Fri, 6 Dec 2024 23:36:44 -0800 Subject: [PATCH 79/81] refine ep plugin c++ wrapper (#23050) --- .../core/session/onnxruntime_c_api_ep.h | 14 +++ .../core/session/onnxruntime_cxx_api_ep.h | 14 +-- .../core/session/onnxruntime_cxx_inline_ep.h | 93 ++++++++++++++++--- .../core/session/onnxruntime_c_api_ep.cc | 9 ++ onnxruntime/core/session/ort_apis_ep.h | 2 + samples/openvino/ov_versions/utils.cc | 4 +- .../tensorRTEp/tensorrt_execution_provider.cc | 1 + 7 files changed, 115 insertions(+), 22 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index d0ff83e38be2f..77da384ac6671 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -195,6 +195,8 @@ ORT_API2_STATUS(ReleaseCharArray, const char** char_array); ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); /** \brief Get the consumer nodes of a node arg with the given name + * + * NOTE!!: The caller is responsible for releasing the OrtNode arrays using ReleaseOrtNodeArray. * * \param[in] graph The graph to query * \param[in] input_name The name of the node arg @@ -204,6 +206,16 @@ ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_in */ ORT_API2_STATUS(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers, _Out_ size_t* num_consumers); // TODO(leca): ValueConsumers::comprehensive ? +/** \brief Release the OrtNode arrays + * + * NOTE!!: Invoke this function after the use of OrtGraph_GetNodesConsumingInput. + * + * \param[in] nodes The OrtNode arrays to release + * \param[in] num_nodes The number of OrtNode arrays + * + */ +ORT_API2_STATUS(ReleaseOrtNodeArray, const OrtNode** nodes); + /** \brief Get the producer node of a node arg with the given name * * \param[in] graph The graph to query @@ -502,6 +514,8 @@ ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Outptr ORT_API2_STATUS(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* out); /** \brief Gets attribute names of the node. + * + * NOTE!!: The caller is responsible for releasing the char array using ReleaseCharArray. * * \param[in] node The node to query * \param[out] names The attribute names of the node diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h index 53298cbc605c1..335eb88a060b4 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h @@ -26,12 +26,12 @@ const std::vector GetNodesIndexInTopologicalOrder(int execution_order); bool IsSubgraph(); std::shared_ptr GetParenNode(); std::filesystem::path GetModelPath(); -// std::vector> GetRequiredInputs(); -// std::vector> GetAllInputs(); -// std::vector> GetAllInitializers(); -std::shared_ptr GetOrtNode(size_t node_index); -// std::vector> GetNodesConsumingInput(const char* input_name); -std::shared_ptr GetNodeProducingOutput(const char* output_name); +std::vector GetRequiredInputs(); +std::vector GetAllInputs(); +std::vector GetAllInitializers(); +Node GetOrtNode(size_t node_index); +std::vector GetNodesConsumingInput(const char* input_name); +Node GetNodeProducingOutput(const char* output_name); int NumberOfNodes(); int MaxNodeIndex(); size_t GetOutputSize(); @@ -64,7 +64,7 @@ const std::string GetIthInputName(size_t i); size_t GetNumOutputs(); const std::string GetIthOutputName(size_t i); size_t GetIndex(); -// const std::vector GetAttributeNames(); +std::vector GetAttributeNames(); size_t GetAttributeSize(); int GetAttributeType(std::string attribute_name); size_t GetAttributeKeyCount(std::string attribute_name); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h index 6ea31f9189a19..21f2834fb20d6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h @@ -62,16 +62,77 @@ inline std::filesystem::path Graph::GetModelPath() { return *reinterpret_cast(model_path); } -inline std::shared_ptr Graph::GetOrtNode(size_t node_index) { +inline std::vector Graph::GetRequiredInputs() { + const char** required_inputs = nullptr; + size_t required_inputs_count = 0; + ThrowOnError(ort_graph_api->OrtGraph_GetRequiredInputs(graph_, &required_inputs, &required_inputs_count)); + auto release_fn = [](const char** strs) { + ThrowOnError(ort_graph_api->ReleaseCharArray(strs)); + }; + std::unique_ptr guard(required_inputs, release_fn); + std::vector ret; + ret.reserve(required_inputs_count); + for (size_t i = 0; i < required_inputs_count; i++) { + ret.emplace_back(required_inputs[i]); + } + return ret; +} + +inline std::vector Graph::GetAllInputs() { + const char** all_inputs = nullptr; + size_t all_inputs_count = 0; + ThrowOnError(ort_graph_api->OrtGraph_GetAllInputs(graph_, &all_inputs, &all_inputs_count)); + auto release_fn = [](const char** strs) { + ThrowOnError(ort_graph_api->ReleaseCharArray(strs)); + }; + std::unique_ptr guard(all_inputs, release_fn); + std::vector ret; + ret.reserve(all_inputs_count); + for (size_t i = 0; i < all_inputs_count; i++) { + ret.emplace_back(all_inputs[i]); + } + return ret; +} + +inline std::vector Graph::GetAllInitializers() { + const char** all_initializers = nullptr; + size_t all_initializers_count = 0; + ThrowOnError(ort_graph_api->OrtGraph_GetAllInitializers(graph_, &all_initializers, &all_initializers_count)); + auto release_fn = [](const char** strs) { + ThrowOnError(ort_graph_api->ReleaseCharArray(strs)); + }; + std::unique_ptr guard(all_initializers, release_fn); + std::vector ret; + ret.reserve(all_initializers_count); + for (size_t i = 0; i < all_initializers_count; i++) { + ret.emplace_back(all_initializers[i]); + } + return ret; +} + +inline Ort::PluginEP::Node Graph::GetOrtNode(size_t node_index) { const OrtNode* node = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetOrtNode(graph_, node_index, &node)); - return std::make_shared(node); + return Ort::PluginEP::Node(node); } -inline std::shared_ptr Graph::GetNodeProducingOutput(const char* output_name) { +inline std::vector Graph::GetNodesConsumingInput(const char* input_name) { + const OrtNode** consumers = nullptr; + size_t consumer_count = 0; + ThrowOnError(ort_graph_api->OrtGraph_GetNodesConsumingInput(graph_, input_name, &consumers, &consumer_count)); + std::vector ret; + ret.reserve(consumer_count); + for (size_t i = 0; i < consumer_count; i++) { + ret.emplace_back(consumers[i]); + } + ort_graph_api->ReleaseOrtNodeArray(consumers); + return ret; +} + +inline Ort::PluginEP::Node Graph::GetNodeProducingOutput(const char* output_name) { const OrtNode* node = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetNodeProducingOutput(graph_, output_name, &node)); - return std::make_shared(node); + return Ort::PluginEP::Node(node); } inline int Graph::NumberOfNodes() { @@ -207,15 +268,21 @@ inline size_t Node::GetIndex() { return node_index; } -// inline const std::vector Node::GetAttributeNames() { -// const ::onnxruntime::Node* n = reinterpret_cast(node_); -// const auto& attribute = n->GetAttributes(); -// std::vector attribute_names; -// for (const auto& attr : attribute) { -// attribute_names.push_back(attr.first); -// } -// return attribute_names; -// } +inline std::vector Node::GetAttributeNames() { + const char** attribute_names = nullptr; + size_t attribute_names_count = 0; + ThrowOnError(ort_graph_api->OrtNode_GetAttributeNames(node_, &attribute_names, &attribute_names_count)); + auto release_fn = [](const char** strs) { + ThrowOnError(ort_graph_api->ReleaseCharArray(strs)); + }; + std::unique_ptr guard(attribute_names, release_fn); + std::vector ret; + ret.reserve(attribute_names_count); + for (size_t i = 0; i < attribute_names_count; i++) { + ret.emplace_back(attribute_names[i]); + } + return ret; +} inline size_t Node::GetAttributeSize() { size_t attribute_size = 0; diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index bcf5e6f1057aa..b52c1cae9046e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -104,6 +104,14 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetNodesConsumingInput, const OrtGrap return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::ReleaseOrtNodeArray, const OrtNode** nodes) { + if (!nodes) { + return nullptr; + } + delete[] nodes; + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** node) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); *node = reinterpret_cast(graph_viewer->GetProducerNode(output_name)); @@ -996,6 +1004,7 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::ReleaseCharArray, &OrtGraphApis::OrtGraph_GetOrtNode, &OrtGraphApis::OrtGraph_GetNodesConsumingInput, + &OrtGraphApis::ReleaseOrtNodeArray, &OrtGraphApis::OrtGraph_GetNodeProducingOutput, &OrtGraphApis::OrtGraph_NumberOfNodes, &OrtGraphApis::OrtGraph_MaxNodeIndex, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index 9b7ec33cc67ac..dd6ae746db564 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -27,6 +27,8 @@ ORT_API_STATUS_IMPL(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t nod ORT_API_STATUS_IMPL(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers, _Out_ size_t* num_consumers); +ORT_API_STATUS_IMPL(ReleaseOrtNodeArray, const OrtNode** nodes); + ORT_API_STATUS_IMPL(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** node); ORT_API_STATUS_IMPL(OrtGraph_NumberOfNodes, const OrtGraphViewer* graph, _Out_ int* num_nodes); diff --git a/samples/openvino/ov_versions/utils.cc b/samples/openvino/ov_versions/utils.cc index 877f7c10a4a33..212d34bded983 100644 --- a/samples/openvino/ov_versions/utils.cc +++ b/samples/openvino/ov_versions/utils.cc @@ -179,7 +179,7 @@ void IdentifyConnectedNodes(const OrtGraphApi* graph_api, graph_api->OrtNode_GetIndex(consumer_nodes[j], &consumer_index); IdentifyConnectedNodes(graph_api, graph_viewer, consumer_index, cluster, sub_cluster); } - // TODO(leca): release consumer_nodes + graph_api->ReleaseOrtNodeArray(consumer_nodes); } } @@ -263,7 +263,7 @@ void GetInputsOutputsOfCluster(const OrtGraphApi* graph_api, } } } - // TODO(leca): release consumer_nodes + graph_api->ReleaseOrtNodeArray(consumer_nodes); } } diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index ad560c5a019f9..2715eef5c217c 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1273,6 +1273,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr break; } } + graph_api_->ReleaseOrtNodeArray(consumers, consumer_count); } } From ce6630c12af97b7a594c83110d59898cd0737e78 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 12 Dec 2024 23:32:35 -0800 Subject: [PATCH 80/81] [TRT EP Plugin] Fix issues of building on Windows (#23099) - Modify CMakeLists.txt for TRT EP plugin - Add "-l" for specifying EP plugin lib path for onnxruntime_perf_test --- .../test/perftest/command_args_parser.cc | 5 +- onnxruntime/test/perftest/ort_test_session.cc | 6 +- onnxruntime/test/perftest/ort_test_session.h | 1 + .../test/perftest/test_configuration.h | 1 + samples/tensorRTEp/CMakeLists.txt | 102 ++++++++++++++---- samples/tensorRTEp/onnx_ctx_model_helper.cc | 1 + .../tensorRTEp/tensorrt_execution_provider.cc | 6 +- .../tensorrt_execution_provider_utils.h | 3 +- samples/utils/helper.cc | 59 ++++++++++ samples/utils/path_string.h | 70 ++++++++++++ 10 files changed, 228 insertions(+), 26 deletions(-) create mode 100644 samples/utils/helper.cc create mode 100644 samples/utils/path_string.h diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 02fed26cfe463..5072b3192869d 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -205,7 +205,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzng"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:l:f:F:S:T:C:AMPIDZvhsqzng"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -393,6 +393,9 @@ static bool ParseSessionConfigs(const std::string& configs_string, case 'g': test_config.plugin = true; break; + case 'l': + test_config.plugin_lib_path = optarg; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 772d0d462fa13..ecb8b505a5413 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -50,6 +50,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device provider_name_ = performance_test_config.machine_config.provider_type_name; plugin_ = performance_test_config.plugin; + plugin_lib_path_ = performance_test_config.plugin_lib_path; if (provider_name_ == onnxruntime::kDnnlExecutionProvider) { #ifdef USE_DNNL // Generate provider options @@ -172,7 +173,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #ifdef USE_TENSORRT const auto& api = Ort::GetApi(); if (plugin_) { - Ort::ThrowOnError(api.RegisterPluginExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/tensorRTEp/build/libTensorRTEp.so", env, "tensorrtEp")); + if (plugin_lib_path_.empty()) { + ORT_THROW("Please specify EP plugin library path, e.g. -l /path/to/ep_plugin_lib"); + } + Ort::ThrowOnError(api.RegisterPluginExecutionProviderLibrary(plugin_lib_path_.c_str(), env, "tensorrtEp")); std::vector keys{"trt_engine_cache_enable", "trt_dump_ep_context_model", "trt_ep_context_embed_mode"}, values{"0", "0", "0"}; Ort::ThrowOnError(api.SessionOptionsAppendPluginExecutionProvider(session_options, "tensorrtEp", env, keys.data(), values.data(), keys.size())); } else { diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index 51fa154a14e1f..95e96cf0d3267 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -47,6 +47,7 @@ class OnnxRuntimeTestSession : public TestSession { const int input_length_; std::string provider_name_; bool plugin_; + std::basic_string plugin_lib_path_; }; } // namespace perftest diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 8edb775b7ab32..8cc841123e3f1 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -71,6 +71,7 @@ struct PerformanceTestConfig { MachineConfig machine_config; RunConfig run_config; bool plugin = false; + std::basic_string plugin_lib_path; }; } // namespace perftest diff --git a/samples/tensorRTEp/CMakeLists.txt b/samples/tensorRTEp/CMakeLists.txt index b0982784e7c19..ab26f33fb2ddd 100644 --- a/samples/tensorRTEp/CMakeLists.txt +++ b/samples/tensorRTEp/CMakeLists.txt @@ -1,7 +1,7 @@ # usage: # cd build/ -# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/leca/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") -# cmake --build ./ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DORT_HOME=/home/lochi/repos/ort -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake --build ./ --config Debug cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) set(CMAKE_CXX_STANDARD 17) @@ -12,28 +12,90 @@ find_package(CUDAToolkit REQUIRED) add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) add_definitions(-DNV_TENSORRT_MAJOR=10) -file(GLOB tensorrt_src "./*.cc" "../utils/status.cc" "./cuda/unary_elementwise_ops_impl.cu") +add_definitions(-DNOMINMAX) +file(GLOB tensorrt_src "./*.cc" "../utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu") add_library(TensorRTEp SHARED ${tensorrt_src}) + +if (NOT ORT_HOME) + message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/") +endif() + +if (NOT TENSORRT_HOME) + message(FATAL_ERROR "Please specify TENSORRT_HOME, e.g. -DTENSORRT_HOME=/path/to/trt/") +endif() + +# Use release mode if not specified +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +# There is an known issue when running "Debug build" TRT EP plugin with "Release build" TRT builtin parser on Windows. +if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug") + MESSAGE(FATAL_ERROR "[Note] There is an known issue when running \"Debug build\" TRT EP plugin with \"Release build\" TRT built-in parser on Windows. Please use release mode to build TRT EP plugin.") +endif() + +if (WIN32) + set(PLATFORM "Windows") + set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/${CMAKE_BUILD_TYPE}/onnxruntime.lib") + set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps") + set(TRT_LIBS "${TENSORRT_HOME}/lib/nvinfer_10.lib" + "${TENSORRT_HOME}/lib/nvinfer_plugin_10.lib" + "${TENSORRT_HOME}/lib/nvonnxparser_10.lib") + set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib") + else() + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") + endif() +else() + set(PLATFORM "Linux") + set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/libonnxruntime.so") + set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps") + set(TRT_LIBS "${TENSORRT_HOME}/lib/libnvinfer.so" + "${TENSORRT_HOME}/lib/libnvinfer_plugin.so" + "${TENSORRT_HOME}/lib/libnvonnxparser.so") + set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/libflatbuffers.a" + "${DEPS_PATH}/onnx-build/libonnx.a" + "${DEPS_PATH}/onnx-build/libonnx_proto.a") + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/libprotobufd.a" + "${DEPS_PATH}/protobuf-build/libprotocd.a") + else() + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/libprotobuf.a" + "${DEPS_PATH}/protobuf-build/libprotoc.a") + endif() +endif() + +MESSAGE(STATUS "Looking for following dependencies ...") +MESSAGE(STATUS "Platform : ${PLATFORM}") +MESSAGE(STATUS "ORT home : ${ORT_HOME}") +MESSAGE(STATUS "ORT lib : ${ORT_LIB}") +MESSAGE(STATUS "Deps path: ${DEPS_PATH}") +MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") +MESSAGE(STATUS "TRT libs : ${TRT_LIBS}") + target_include_directories(TensorRTEp PUBLIC "../../include/onnxruntime" "../utils" "/usr/local/cuda/include" ${TENSORRT_HOME}/include - "../../build/tensorrt/Debug/_deps/flatbuffers-src/include" - "../../build/tensorrt/Debug/_deps/gsl-src/include" - "../../build/tensorrt/Debug/_deps/onnx-src" - "../../build/tensorrt/Debug/_deps/onnx-build" - "../../build/tensorrt/Debug/_deps/protobuf-src/src" + "${DEPS_PATH}/flatbuffers-src/include" + "${DEPS_PATH}/gsl-src/include" + "${DEPS_PATH}/onnx-src" + "${DEPS_PATH}/onnx-build" + "${DEPS_PATH}/protobuf-src/src" ) -## looks we need libonnxruntime.so in Win as in Windows you cannot build shared library with undefined symbol -target_link_libraries(TensorRTEp PUBLIC "/home/leca/code/onnxruntime/build/tensorrt/Debug/libonnxruntime.so" - ${TENSORRT_HOME}/lib/libnvinfer.so - ${TENSORRT_HOME}/lib/libnvinfer_plugin.so - ${TENSORRT_HOME}/lib/libnvonnxparser.so - "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/flatbuffers-build/libflatbuffers.a" - CUDA::cudart - "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/onnx-build/libonnx.a" - "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/onnx-build/libonnx_proto.a" - "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotobufd.a" - "/home/leca/code/onnxruntime/build/tensorrt/Debug/_deps/protobuf-build/libprotocd.a" - ) +target_link_libraries(TensorRTEp PUBLIC ${ORT_LIB} + ${TRT_LIBS} + CUDA::cudart + ${DEPS_LIBS}) diff --git a/samples/tensorRTEp/onnx_ctx_model_helper.cc b/samples/tensorRTEp/onnx_ctx_model_helper.cc index 56b52395ad04f..1b29f626f77ed 100644 --- a/samples/tensorRTEp/onnx_ctx_model_helper.cc +++ b/samples/tensorRTEp/onnx_ctx_model_helper.cc @@ -3,6 +3,7 @@ #include #include "onnx_ctx_model_helper.h" #include "tensorrt_execution_provider.h" +#include "path_string.h" namespace onnxruntime { diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 2715eef5c217c..76bcd5350b0d1 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1273,7 +1273,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGr break; } } - graph_api_->ReleaseOrtNodeArray(consumers, consumer_count); + graph_api_->ReleaseOrtNodeArray(consumers); } } @@ -1336,9 +1336,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const graph_api_->OrtGraph_GetModelPath(graph, reinterpret_cast(&model_path)); const auto& path_string = model_path->string(); #ifdef _WIN32 - std::strncpy_s(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); + strncpy_s(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); #else - std::strncpy(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); + strncpy(p->model_path_, path_string.c_str(), sizeof(p->model_path_) - 1); #endif p->model_path_[sizeof(p->model_path_) - 1] = '\0'; diff --git a/samples/tensorRTEp/tensorrt_execution_provider_utils.h b/samples/tensorRTEp/tensorrt_execution_provider_utils.h index ec3f923fe5d49..75508225deca8 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider_utils.h +++ b/samples/tensorRTEp/tensorrt_execution_provider_utils.h @@ -8,6 +8,7 @@ #include "flatbuffers/idl.h" #include "ort_trt_int8_cal_table.fbs.h" #include "murmurhash3.h" +#include "path_string.h" namespace fs = std::filesystem; @@ -289,7 +290,7 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) { // Use the model's file name instead of the entire path to avoid cache regeneration if path changes if (model_path->has_filename()) { - std::string model_name = model_path->filename(); + std::string model_name = PathToUTF8String(model_path->filename()); // LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name; // Ensure enough characters are hashed in case model names are too short diff --git a/samples/utils/helper.cc b/samples/utils/helper.cc new file mode 100644 index 0000000000000..7a889c30baec2 --- /dev/null +++ b/samples/utils/helper.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "common.h" + +#ifdef _WIN32 +#include +#include +#endif + +namespace onnxruntime { +#ifdef _WIN32 +std::string ToUTF8String(const std::wstring& s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + ORT_THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr); + assert(len > 0); + std::string ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} + +std::wstring ToWideString(const std::string& s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + ORT_THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0); + assert(len > 0); + std::wstring ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} +#endif // #ifdef _WIN32 + +#ifdef ORT_NO_EXCEPTIONS +void PrintFinalMessage(const char* msg) { +#if defined(__ANDROID__) + __android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg); +#else + // TODO, consider changing the output of the error message from std::cerr to logging when the + // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output + // might not be easily accessible on some systems such as mobile + // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS + std::cerr << msg << std::endl; +#endif +} +#endif // #ifdef ORT_NO_EXCEPTIONS + +} // namespace onnxruntime diff --git a/samples/utils/path_string.h b/samples/utils/path_string.h new file mode 100644 index 0000000000000..fd638aa5f39e5 --- /dev/null +++ b/samples/utils/path_string.h @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +// for std::tolower or std::towlower +#ifdef _WIN32 +#include +#else +#include +#endif + +// for converting / printing ORT_TSTR path strings to std::string +#ifdef _WIN32 +#define ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(X) std::wstring_convert>().to_bytes(X) +#define ORT_TSTR_CONVERT_FROM_STRING(X) std::wstring_convert>().from_bytes(X); +#else +#define ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(X) X +#define ORT_TSTR_CONVERT_FROM_STRING(X) X +#endif + +//#include "core/common/common.h" +//#include "core/session/onnxruntime_c_api.h" + +//#include "common.h" + +namespace onnxruntime { +// char type for filesystem paths +using PathChar = ORTCHAR_T; +// string type for filesystem paths +using PathString = std::basic_string; + +inline PathString ToPathString(const PathString& s) { + return s; +} + +#ifdef _WIN32 + +static_assert(std::is_same::value, "PathString is not std::wstring!"); + +inline PathString ToPathString(const std::string& s) { + return ToWideString(s); +} + +inline PathChar ToLowerPathChar(PathChar c) { + return std::towlower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return ToUTF8String(s); +} + +#else + +static_assert(std::is_same::value, "PathString is not std::string!"); + +inline PathChar ToLowerPathChar(PathChar c) { + return std::tolower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return s; +} + +#endif + +} // namespace onnxruntime From dc6674b1b2fb368364f1c94d6e1420d086897465 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Tue, 17 Dec 2024 00:10:57 -0800 Subject: [PATCH 81/81] refine ep plugin c++ wrapper (#23131) --- .../core/session/onnxruntime_c_api_ep.h | 13 +- .../core/session/onnxruntime_cxx_api_ep.h | 50 +++++-- .../core/session/onnxruntime_cxx_inline_ep.h | 131 +++++++++++++----- .../core/session/onnxruntime_c_api_ep.cc | 15 +- onnxruntime/core/session/ort_apis_ep.h | 4 +- .../tensorRTEp/tensorrt_execution_provider.cc | 4 +- 6 files changed, 167 insertions(+), 50 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index 77da384ac6671..8a59eaebfd210 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -389,13 +389,20 @@ ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraph* graph); /** \brief Release the graph viewer instance. * * NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocates model instead of - * graph, this API releases graph's owning_model explicitly which in turn will release the graph - * (because graph is hosted in an unique_ptr in Model class) + * graph, should set release_model to true, then this API will releases graph's owning_model explicitly which in turn will + * release the graph (because graph is hosted in an unique_ptr in Model class) * * \param[in] graph The graph to release + * \param[in] release_model If true, release the model as well, otherwise only release the graph viewer instance + * + */ +ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph, bool release_model); + +/** \brief Release the graph viewer array. * + * NOTE!!: Invoke this function after the use of OrtNode_GetSubgraphs. */ -ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); +ORT_API2_STATUS(OrtGraph_ReleaseGraphViewerArray, const OrtGraphViewer** graph_array, size_t num_graphs); /** \brief Check are two graph actually pointing to the same graph. * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h index 335eb88a060b4..7d206d16dbb3d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h @@ -8,6 +8,19 @@ namespace Ort { namespace PluginEP { +using VoidPtr = std::unique_ptr>; + +struct TensorRef { +explicit TensorRef(OrtTensorRef*); +~TensorRef(); +const std::vector GetShape(); +const ONNXTensorElementDataType GetTensorElementType(); +const char* GetData(); +size_t GetDataLen(); +private: +OrtTensorRef* tensor_; +}; + struct ValueInfoRef { explicit ValueInfoRef(OrtValueInfoRef*); ~ValueInfoRef(); @@ -18,8 +31,17 @@ OrtValueInfoRef* value_info_; }; struct Graph { -explicit Graph(const OrtGraphViewer*); -const OrtGraphViewer* GetGraph() { return graph_; } +explicit Graph(const OrtGraph*); +const OrtGraph* GetGraph() { return graph_; } +void DumpOnnxModel(const std::filesystem::path& onnx_model_path); +private: +const OrtGraph* graph_; +}; +using GraphPtr = std::unique_ptr>; + +struct GraphViewer { +explicit GraphViewer(const OrtGraphViewer*); +const OrtGraphViewer* GetGraphViewer() { return graph_; } const char* GetName(); bool IsConstantInitializer(const char* name, bool check_outer_scope); const std::vector GetNodesIndexInTopologicalOrder(int execution_order); @@ -37,17 +59,25 @@ int MaxNodeIndex(); size_t GetOutputSize(); std::string GetIthOutputName(size_t i); int32_t GetIthOutputElemType(size_t i); -// std::shared_ptr GetInitializerTensor(const char* initializer_name); +std::shared_ptr GetInitializerTensor(const char* initializer_name); std::shared_ptr GetValueInfo(const char* name); -// void SerializeToArray(void** data, size_t* data_size); -// void DumpOnnxModel(const std::filesystem::path& onnx_model_path); -// CreateOrUpdateEpCtxGraph(); -std::shared_ptr GetSubGraph(std::vector node_indices); -// bool IsSameGraph(const Graph& other); +std::pair SerializeToArray(); +GraphPtr CreateOrUpdateEpCtxGraph(const char* node_name, + const int64_t main_context, + const int64_t embed_mode, + const char* cache_path, + char* cache_data, + size_t size, + const char* const* extra_attr_keys, + const char* const* extra_attr_values, + size_t extra_attr_num); +GraphViewerPtr GetSubGraph(std::vector node_indices); +bool IsSameGraph(GraphViewer& other); private: const OrtGraphViewer* graph_; }; +using GraphViewerPtr = std::unique_ptr>; struct Node { explicit Node(const OrtNode*); @@ -74,12 +104,10 @@ int GetAttributeStringSize(std::string attribute_name); int64_t GetAttributeIthInt(std::string attribute_name, size_t i); float GetAttributeIthFloat(std::string attribute_name, size_t i); const std::string GetAttributeIthStr(std::string attribute_name, size_t i); -// GetAttributeIthStrWithSize const std::string GetAttributeStr(std::string attribute_name); -// GetAttributeStrWithSize int64_t GetAttributeInt(std::string attribute_name); float GetAttributeFloat(std::string attribute_name); -// GetSubgraphs +// TODO: add GetSubgraphs wrapper here private: const OrtNode* node_; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h index 21f2834fb20d6..8c5e9b7978ca9 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h @@ -8,6 +8,29 @@ namespace PluginEP { static const OrtGraphApi* ort_graph_api = GetApi().GetGraphApi(ORT_API_VERSION); +inline TensorRef::TensorRef(OrtTensorRef* tensor) : tensor_(tensor) {} + +inline TensorRef::~TensorRef() { + ort_graph_api->OrtGraph_ReleaseInitializerTensor(tensor_); +} + +inline const std::vector TensorRef::GetShape() { + std::vector shape(tensor_->shape, tensor_->shape + tensor_->shape_len); + return shape; +} + +inline const ONNXTensorElementDataType TensorRef::GetTensorElementType() { + return tensor_->data_type; +} + +inline const char* TensorRef::GetData() { + return tensor_->data; +} + +inline size_t TensorRef::GetDataLen() { + return tensor_->data_len; +} + inline ValueInfoRef::ValueInfoRef(OrtValueInfoRef* value_info) : value_info_(value_info) {} inline ValueInfoRef::~ValueInfoRef() { @@ -23,46 +46,52 @@ inline const ONNXTensorElementDataType ValueInfoRef::GetTensorElementType() { return value_info_->data_type; } -inline Graph::Graph(const OrtGraphViewer* graph) : graph_(graph) {} +inline Graph::Graph(const OrtGraph* graph) : graph_(graph) {} + +inline void Graph::DumpOnnxModel(const std::filesystem::path& onnx_model_path) { + ThrowOnError(ort_graph_api->OrtGraph_DumpOnnxModel(graph_, onnx_model_path.c_str())); +} -inline const char* Graph::GetName() { +inline GraphViewer::GraphViewer(const OrtGraphViewer* graph) : graph_(graph) {} + +inline const char* GraphViewer::GetName() { const char* graph_name = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetName(graph_, &graph_name)); return graph_name; } -inline bool Graph::IsConstantInitializer(const char* name, bool check_outer_scope) { +inline bool GraphViewer::IsConstantInitializer(const char* name, bool check_outer_scope) { bool is_initializer = false; ThrowOnError(ort_graph_api->OrtGraph_IsConstantInitializer(graph_, name, check_outer_scope, &is_initializer)); return is_initializer; } -inline const std::vector Graph::GetNodesIndexInTopologicalOrder(int execution_order) { +inline const std::vector GraphViewer::GetNodesIndexInTopologicalOrder(int execution_order) { const size_t* nodes_index = nullptr; size_t nodes_count = 0; - ThrowOnError(ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_, execution_order, nodes_index, nodes_count)); + ThrowOnError(ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_, execution_order, &nodes_index, &nodes_count)); return std::vector(nodes_index, nodes_index + nodes_count); } -inline bool Graph::IsSubgraph() { +inline bool GraphViewer::IsSubgraph() { bool is_subgraph = false; ThrowOnError(ort_graph_api->OrtGraph_IsSubgraph(graph_, &is_subgraph)); return is_subgraph; } -inline std::shared_ptr Graph::GetParenNode() { +inline std::shared_ptr GraphViewer::GetParenNode() { const OrtNode* parent_node = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetParenNode(graph_, &parent_node)); - return std::make_shared(parent_node); + return std::make_shared(parent_node); } -inline std::filesystem::path Graph::GetModelPath() { +inline std::filesystem::path GraphViewer::GetModelPath() { const void* model_path = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetModelPath(graph_, &model_path)); return *reinterpret_cast(model_path); } -inline std::vector Graph::GetRequiredInputs() { +inline std::vector GraphViewer::GetRequiredInputs() { const char** required_inputs = nullptr; size_t required_inputs_count = 0; ThrowOnError(ort_graph_api->OrtGraph_GetRequiredInputs(graph_, &required_inputs, &required_inputs_count)); @@ -78,7 +107,7 @@ inline std::vector Graph::GetRequiredInputs() { return ret; } -inline std::vector Graph::GetAllInputs() { +inline std::vector GraphViewer::GetAllInputs() { const char** all_inputs = nullptr; size_t all_inputs_count = 0; ThrowOnError(ort_graph_api->OrtGraph_GetAllInputs(graph_, &all_inputs, &all_inputs_count)); @@ -94,7 +123,7 @@ inline std::vector Graph::GetAllInputs() { return ret; } -inline std::vector Graph::GetAllInitializers() { +inline std::vector GraphViewer::GetAllInitializers() { const char** all_initializers = nullptr; size_t all_initializers_count = 0; ThrowOnError(ort_graph_api->OrtGraph_GetAllInitializers(graph_, &all_initializers, &all_initializers_count)); @@ -110,13 +139,13 @@ inline std::vector Graph::GetAllInitializers() { return ret; } -inline Ort::PluginEP::Node Graph::GetOrtNode(size_t node_index) { +inline Ort::PluginEP::Node GraphViewer::GetOrtNode(size_t node_index) { const OrtNode* node = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetOrtNode(graph_, node_index, &node)); return Ort::PluginEP::Node(node); } -inline std::vector Graph::GetNodesConsumingInput(const char* input_name) { +inline std::vector GraphViewer::GetNodesConsumingInput(const char* input_name) { const OrtNode** consumers = nullptr; size_t consumer_count = 0; ThrowOnError(ort_graph_api->OrtGraph_GetNodesConsumingInput(graph_, input_name, &consumers, &consumer_count)); @@ -129,64 +158,102 @@ inline std::vector Graph::GetNodesConsumingInput(const char return ret; } -inline Ort::PluginEP::Node Graph::GetNodeProducingOutput(const char* output_name) { +inline Ort::PluginEP::Node GraphViewer::GetNodeProducingOutput(const char* output_name) { const OrtNode* node = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetNodeProducingOutput(graph_, output_name, &node)); return Ort::PluginEP::Node(node); } -inline int Graph::NumberOfNodes() { +inline int GraphViewer::NumberOfNodes() { int num_nodes = 0; ThrowOnError(ort_graph_api->OrtGraph_NumberOfNodes(graph_, &num_nodes)); return num_nodes; } -inline int Graph::MaxNodeIndex() { +inline int GraphViewer::MaxNodeIndex() { int max_node_index = 0; ThrowOnError(ort_graph_api->OrtGraph_MaxNodeIndex(graph_, &max_node_index)); return max_node_index; } -inline size_t Graph::GetOutputSize() { +inline size_t GraphViewer::GetOutputSize() { size_t output_size = 0; ThrowOnError(ort_graph_api->OrtGraph_GetOutputSize(graph_, &output_size)); return output_size; } -inline std::string Graph::GetIthOutputName(size_t i) { +inline std::string GraphViewer::GetIthOutputName(size_t i) { const char* output_name = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetIthOutputName(graph_, i, &output_name)); return std::string(output_name); } -inline int32_t Graph::GetIthOutputElemType(size_t i) { +inline int32_t GraphViewer::GetIthOutputElemType(size_t i) { int32_t elem_type = 0; ThrowOnError(ort_graph_api->OrtGraph_GetIthOutputElemType(graph_, i, &elem_type)); return elem_type; } -inline std::shared_ptr Graph::GetValueInfo(const char* name) { +inline std::shared_ptr GraphViewer::GetInitializerTensor(const char* initializer_name) { + OrtTensorRef* tensor = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetInitializerTensor(graph_, initializer_name, &tensor)); + return std::make_shared(tensor); +} + +inline std::shared_ptr GraphViewer::GetValueInfo(const char* name) { OrtValueInfoRef* value_info = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetValueInfo(graph_, name, &value_info)); return std::make_shared(value_info); } -// inline void Graph::DumpOnnxModel(const std::filesystem::path& onnx_model_path) { -// ThrowOnError(ort_graph_api->OrtGraph_DumpOnnxModel(graph_->GetGraph(), onnx_model_path.c_str())); -// } +inline std::pair GraphViewer::SerializeToArray() { + void* serialized_data = nullptr; + size_t serialized_data_len = 0; + ThrowOnError(ort_graph_api->OrtGraph_SerializeToArray(graph_, &serialized_data, &serialized_data_len)); + return std::make_pair(VoidPtr(serialized_data, [](void* ptr) { ort_graph_api->OrtFreeMem(ptr); }), serialized_data_len); +} + +inline GraphPtr GraphViewer::CreateOrUpdateEpCtxGraph(const char* node_name, + const int64_t main_context, + const int64_t embed_mode, + const char* cache_path, + char* cache_data, + size_t size, + const char* const* extra_attr_keys, + const char* const* extra_attr_values, + size_t extra_attr_num) { + OrtGraph* graph = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_CreateOrUpdateEpCtxGraph(graph_, + node_name, + main_context, + embed_mode, + cache_path, + cache_data, + size, + extra_attr_keys, + extra_attr_values, + extra_attr_num, + &graph)); + auto release_fn = [](Graph* graph) { + ThrowOnError(ort_graph_api->OrtGraph_ReleaseGraph(graph->GetGraph())); + }; + return std::unique_ptr(new Graph(graph), release_fn); +} -inline std::shared_ptr Graph::GetSubGraph(std::vector node_indices) { +inline GraphViewerPtr GraphViewer::GetSubGraph(std::vector node_indices) { const OrtGraphViewer* subgraph = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetSubGraph(graph_, node_indices.size(), node_indices.data(), &subgraph)); - // TODO:yang if should release subgraph in the decstructor of Graph? - return std::make_shared(subgraph); + auto release_fn = [](GraphViewer* graph) { + ThrowOnError(ort_graph_api->OrtGraph_ReleaseGraphViewer(graph->GetGraphViewer(), true)); + }; + return std::unique_ptr(new GraphViewer(subgraph), release_fn); } -// inline bool Graph::IsSameGraph(const Graph& other) { -// bool is_same = false; -// ThrowOnError(ort_graph_api->OrtGraph_IsSameGraph(graph_, other.GetGraph(), &is_same)); -// return is_same; -// } +inline bool GraphViewer::IsSameGraph(GraphViewer& other) { + bool is_same = false; + ThrowOnError(ort_graph_api->OrtGraph_IsSameGraph(graph_, other.GetGraphViewer(), &is_same)); + return is_same; +} inline Node::Node(const OrtNode* node) : node_(node) {} diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index b52c1cae9046e..96dbd78438f56 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -775,15 +775,25 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraph* ort_gra return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph, bool release_model) { if (graph) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - delete &(graph_viewer->GetGraph()).GetModel(); + if (release_model) { + delete &(graph_viewer->GetGraph()).GetModel(); + } delete graph_viewer; } return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewerArray, const OrtGraphViewer** graph_viewers, size_t num_graphs) { + for (size_t i = 0; i < num_graphs; i++) { + OrtGraph_ReleaseGraphViewer(graph_viewers[i], false); + } + delete[] graph_viewers; + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, bool* is_same) { const ::onnxruntime::GraphViewer* graph_viewer1 = reinterpret_cast(graph1); const ::onnxruntime::GraphViewer* graph_viewer2 = reinterpret_cast(graph2); @@ -1021,6 +1031,7 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetSubGraph, &OrtGraphApis::OrtGraph_ReleaseGraph, &OrtGraphApis::OrtGraph_ReleaseGraphViewer, + &OrtGraphApis::OrtGraph_ReleaseGraphViewerArray, &OrtGraphApis::OrtGraph_IsSameGraph, &OrtGraphApis::OrtNode_GetName, &OrtGraphApis::OrtNode_GetDescription, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index dd6ae746db564..33c940242cc61 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -70,7 +70,9 @@ ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraph* graph); -ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); +ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph, bool release_model); + +ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewerArray, const OrtGraphViewer** graph_viewers, size_t num_graphs); ORT_API_STATUS_IMPL(OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, _Out_ bool* is_same); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 76bcd5350b0d1..365ee37f22fcd 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1399,6 +1399,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const break; } } + graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count); if (!all_subgraphs_are_supported) { // if not all its subgraphs are supported, we need to exclude this control flow op continue; @@ -1499,6 +1500,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const break; } + graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count); } if (all_subgraphs_are_supported) { @@ -3750,7 +3752,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } nodes_list_output.push_back(next_nodes_list[i]); } - graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer); + graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer, true); } } }