Skip to content

Commit

Permalink
[VitisAI] Int4 support (#22850)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
1. Add support for throwing error when hardware is not supported for
VitisAI.
2. Add support for unloading VitisAI EP.
3. Add API for Win25.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This is requirement for Win25
  • Loading branch information
BoarQing authored Dec 21, 2024
1 parent 6806174 commit ebdbbb7
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ struct ProviderHost {
virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0;
// OrtSessionOptions
virtual const std::unordered_map<std::string, std::string>& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) = 0;
virtual bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) = 0;
// ComputeCapability
virtual std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) = 0;
virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1476,5 +1476,8 @@ struct OrtSessionOptions final {
const std::unordered_map<std::string, std::string>& GetConfigOptions() const {
return onnxruntime::g_host->SessionOptions__GetConfigOptionsMap(this);
}
bool GetEnableProfiling() const {
return onnxruntime::g_host->SessionOptions__GetEnableProfiling(this);
}
PROVIDER_DISALLOW_ALL(OrtSessionOptions)
};
52 changes: 46 additions & 6 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ struct OrtVitisAIEpAPI {
void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector<OrtCustomOpDomain*>& ret_domain);
std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>* (*compile_onnx_model_with_options)(
const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options);
std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>* (*compile_onnx_model_vitisai_ep_with_error_handling)(
const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func);
uint32_t (*vaip_get_version)();
void (*create_ep_context_nodes)(
const std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>& eps,
Expand Down Expand Up @@ -77,10 +79,11 @@ struct OrtVitisAIEpAPI {
ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_));
#endif
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep));
auto status = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options);
if (!status.IsOK()) {
::onnxruntime::LogRuntimeError(0, status, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__);
ORT_THROW(status);
auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_error_handling", (void**)&compile_onnx_model_vitisai_ep_with_error_handling);
auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options);
if ((!status1.IsOK()) && (!status2.IsOK())) {
::onnxruntime::LogRuntimeError(0, status2, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__);
ORT_THROW(status2);
}
std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version",
(void**)&vaip_get_version);
Expand All @@ -89,6 +92,14 @@ struct OrtVitisAIEpAPI {
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start));
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options));
}
void Clear() {
if (handle_) {
auto& env = Provider_GetHost()->Env__Default();
auto status = env.UnloadDynamicLibrary(handle_);
vai_assert(status.IsOK(), status.ErrorMessage());
handle_ = nullptr;
}
}

private:
void* handle_{};
Expand All @@ -109,10 +120,25 @@ void profiler_collect(
}
}

void change_status_with_error(void* status_ptr, int error_code, const char* error_msg) {
auto status = reinterpret_cast<Status*>(status_ptr);
*status = Status(onnxruntime::common::ONNXRUNTIME, error_code, error_msg);
}

vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(
const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) {
const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options) {
auto model_path = graph_viewer.ModelPath().string();
return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options));
if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling) {
Status status = Status::OK();
auto status_ptr = reinterpret_cast<void*>(&status);
auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling(model_path, graph_viewer.GetGraph(), options, status_ptr, change_status_with_error));
if (!status.IsOK()) {
ORT_THROW(status);
}
return ret;
} else {
return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options));
}
}

std::optional<std::vector<Node*>> create_ep_context_nodes(
Expand Down Expand Up @@ -396,10 +422,12 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape;
the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); };
the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; };
the_global_api.tensor_proto_new_i4 = vaip::tensor_proto_new_i4;
the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8;
the_global_api.tensor_proto_new_i16 = vaip::tensor_proto_new_i16;
the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32;
the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64;
the_global_api.tensor_proto_new_u4 = vaip::tensor_proto_new_u4;
the_global_api.tensor_proto_new_u8 = vaip::tensor_proto_new_u8;
the_global_api.tensor_proto_new_u16 = vaip::tensor_proto_new_u16;
the_global_api.tensor_proto_new_u32 = vaip::tensor_proto_new_u32;
Expand Down Expand Up @@ -468,9 +496,21 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
return vaip_core::DllSafe<std::string>(std::move(local_str));
};

