From 22ba64bedc7664525026eb2d219d164c13d44f93 Mon Sep 17 00:00:00 2001 From: Zhenze Wang Date: Fri, 26 Jan 2024 03:36:59 +0000 Subject: [PATCH] add SessionOptionsAppendExecutionProvider_VitisAI --- .../onnxruntime/core/session/onnxruntime_c_api.h | 12 ++++++++++++ .../core/providers/provider_factory_creators.h | 4 ---- .../vitisai/vitisai_provider_factory_creator.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 1 + onnxruntime/core/session/ort_apis.h | 3 +++ onnxruntime/core/session/provider_bridge_ort.cc | 15 +++++++++++++-- onnxruntime/core/session/provider_registration.cc | 13 ++++++++----- onnxruntime/python/onnxruntime_pybind_schema.cc | 5 ++++- onnxruntime/python/onnxruntime_pybind_state.cc | 2 +- 9 files changed, 43 insertions(+), 14 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b321b2b2bac27..2d83616429b50 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4558,6 +4558,18 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Append VitisAI provider to session options + * + * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] vitisai_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, _In_ const void* vitisai_options); }; /* diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index cc4b92da186e1..6a4ab6a3d2113 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -78,10 +78,6 @@ #include "core/providers/tvm/tvm_provider_factory_creator.h" #endif -#if !defined(ORT_MINIMAL_BUILD) -#include "core/providers/vitisai/vitisai_provider_factory_creator.h" -#endif - #if defined(USE_XNNPACK) #include "core/providers/xnnpack/xnnpack_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h b/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h index 9bb7cfa062a0f..bf294dd622e89 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h @@ -10,6 +10,6 @@ namespace onnxruntime { struct VitisAIProviderFactoryCreator { - static std::shared_ptr Create(const ProviderOptions& provider_options); + static std::shared_ptr Create(const void* provider_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d77c188f832a7..947bb30d42f70 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2724,6 +2724,7 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::SetDeterministicCompute, &OrtApis::KernelContext_ParallelFor, &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, + &OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c1caafa4dcad3..22897ba4fe145 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -509,4 +509,7 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, _In_ const void* vitisai_options); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 92777adcd55aa..636640d992a3d 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1692,8 +1692,8 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(co return s_library_dnnl.Get().CreateExecutionProviderFactory(dnnl_options); } -std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { - return s_library_vitisai.Get().CreateExecutionProviderFactory(&provider_options); +std::shared_ptr VitisAIProviderFactoryCreator::Create(const void* provider_options) { + return s_library_vitisai.Get().CreateExecutionProviderFactory(provider_options); } ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() { @@ -2578,3 +2578,14 @@ ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProvid ORT_UNUSED_PARAMETER(ptr); #endif } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, _In_ const void* provider_options) { + API_IMPL_BEGIN + auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options); + if (!factory) { + return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library"); + } + options->provider_factories.push_back(factory); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 7fc0f70b88c89..7f743fc7feb87 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -150,11 +150,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, status = create_not_supported_status(); #endif } else if (strcmp(provider_name, "VitisAI") == 0) { -#if !defined(ORT_MINIMAL_BUILD) - options->provider_factories.push_back(VitisAIProviderFactoryCreator::Create(provider_options)); -#else - status = create_not_supported_status(); -#endif + OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, &provider_options); } else { ORT_UNUSED_PARAMETER(options); status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, @@ -499,4 +495,11 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { ORT_UNUSED_PARAMETER(ptr); } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, _In_ const void* provider_options) { + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(provider_options); + return CreateNotEnabledStatus("VitisAI"); +} #endif diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index 3a977772873f3..2ab796f618929 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -50,7 +50,10 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::MIGraphXProviderFactoryCreator::Create(0), #endif #ifdef USE_VITISAI - onnxruntime::VitisAIProviderFactoryCreator::Create(ProviderOptions{}), + []() { + ProviderOptions provider_options_map; + return onnxruntime::VitisAIProviderFactoryCreator::Create(&provider_options_map); + }(), #endif #ifdef USE_ACL onnxruntime::ACLProviderFactoryCreator::Create(0), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d2cd6140b838e..dac98f35b61c6 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -989,7 +989,7 @@ std::unique_ptr CreateExecutionProviderInstance( LOGS_DEFAULT(FATAL) << "cannot find provider options for VitisAIExecutionProvider"; } const auto& vitis_option_map = it->second; - return onnxruntime::VitisAIProviderFactoryCreator::Create(vitis_option_map) + return onnxruntime::VitisAIProviderFactoryCreator::Create(&vitis_option_map) ->CreateProvider(); #endif } else if (type == kAclExecutionProvider) {