Skip to content

Commit

Permalink
add SessionOptionsAppendExecutionProvider_VitisAI
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenze Wang committed Jan 26, 2024
1 parent ccb708c commit 22ba64b
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 14 deletions.
12 changes: 12 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4558,6 +4558,18 @@ struct OrtApi {
_In_reads_(num_keys) const char* const* provider_options_keys,
_In_reads_(num_keys) const char* const* provider_options_values,
_In_ size_t num_keys);

/** \brief Append VitisAI provider to session options
*
* If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure.
*
* \param[in] options
* \param[in] vitisai_options
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI,
_In_ OrtSessionOptions* options, _In_ const void* vitisai_options);
};

/*
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/provider_factory_creators.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@
#include "core/providers/tvm/tvm_provider_factory_creator.h"
#endif

#if !defined(ORT_MINIMAL_BUILD)
#include "core/providers/vitisai/vitisai_provider_factory_creator.h"
#endif

#if defined(USE_XNNPACK)
#include "core/providers/xnnpack/xnnpack_provider_factory_creator.h"
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@

namespace onnxruntime {
struct VitisAIProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& provider_options);
static std::shared_ptr<IExecutionProviderFactory> Create(const void* provider_options);
};
} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2724,6 +2724,7 @@ static constexpr OrtApi ort_api_1_to_17 = {
&OrtApis::SetDeterministicCompute,
&OrtApis::KernelContext_ParallelFor,
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2,
&OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,4 +509,7 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
_In_reads_(num_keys) const char* const* provider_options_keys,
_In_reads_(num_keys) const char* const* provider_options_values,
_In_ size_t num_keys);

ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI,
_In_ OrtSessionOptions* options, _In_ const void* vitisai_options);
} // namespace OrtApis
15 changes: 13 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1692,8 +1692,8 @@ std::shared_ptr<IExecutionProviderFactory> DnnlProviderFactoryCreator::Create(co
return s_library_dnnl.Get().CreateExecutionProviderFactory(dnnl_options);
}

std::shared_ptr<IExecutionProviderFactory> VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) {
return s_library_vitisai.Get().CreateExecutionProviderFactory(&provider_options);
std::shared_ptr<IExecutionProviderFactory> VitisAIProviderFactoryCreator::Create(const void* provider_options) {
return s_library_vitisai.Get().CreateExecutionProviderFactory(provider_options);
}

ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() {
Expand Down Expand Up @@ -2578,3 +2578,14 @@ ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProvid
ORT_UNUSED_PARAMETER(ptr);
#endif
}

ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, _In_ const void* provider_options) {
API_IMPL_BEGIN
auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library");
}
options->provider_factories.push_back(factory);
return nullptr;
API_IMPL_END
}
13 changes: 8 additions & 5 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
#if !defined(ORT_MINIMAL_BUILD)
options->provider_factories.push_back(VitisAIProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
#endif
OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, &provider_options);
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
Expand Down Expand Up @@ -499,4 +495,11 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString,
ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) {
ORT_UNUSED_PARAMETER(ptr);
}

ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
_In_ OrtSessionOptions* options, _In_ const void* provider_options) {
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(provider_options);
return CreateNotEnabledStatus("VitisAI");
}
#endif
5 changes: 4 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
onnxruntime::MIGraphXProviderFactoryCreator::Create(0),
#endif
#ifdef USE_VITISAI
onnxruntime::VitisAIProviderFactoryCreator::Create(ProviderOptions{}),
[]() {
ProviderOptions provider_options_map;
return onnxruntime::VitisAIProviderFactoryCreator::Create(&provider_options_map);
}(),
#endif
#ifdef USE_ACL
onnxruntime::ACLProviderFactoryCreator::Create(0),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
LOGS_DEFAULT(FATAL) << "cannot find provider options for VitisAIExecutionProvider";
}
const auto& vitis_option_map = it->second;
return onnxruntime::VitisAIProviderFactoryCreator::Create(vitis_option_map)
return onnxruntime::VitisAIProviderFactoryCreator::Create(&vitis_option_map)
->CreateProvider();
#endif
} else if (type == kAclExecutionProvider) {
Expand Down

0 comments on commit 22ba64b

Please sign in to comment.