the_global_api.is_profiling_enabled = [](void* session_options) {
auto options = reinterpret_cast<OrtSessionOptions*>(session_options);
return options->GetEnableProfiling();
};
the_global_api.graph_remove_initialized_tensor = [](Graph& graph, const std::string& tensor_name) {
graph.RemoveInitializedTensor(tensor_name);
};
if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
} else {
return &the_global_api;
}
}

void deinitialize_vitisai_ep() {
s_library_vitisaiep.Clear();
s_kernel_registry_vitisaiep.reset();
}
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/tensor_proto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, co
return tensor_proto.release();
}

ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<int8_t>& data) {
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT4,
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
}

ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<int8_t>& data) {
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8,
Expand All @@ -108,6 +114,13 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64,
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
}

ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<uint8_t>& data) {
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT4,
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
}

ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<uint8_t>& data) {
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT8,
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/tensor_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ namespace vaip {
gsl::span<const char> tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor);
vaip_core::DllSafe<std::vector<int64_t>> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor);
const std::string& tensor_proto_get_name(const ONNX_NAMESPACE::TensorProto& tensor);
ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<int8_t>& data);
ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<uint8_t>& data);
ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<int8_t>& data);
ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector<int64_t>& shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "vaip/custom_op.h"
#include <optional>
void initialize_vitisai_ep();
void deinitialize_vitisai_ep();
vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options);
std::shared_ptr<onnxruntime::KernelRegistry> get_kernel_registry_vitisaiep();
const std::vector<OrtCustomOpDomain*>& get_domains_vitisaiep();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/vitisai/include/vaip/my_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,5 @@ using InitializedTensorSet =
std::unordered_map<std::string, const TensorProto*>;

using ModelMetaData = std::unordered_map<std::string, std::string>;
using error_report_func = void (*)(void*, int, const char*);
} // namespace vaip_core
10 changes: 9 additions & 1 deletion onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct OrtApi;

namespace vaip_core {

#define VAIP_ORT_API_MAJOR (12u)
#define VAIP_ORT_API_MAJOR (13u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
Expand Down Expand Up @@ -235,6 +235,14 @@ struct OrtApiForVaip {
DllSafe<std::string> (*model_proto_serialize_as_string)(ModelProto& model_proto); // [96]
void (*model_proto_delete)(ModelProto* p); // [97]
DllSafe<std::string> (*attr_proto_release_string)(AttributeProto* attr); // [98]
bool (*is_profiling_enabled)(void* session_options); // [99] // [98]
TensorProto* (*tensor_proto_new_i4)(const std::string& name,
const std::vector<int64_t>& shape,
const std::vector<int8_t>& data); // [100]
TensorProto* (*tensor_proto_new_u4)(const std::string& name,
const std::vector<int64_t>& shape,
const std::vector<uint8_t>& data); // [101]
void (*graph_remove_initialized_tensor)(Graph& graph, const std::string& tensor_name); // [102]
};

#ifndef USE_VITISAI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct VitisAI_Provider : Provider {
// Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded
void Initialize() override { initialize_vitisai_ep(); }
// Called right before unloading the shared library
void Shutdown() override {}
void Shutdown() override { deinitialize_vitisai_ep(); }
} g_provider;

} // namespace onnxruntime
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ struct ProviderHostImpl : ProviderHost {

// OrtSessionOptions (wrapped)
const std::unordered_map<std::string, std::string>& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) override { return p->value.config_options.configurations; }
bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) override { return p->value.enable_profiling; };
// ComputeCapability (wrapped)
std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) override { return std::make_unique<ComputeCapability>(std::move(t_sub_graph)); }
void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; }
Expand Down

0 comments on commit ebdbbb7

Please sign in to comment.