From ebdbbb7531f6be3a4df7901ff5482a8174b51bd7 Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Fri, 20 Dec 2024 22:03:27 -0800 Subject: [PATCH] [VitisAI] Int4 support (#22850) ### Description 1. Add support for throwing error when hardware is not supported for VitisAI. 2. Add support for unloading VitisAI EP. 3. Add API for Win25. ### Motivation and Context This is requirement for Win25 --- .../shared_library/provider_interfaces.h | 1 + .../shared_library/provider_wrappedtypes.h | 3 ++ .../core/providers/vitisai/imp/global_api.cc | 52 ++++++++++++++++--- .../providers/vitisai/imp/tensor_proto.cc | 13 +++++ .../core/providers/vitisai/imp/tensor_proto.h | 4 ++ .../vitisai/include/vaip/global_api.h | 1 + .../providers/vitisai/include/vaip/my_ort.h | 1 + .../vitisai/include/vaip/vaip_ort_api.h | 10 +++- .../vitisai/vitisai_provider_factory.cc | 2 +- .../core/session/provider_bridge_ort.cc | 1 + 10 files changed, 80 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 8bd4067e59492..5a179ec622f8c 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -589,6 +589,7 @@ struct ProviderHost { virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0; // OrtSessionOptions virtual const std::unordered_map& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) = 0; + virtual bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) = 0; // ComputeCapability virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index d8516d5858a2f..76b6d8063fd66 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1476,5 +1476,8 @@ struct OrtSessionOptions final { const std::unordered_map& GetConfigOptions() const { return onnxruntime::g_host->SessionOptions__GetConfigOptionsMap(this); } + bool GetEnableProfiling() const { + return onnxruntime::g_host->SessionOptions__GetEnableProfiling(this); + } PROVIDER_DISALLOW_ALL(OrtSessionOptions) }; diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index cccaa65de45f2..8111ee3c1fe61 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -47,6 +47,8 @@ struct OrtVitisAIEpAPI { void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector& ret_domain); std::vector>* (*compile_onnx_model_with_options)( const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); + std::vector>* (*compile_onnx_model_vitisai_ep_with_error_handling)( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func); uint32_t (*vaip_get_version)(); void (*create_ep_context_nodes)( const std::vector>& eps, @@ -77,10 +79,11 @@ struct OrtVitisAIEpAPI { ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); #endif ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep)); - auto status = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); - if (!status.IsOK()) { - ::onnxruntime::LogRuntimeError(0, status, __FILE__, static_cast(__FUNCTION__), __LINE__); - ORT_THROW(status); + auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_error_handling", (void**)&compile_onnx_model_vitisai_ep_with_error_handling); + auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); + if ((!status1.IsOK()) && (!status2.IsOK())) { + ::onnxruntime::LogRuntimeError(0, status2, __FILE__, static_cast(__FUNCTION__), __LINE__); + ORT_THROW(status2); } std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version", (void**)&vaip_get_version); @@ -89,6 +92,14 @@ struct OrtVitisAIEpAPI { ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start)); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options)); } + void Clear() { + if (handle_) { + auto& env = Provider_GetHost()->Env__Default(); + auto status = env.UnloadDynamicLibrary(handle_); + vai_assert(status.IsOK(), status.ErrorMessage()); + handle_ = nullptr; + } + } private: void* handle_{}; @@ -109,10 +120,25 @@ void profiler_collect( } } +void change_status_with_error(void* status_ptr, int error_code, const char* error_msg) { + auto status = reinterpret_cast(status_ptr); + *status = Status(onnxruntime::common::ONNXRUNTIME, error_code, error_msg); +} + vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { + const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options) { auto model_path = graph_viewer.ModelPath().string(); - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); + if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling) { + Status status = Status::OK(); + auto status_ptr = reinterpret_cast(&status); + auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling(model_path, graph_viewer.GetGraph(), options, status_ptr, change_status_with_error)); + if (!status.IsOK()) { + ORT_THROW(status); + } + return ret; + } else { + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); + } } std::optional> create_ep_context_nodes( @@ -396,10 +422,12 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape; the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); }; the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; }; + the_global_api.tensor_proto_new_i4 = vaip::tensor_proto_new_i4; the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8; the_global_api.tensor_proto_new_i16 = vaip::tensor_proto_new_i16; the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32; the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64; + the_global_api.tensor_proto_new_u4 = vaip::tensor_proto_new_u4; the_global_api.tensor_proto_new_u8 = vaip::tensor_proto_new_u8; the_global_api.tensor_proto_new_u16 = vaip::tensor_proto_new_u16; the_global_api.tensor_proto_new_u32 = vaip::tensor_proto_new_u32; @@ -468,9 +496,21 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(local_str)); }; + the_global_api.is_profiling_enabled = [](void* session_options) { + auto options = reinterpret_cast(session_options); + return options->GetEnableProfiling(); + }; + the_global_api.graph_remove_initialized_tensor = [](Graph& graph, const std::string& tensor_name) { + graph.RemoveInitializedTensor(tensor_name); + }; if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { return &the_global_api; } } + +void deinitialize_vitisai_ep() { + s_library_vitisaiep.Clear(); + s_kernel_registry_vitisaiep.reset(); +} diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index 872d022e85264..bb942c69003a1 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -87,6 +87,12 @@ static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, co return tensor_proto.release(); } +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT4, + reinterpret_cast(&data[0]), data.size() * sizeof(data[0])); +} + ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, const std::vector& data) { return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8, @@ -108,6 +114,13 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64, reinterpret_cast(&data[0]), data.size() * sizeof(data[0])); } + +ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT4, + reinterpret_cast(&data[0]), data.size() * sizeof(data[0])); +} + ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector& shape, const std::vector& data) { return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT8, diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 618d9c4728e2f..73015d3411a54 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -9,6 +9,10 @@ namespace vaip { gsl::span tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor); vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor); const std::string& tensor_proto_get_name(const ONNX_NAMESPACE::TensorProto& tensor); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector& shape, + const std::vector& data); ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, const std::vector& data); ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector& shape, diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 704b156dff57f..7791ea430054a 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -11,6 +11,7 @@ #include "vaip/custom_op.h" #include void initialize_vitisai_ep(); +void deinitialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); std::shared_ptr get_kernel_registry_vitisaiep(); const std::vector& get_domains_vitisaiep(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 7628e45d2b933..85a1262d8489b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -122,4 +122,5 @@ using InitializedTensorSet = std::unordered_map; using ModelMetaData = std::unordered_map; +using error_report_func = void (*)(void*, int, const char*); } // namespace vaip_core diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 9425c08dceebc..6a51ef862280b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (12u) +#define VAIP_ORT_API_MAJOR (13u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -235,6 +235,14 @@ struct OrtApiForVaip { DllSafe (*model_proto_serialize_as_string)(ModelProto& model_proto); // [96] void (*model_proto_delete)(ModelProto* p); // [97] DllSafe (*attr_proto_release_string)(AttributeProto* attr); // [98] + bool (*is_profiling_enabled)(void* session_options); // [99] // [98] + TensorProto* (*tensor_proto_new_i4)(const std::string& name, + const std::vector& shape, + const std::vector& data); // [100] + TensorProto* (*tensor_proto_new_u4)(const std::string& name, + const std::vector& shape, + const std::vector& data); // [101] + void (*graph_remove_initialized_tensor)(Graph& graph, const std::string& tensor_name); // [102] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 453db30e1320f..99d9845302d9a 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -50,7 +50,7 @@ struct VitisAI_Provider : Provider { // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded void Initialize() override { initialize_vitisai_ep(); } // Called right before unloading the shared library - void Shutdown() override {} + void Shutdown() override { deinitialize_vitisai_ep(); } } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index a40fabd6a607c..af39edae2074d 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -720,6 +720,7 @@ struct ProviderHostImpl : ProviderHost { // OrtSessionOptions (wrapped) const std::unordered_map& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) override { return p->value.config_options.configurations; } + bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) override { return p->value.enable_profiling; }; // ComputeCapability (wrapped) std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; }