diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index c900f4d4b09a5..2ead13e554197 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -189,7 +189,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_SNPE} ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} - ${PROVIDERS_VITISAI} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 0951c2d02664d..183a3e196af42 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -14,14 +14,19 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) - if(NOT MSVC) - target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) - endif(NOT MSVC) + onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} nlohmann_json::nlohmann_json safeint_interface flatbuffers::flatbuffers) + target_link_libraries(onnxruntime_providers_vitisai PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED}) + if(MSVC) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai dbghelp) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/vitisai/symbols.def") + else(MSVC) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/vitisai/version_script.lds -Xlinker --gc-sections") + endif(MSVC) target_include_directories(onnxruntime_providers_vitisai PRIVATE "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include" ${XRT_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/VitisAI) if(MSVC) @@ -30,17 +35,18 @@ target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4251") # for unused formal parameter target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4100") + # for type name first seen using 'class' now seen using 'struct' + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4099") else(MSVC) + target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) target_compile_options(onnxruntime_providers_vitisai PRIVATE -Wno-unused-parameter) endif(MSVC) set_target_properties(onnxruntime_providers_vitisai PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_providers_vitisai PROPERTIES LINKER_LANGUAGE CXX) - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_vitisai - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + install(TARGETS onnxruntime_providers_vitisai + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 2e3594f256f65..456344aa34d95 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -170,7 +170,6 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} ${PROVIDERS_TVM} - ${PROVIDERS_VITISAI} ${PROVIDERS_NNAPI} ${PROVIDERS_XNNPACK} ${PROVIDERS_COREML} @@ -852,6 +851,16 @@ if (onnxruntime_USE_DNNL) ) endif() +if (onnxruntime_USE_VITISAI) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${DNNL_DLL_PATH} $ + $ + $/onnxruntime/capi/ + ) +endif() + if (onnxruntime_USE_TENSORRT) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index fa395802d95ff..18a6955165bf3 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -591,7 +591,6 @@ set(ONNXRUNTIME_TEST_LIBS # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NNAPI} ${PROVIDERS_JS} - ${PROVIDERS_VITISAI} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b321b2b2bac27..6c8b5841284a1 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4558,6 +4558,23 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Append VitisAI provider to session options + * + * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 7a553f9f94006..ae4c4bef90c64 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -901,6 +901,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction + + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI + SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 957e849cf5d4d..23246adff254a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -885,6 +885,25 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Ope return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); + + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs) { diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 42a58097e1635..6a4ab6a3d2113 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -78,10 +78,6 @@ #include "core/providers/tvm/tvm_provider_factory_creator.h" #endif -#if defined(USE_VITISAI) -#include "core/providers/vitisai/vitisai_provider_factory_creator.h" -#endif - #if defined(USE_XNNPACK) #include "core/providers/xnnpack/xnnpack_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 53ba4874c643c..23404fe8d2ad1 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -95,12 +95,15 @@ enum OperatorStatus : int { }; // onnx Protobuf types (All of these are direct mappings to the onnx types except for the Repeated*Field ones which map to a Repeated*Field type) -struct int64s; // RepeatedField +struct int64s; // RepeatedField +struct float32s; // RepeatedField struct AttributeProto; struct GraphProto; struct ModelProto; struct NodeProto; struct SparseTensorProto; +struct StringStringEntryProto; +struct StringStringEntryProtos; // RepeatedPtrField struct TensorProto; struct TensorProtos; // RepeatedPtrField struct TensorShapeProto_Dimension; @@ -113,6 +116,9 @@ struct TypeProto_Sequence; struct TypeProto; struct ValueInfoProto; struct ValueInfoProtos; // RepeatedPtrField +struct InferenceContext; +class GraphInferencer; +using InferenceFunction = std::function; } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -248,6 +254,7 @@ constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; +constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider"; constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index a3155fe6b86cf..0bc7eb4512e27 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -496,6 +496,10 @@ template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } +Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) { + return g_host->UnpackInitializerData(tensor, model_path, unpacked_tensor); +} } // namespace utils diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 21c14ce784a38..d34ee0bb3b747 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -91,6 +91,7 @@ using HashValue = uint64_t; using NodeIndex = size_t; // We can't just reinterpret_cast this one, since it's an unordered_map of object BY VALUE (can't do anything by value on the real types) // using NodeAttributes = std::unordered_map; +using ModelMetaData = std::unordered_map; using InitializedTensorSet = std::unordered_map; @@ -201,6 +202,8 @@ struct ProviderHost { virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) = 0; + virtual Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) = 0; virtual uint16_t math__floatToHalf(float f) = 0; virtual float math__halfToFloat(uint16_t h) = 0; @@ -263,12 +266,32 @@ struct ProviderHost { virtual void logging__Capture__operator_delete(logging::Capture* p) noexcept = 0; virtual std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept = 0; + // Env + virtual Env& Env__Default() = 0; + // Utils::DataTypeUtils virtual const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) = 0; // int64s virtual int int64s__size(const ONNX_NAMESPACE::int64s* p) = 0; virtual const int64_t& int64s__Get(const ONNX_NAMESPACE::int64s* p, int index) = 0; + virtual void int64s__Reserve(ONNX_NAMESPACE::int64s* p, int size) = 0; + virtual const int64_t* int64s__data(const ONNX_NAMESPACE::int64s* p) = 0; + + // float32s + virtual void float32s__Reserve(ONNX_NAMESPACE::float32s* p, int size) = 0; + virtual const float* float32s__data(const ONNX_NAMESPACE::float32s* p) = 0; + virtual int float32s__size(const ONNX_NAMESPACE::float32s* p) = 0; + + // StringStringEntryProto + virtual std::string* StringStringEntryProto__mutable_key(ONNX_NAMESPACE::StringStringEntryProto* p) = 0; + virtual std::string* StringStringEntryProto__mutable_value(ONNX_NAMESPACE::StringStringEntryProto* p) = 0; + + // StringStringEntryProtos + virtual void StringStringEntryProtos__Clear(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* StringStringEntryProtos__Add(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) = 0; #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional @@ -285,6 +308,7 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; + virtual void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) = 0; #if !defined(DISABLE_SPARSE_TENSORS) // TypeProto_SparseTensor @@ -329,9 +353,17 @@ struct ProviderHost { virtual float AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p, int i) = 0; virtual const ::std::string& AttributeProto__strings(const ONNX_NAMESPACE::AttributeProto* p, int i) = 0; virtual const ONNX_NAMESPACE::int64s& AttributeProto__ints(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual const ONNX_NAMESPACE::float32s& AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual ONNX_NAMESPACE::int64s* AttributeProto__mutable_ints(ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual ONNX_NAMESPACE::float32s* AttributeProto__mutable_floats(ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual void AttributeProto__add_ints(ONNX_NAMESPACE::AttributeProto* p, int64_t size) = 0; + virtual void AttributeProto__add_floats(ONNX_NAMESPACE::AttributeProto* p, float size) = 0; + virtual void AttributeProto__add_strings(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& size) = 0; virtual int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual const ONNX_NAMESPACE::TensorProto& AttributeProto__t(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; + virtual void AttributeProto__set_f(ONNX_NAMESPACE::AttributeProto* p, const float& value) = 0; virtual void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) = 0; virtual const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; @@ -354,6 +386,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) = 0; + virtual std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) = 0; // ModelProto @@ -369,6 +402,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::GraphProto* ModelProto__mutable_graph(ONNX_NAMESPACE::ModelProto* p) = 0; virtual void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) = 0; // NodeProto virtual std::unique_ptr NodeProto__construct() = 0; @@ -383,19 +417,33 @@ struct ProviderHost { virtual void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) = 0; virtual void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) = 0; virtual bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) = 0; + virtual const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual int TensorProto__dims_size(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual const ONNX_NAMESPACE::int64s& TensorProto__dims(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) = 0; virtual bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) = 0; virtual int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) = 0; virtual void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* TensorProto__mutable_external_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_float_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_int32_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_string_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) = 0; virtual bool TensorProto_DataType_IsValid(int value) = 0; // TensorProtos virtual ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) = 0; // TensorShapeProto_Dimension virtual int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; @@ -405,6 +453,8 @@ struct ProviderHost { virtual bool TensorShapeProto_Dimension__has_dim_value(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; virtual bool TensorShapeProto_Dimension__has_dim_param(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; virtual void TensorShapeProto_Dimension__clear_dim_value(ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; + virtual const std::string& TensorShapeProto_Dimension__denotation(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) const = 0; + virtual void TensorShapeProto_Dimension__set_denotation(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, const std::string& value) = 0; // TensorShapeProto_Dimensions virtual std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) = 0; @@ -428,6 +478,8 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0; + virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; + // ConfigOptions virtual std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0; @@ -653,6 +705,7 @@ struct ProviderHost { virtual void Node__ToProto(const Node* p, ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) = 0; virtual const NodeAttributes& Node__GetAttributes(const Node* p) noexcept = 0; + virtual void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) = 0; virtual size_t Node__GetInputEdgesCount(const Node* p) noexcept = 0; virtual size_t Node__GetOutputEdgesCount(const Node* p) noexcept = 0; @@ -662,10 +715,13 @@ struct ProviderHost { virtual std::unique_ptr Node__OutputNodesBegin(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputNodesEnd(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__InputEdgesBegin(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__InputEdgesEnd(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputEdgesBegin(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputEdgesEnd(const Node* p) noexcept = 0; virtual void Node__ForEachDef(const Node* p, std::function func, bool include_missing_optional_defs) = 0; + virtual int Node__NodeType(const Node* p) const noexcept = 0; virtual const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) = 0; virtual std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const = 0; @@ -676,6 +732,7 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::NodeArgInfo& NodeArg__ToProto(const NodeArg* p) noexcept = 0; virtual bool NodeArg__Exists(const NodeArg* p) const noexcept = 0; virtual const ONNX_NAMESPACE::TypeProto* NodeArg__TypeAsProto(const NodeArg* p) noexcept = 0; + virtual Status NodeArg__OverrideTypesHelper(NodeArg* p, const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) = 0; // NodeAttributes virtual std::unique_ptr NodeAttributes__construct() = 0; @@ -693,12 +750,18 @@ struct ProviderHost { virtual std::unique_ptr NodeAttributes__find(const NodeAttributes* p, const std::string& key) = 0; virtual void NodeAttributes__insert(NodeAttributes* p, const NodeAttributes& v) = 0; virtual void NodeAttributes__emplace(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) = 0; + virtual void NodeAttributes__insert_or_assign(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) = 0; virtual void NodeAttributes__reserve(NodeAttributes* p, size_t size) = 0; // Model + virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, + const PathString& model_path, const logging::Logger& logger) = 0; virtual void Model__operator_delete(Model* p) = 0; virtual Graph& Model__MainGraph(Model* p) = 0; virtual std::unique_ptr Model__ToProto(Model* p) = 0; + virtual std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) = 0; + virtual const ModelMetaData& Model__MetaData(const Model* p) const noexcept = 0; + virtual Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) = 0; // Graph virtual std::unique_ptr Graph__CreateGraphViewer(const Graph* p) = 0; @@ -716,6 +779,7 @@ struct ProviderHost { virtual void Graph__SetOutputs(Graph* p, gsl::span outputs) = 0; virtual const std::vector& Graph__GetInputs(const Graph* p) noexcept = 0; + virtual std::vector Graph__Nodes(const Graph* p) = 0; virtual bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) = 0; virtual const Node* Graph__ParentNode(const Graph* p) const = 0; @@ -725,6 +789,26 @@ struct ProviderHost { virtual const Path& Graph__ModelPath(const Graph* p) const = 0; virtual const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0; virtual bool Graph__IsSubgraph(const Graph* p) = 0; + virtual const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const = 0; + virtual const Model& Graph__GetModel(const Graph* p) = 0; + virtual void Graph__ReverseDFSFrom(const Graph* p, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const = 0; + virtual Graph& Graph__SetGraphResolveNeeded(Graph* p) = 0; + virtual void Graph__RemoveInitializedTensor(Graph* p, const std::string& tensor_name) = 0; + + virtual std::vector Graph__GetConsumerNodes(const Graph* p, const std::string& node_arg_name) const = 0; + virtual void Graph__AddEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) = 0; + virtual void Graph__RemoveEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) = 0; + virtual void Graph__RemoveNode(Graph* p, NodeIndex index) = 0; + virtual Node& Graph__FuseSubGraph(Graph* p, const IndexedSubGraph& sub_graph, const std::string& fused_node_name) = 0; + virtual void Graph__UpdateProducerNode(Graph* p, const std::string& node_arg_name, NodeIndex node_index) = 0; + virtual const ONNX_NAMESPACE::TensorProto* Graph__GetConstantInitializer(const Graph* p, const std::string& name, bool check_outer_scope) const = 0; + virtual const InitializedTensorSet& Graph__GetAllInitializedTensors(const Graph* p) = 0; virtual int Graph__MaxNodeIndex(const Graph* p) const noexcept = 0; virtual Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept = 0; virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0; @@ -759,11 +843,14 @@ struct ProviderHost { virtual const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0; virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0; + virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; // Path virtual PathString Path__ToPathString(const Path* p) noexcept = 0; virtual const std::vector& Path__GetComponents(const Path* p) noexcept = 0; virtual bool Path__IsEmpty(const Path* p) noexcept = 0; + virtual std::unique_ptr Path__construct() = 0; + virtual void Path__operator_delete(ONNX_NAMESPACE::Path* p) = 0; // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index eaf8ef459cf00..5f8930b64fe90 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -52,11 +52,34 @@ namespace ONNX_NAMESPACE { struct int64s final { int size() const { return g_host->int64s__size(this); } const int64_t& Get(int index) const { return g_host->int64s__Get(this, index); } + const int64_t* data() const { return g_host->int64s__data(this); } const int64_t& operator[](int index) const { return Get(index); } - + void Reserve(int size) { g_host->int64s__Reserve(this, size); } PROVIDER_DISALLOW_ALL(int64s) }; +struct float32s final { + void Reserve(int size) { g_host->float32s__Reserve(this, size); } + const float* data() const { return g_host->float32s__data(this); } + int size() const { return g_host->float32s__size(this); } + PROVIDER_DISALLOW_ALL(float32s) +}; + +struct StringStringEntryProto final { + std::string* mutable_key() { return g_host->StringStringEntryProto__mutable_key(this); } + std::string* mutable_value() { return g_host->StringStringEntryProto__mutable_value(this); } + + PROVIDER_DISALLOW_ALL(StringStringEntryProto) +}; + +struct StringStringEntryProtos final { + void Clear() { g_host->StringStringEntryProtos__Clear(this); } + StringStringEntryProto* Add() { return g_host->StringStringEntryProtos__Add(this); } + int size() { return g_host->StringStringEntryProtos__size(this); } + StringStringEntryProto& at(int index) { return g_host->StringStringEntryProtos__at(this, index); } + + PROVIDER_DISALLOW_ALL(StringStringEntryProtos) +}; struct AttributeProto final { static std::unique_ptr Create() { return g_host->AttributeProto__construct(); } void operator=(const AttributeProto& v) { g_host->AttributeProto__operator_assign(this, v); } @@ -71,9 +94,18 @@ struct AttributeProto final { float floats(int i) const { return g_host->AttributeProto__floats(this, i); } const std::string& strings(int i) const { return g_host->AttributeProto__strings(this, i); } const int64s& ints() const { return g_host->AttributeProto__ints(this); } + const float32s& floats() const { return g_host->AttributeProto__floats(this); } + int64s* mutable_ints() { return g_host->AttributeProto__mutable_ints(this); } + float32s* mutable_floats() { return g_host->AttributeProto__mutable_floats(this); } + void add_ints(int64_t value) { g_host->AttributeProto__add_ints(this, value); } + void add_floats(float value) { g_host->AttributeProto__add_floats(this, value); } + void add_strings(const ::std::string& value) { g_host->AttributeProto__add_strings(this, value); } + int64_t i() const { return g_host->AttributeProto__i(this); } float f() const { return g_host->AttributeProto__f(this); } + const ONNX_NAMESPACE::TensorProto& t() const { return g_host->AttributeProto__t(this); } void set_s(const ::std::string& value) { return g_host->AttributeProto__set_s(this, value); } + void set_f(const float& value) { return g_host->AttributeProto__set_f(this, value); } void set_i(int64_t value) { return g_host->AttributeProto__set_i(this, value); } const ::std::string& s() const { return g_host->AttributeProto__s(this); } void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); } @@ -121,6 +153,8 @@ struct GraphProto final { NodeProto* add_node() { return g_host->GraphProto__add_node(this); } NodeProto* mutable_node(int index) { return g_host->GraphProto__mutable_node(this, index); } + std::string* mutable_name() { return g_host->GraphProto__mutable_name(this); } + GraphProto() = delete; GraphProto(const GraphProto&) = delete; }; @@ -133,7 +167,7 @@ struct ModelProto final { bool SerializeToOstream(std::ostream& output) const { return g_host->ModelProto__SerializeToOstream(this, output); } bool ParseFromString(const std::string& data) { return g_host->ModelProto__ParseFromString(this, data); } std::string SerializeAsString() const { return g_host->ModelProto__SerializeAsString(this); } - + StringStringEntryProtos* mutable_metadata_props() { return g_host->ModelProto__mutable_metadata_props(this); }; const GraphProto& graph() const { return g_host->ModelProto__graph(this); } GraphProto* mutable_graph() { return g_host->ModelProto__mutable_graph(this); } @@ -162,17 +196,22 @@ struct TensorProto final { void operator=(const TensorProto& v) { g_host->TensorProto__operator_assign(this, v); } bool has_name() const { return g_host->TensorProto__has_name(this); } + void set_name(const ::std::string& name) { return g_host->TensorProto__set_name(this, name); } + const ::std::string& name() const { return g_host->TensorProto__name(this); } int dims_size() const { return g_host->TensorProto__dims_size(this); } const int64s& dims() const { return g_host->TensorProto__dims(this); } + void add_dims(int64_t value) { g_host->TensorProto__add_dims(this, value); } bool has_data_location() const { return g_host->TensorProto__has_data_location(this); } TensorProto_DataLocation data_location() const { return TensorProto_DataLocation(g_host->TensorProto__data_location(this)); } bool has_raw_data() const { return g_host->TensorProto__has_raw_data(this); } const std::string& raw_data() const { return g_host->TensorProto__raw_data(this); } + std::string* mutable_raw_data() { return g_host->TensorProto__mutable_raw_data(this); } int32_t data_type() const { return g_host->TensorProto__data_type(this); } + void set_data_type(int32_t type) { return g_host->TensorProto__set_data_type(this, type); } typedef TensorProto_DataType DataType; static constexpr DataType UNDEFINED = TensorProto_DataType_UNDEFINED; @@ -180,6 +219,13 @@ struct TensorProto final { static bool DataType_IsValid(int value) { return g_host->TensorProto_DataType_IsValid(value); } void copy_from(const TensorProto* other) { return g_host->TensorProto__CopyFrom(this, other); } + StringStringEntryProtos* mutable_external_data() { return g_host->TensorProto__mutable_external_data(this); }; + void clear_float_data() { return g_host->TensorProto__clear_float_data(this); } + void clear_int32_data() { return g_host->TensorProto__clear_int32_data(this); } + void clear_string_data() { return g_host->TensorProto__clear_string_data(this); } + void clear_int64_data() { return g_host->TensorProto__clear_int64_data(this); } + void clear_double_data() { return g_host->TensorProto__clear_double_data(this); } + void clear_uint64_data() { return g_host->TensorProto__clear_uint64_data(this); } TensorProto() = delete; TensorProto(const TensorProto&) = delete; @@ -187,6 +233,8 @@ struct TensorProto final { struct TensorProtos final { TensorProto* Add() { return g_host->TensorProtos__Add(this); } + int size() { return g_host->TensorProtos__size(this); } + TensorProto& at(int index) { return g_host->TensorProtos__at(this, index); } PROVIDER_DISALLOW_ALL(TensorProtos) }; @@ -205,6 +253,8 @@ struct TensorShapeProto_Dimension final { bool has_dim_value() const { return g_host->TensorShapeProto_Dimension__has_dim_value(this); } bool has_dim_param() const { return g_host->TensorShapeProto_Dimension__has_dim_param(this); } void clear_dim_value() { return g_host->TensorShapeProto_Dimension__clear_dim_value(this); } + const std::string& denotation() const { return g_host->TensorShapeProto_Dimension__denotation(this); } + void set_denotation(const std::string& value) { g_host->TensorShapeProto_Dimension__set_denotation(this, value); } PROVIDER_DISALLOW_ALL(TensorShapeProto_Dimension) }; @@ -232,6 +282,7 @@ struct TypeProto_Tensor final { const TensorShapeProto& shape() const { return g_host->TypeProto_Tensor__shape(this); } TensorShapeProto* mutable_shape() { return g_host->TypeProto_Tensor__mutable_shape(this); } int32_t elem_type() const { return g_host->TypeProto_Tensor__elem_type(this); } + void set_elem_type(int32_t value) { g_host->TypeProto_Tensor__set_elem_type(this, value); } PROVIDER_DISALLOW_ALL(TypeProto_Tensor) }; @@ -315,7 +366,6 @@ struct ValueInfoProtos final { PROVIDER_DISALLOW_ALL(ValueInfoProtos) }; - } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -603,6 +653,10 @@ struct Function final { }; struct Node final { + enum class Type { + Primitive = 0, + Fused = 1, + }; const std::string& Name() const noexcept { return g_host->Node__Name(this); } const std::string& Description() const noexcept { return g_host->Node__Description(this); } const std::string& Domain() const noexcept { return g_host->Node__Domain(this); } @@ -626,6 +680,10 @@ struct Node final { void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const { return g_host->Node__ToProto(this, proto, update_subgraphs); } const NodeAttributes& GetAttributes() const noexcept { return g_host->Node__GetAttributes(this); } + void AddAttribute(const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) { + g_host->Node__AddAttribute(this, attr_name, value); + } + size_t GetInputEdgesCount() const noexcept { return g_host->Node__GetInputEdgesCount(this); } size_t GetOutputEdgesCount() const noexcept { return g_host->Node__GetOutputEdgesCount(this); } @@ -661,12 +719,15 @@ struct Node final { std::unique_ptr impl_; }; + EdgeConstIterator InputEdgesBegin() const noexcept { return g_host->Node__InputEdgesBegin(this); } + EdgeConstIterator InputEdgesEnd() const noexcept { return g_host->Node__InputEdgesEnd(this); } EdgeConstIterator OutputEdgesBegin() const noexcept { return g_host->Node__OutputEdgesBegin(this); } EdgeConstIterator OutputEdgesEnd() const noexcept { return g_host->Node__OutputEdgesEnd(this); } void ForEachDef(std::function func, bool include_missing_optional_defs = false) const { g_host->Node__ForEachDef(this, func, std::move(include_missing_optional_defs)); } const std::unordered_map>& GetAttributeNameToMutableSubgraphMap() { return g_host->Node__GetAttributeNameToMutableSubgraphMap(this); } std::unordered_map> GetAttributeNameToSubgraphMap() const { return g_host->Node__GetAttributeNameToSubgraphMap(this); } + Type NodeType() const noexcept { return Type(g_host->Node__NodeType(this)); } PROVIDER_DISALLOW_ALL(Node) }; @@ -678,6 +739,7 @@ struct NodeArg final { const NodeArgInfo& ToProto() const noexcept { return g_host->NodeArg__ToProto(this); } bool Exists() const noexcept { return g_host->NodeArg__Exists(this); } const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept { return g_host->NodeArg__TypeAsProto(this); } + Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) { return g_host->NodeArg__OverrideTypesHelper(this, input_type, input_tensor_elem_type, current_tensor_elem_type, override_types); } PROVIDER_DISALLOW_ALL(NodeArg) }; @@ -698,6 +760,8 @@ struct NodeAttributes final { IteratorHolder> find(const std::string& key) const { return g_host->NodeAttributes__find(this, key); } void insert(const NodeAttributes& v) { return g_host->NodeAttributes__insert(this, v); } void emplace(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__emplace(this, k, v); } + void insert_or_assign(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__insert_or_assign(this, k, v); } + void reserve(size_t size) { g_host->NodeAttributes__reserve(this, size); } NodeAttributes() = delete; @@ -705,11 +769,18 @@ struct NodeAttributes final { }; struct Model final { + static std::unique_ptr Create(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, + const logging::Logger& logger) { + return g_host->Model__construct(std::move(model_proto), model_path, logger); + } static void operator delete(void* p) { g_host->Model__operator_delete(reinterpret_cast(p)); } + static Status Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { return g_host->Model__Load(file_path, model_proto); } Graph& MainGraph() { return g_host->Model__MainGraph(this); } std::unique_ptr ToProto() { return g_host->Model__ToProto(this); } + std::unique_ptr ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) { return g_host->Model__ToGraphProtoWithExternalInitializers(this, external_file_name, file_path, initializer_size_threshold); } + const ModelMetaData& MetaData() const noexcept { return g_host->Model__MetaData(this); } Model() = delete; Model(const Model&) = delete; @@ -732,6 +803,7 @@ struct Graph final { void SetOutputs(gsl::span outputs) { return g_host->Graph__SetOutputs(this, outputs); } const std::vector& GetInputs() const noexcept { return g_host->Graph__GetInputs(this); } + std::vector Nodes() const noexcept { return g_host->Graph__Nodes(this); } bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { return g_host->Graph__GetInitializedTensor(this, tensor_name, value); } @@ -742,6 +814,37 @@ struct Graph final { const Path& ModelPath() const { return g_host->Graph__ModelPath(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); } bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); } + const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->Graph__GetProducerNode(this, node_arg_name); } + const Model& GetModel() const { return g_host->Graph__GetModel(this); } + void ReverseDFSFrom(gsl::span from, const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const { + g_host->Graph__ReverseDFSFrom(this, from, enter, leave, comp, stop); + } + Graph& SetGraphResolveNeeded() { return g_host->Graph__SetGraphResolveNeeded(this); } + void RemoveInitializedTensor(const std::string& tensor_name) { g_host->Graph__RemoveInitializedTensor(this, tensor_name); } + + std::vector GetConsumerNodes(const std::string& node_arg_name) const { + return g_host->Graph__GetConsumerNodes(this, node_arg_name); + } + void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index) { + g_host->Graph__AddEdge(this, src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index) { + g_host->Graph__RemoveEdge(this, src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void RemoveNode(NodeIndex index) { g_host->Graph__RemoveNode(this, index); } + Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) { + return g_host->Graph__FuseSubGraph(this, sub_graph, fused_node_name); + } + void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) { + g_host->Graph__UpdateProducerNode(this, node_arg_name, node_index); + } + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const { + return g_host->Graph__GetConstantInitializer(this, name, check_outer_scope); + } + const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return g_host->Graph__GetAllInitializedTensors(this); } int MaxNodeIndex() const noexcept { return g_host->Graph__MaxNodeIndex(this); } const Node* GetNode(NodeIndex node_index) const noexcept { return g_host->Graph__GetNode(this, node_index); } Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); } @@ -782,6 +885,7 @@ struct GraphViewer final { const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); } void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); } + const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } GraphViewer() = delete; GraphViewer(const GraphViewer&) = delete; @@ -789,11 +893,16 @@ struct GraphViewer final { }; struct Path final { + static std::unique_ptr Create() { return g_host->Path__construct(); } + static void operator delete(void* p) { g_host->Path__operator_delete(reinterpret_cast(p)); } + PathString ToPathString() const noexcept { return g_host->Path__ToPathString(this); } const std::vector& GetComponents() const noexcept { return g_host->Path__GetComponents(this); } bool IsEmpty() const noexcept { return g_host->Path__IsEmpty(this); } - PROVIDER_DISALLOW_ALL(Path) + Path() = delete; + Path(const Path&) = delete; + void operator=(const Path&) = delete; }; struct OpKernelContext final { diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc index 29bc886fb5ed4..1392ecef1b72d 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc @@ -2,126 +2,106 @@ // Licensed under the MIT License. #include "./attr_proto.h" -#include "./vai_assert.h" - #include #include #include #include -namespace vaip { +#include "core/providers/shared_library/provider_api.h" -ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, - int64_t value) { - auto ret = new onnx::AttributeProto(); +#include "./vai_assert.h" + +namespace vaip { +ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, int64_t value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_INT); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); ret->set_i(value); - return ret; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, - float value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, float value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_FLOAT); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); ret->set_f(value); - return ret; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_string( - const std::string& name, const std::string& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, const std::string& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_STRING); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); ret->set_s(value); - return ret; + return ret.release(); } ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor( const std::string& name, const ONNX_NAMESPACE::TensorProto& value) { - auto ret = new onnx::AttributeProto(); + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_TENSOR); - *ret->mutable_t() = value; - return ret; + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR); + *ret->add_tensors() = value; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints( - const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints(const std::string& name, const std::vector& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_INTS); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INTS); ret->mutable_ints()->Reserve((int)value.size()); for (auto v : value) { ret->add_ints(v); } - return ret; + return ret.release(); } - ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats( const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_FLOATS); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS); ret->mutable_floats()->Reserve((int)value.size()); for (auto v : value) { ret->add_floats(v); } - return ret; + return ret.release(); } - -ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings( - const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings(const std::string& name, const std::vector& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_STRINGS); - ret->mutable_strings()->Reserve((int)value.size()); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS); for (auto& v : value) { ret->add_strings(v); } - return ret; + return ret.release(); } - -int64_t attr_proto_get_int(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_INT, attr.DebugString()); +int64_t attr_proto_get_int(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT, attr.name()); return attr.i(); } - -float attr_proto_get_float(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_FLOAT, attr.DebugString()); +float attr_proto_get_float(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT, attr.name()); return attr.f(); } - -const std::string& attr_proto_get_string(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_STRING, attr.DebugString()); +const std::string& attr_proto_get_string(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRING, attr.name()); return attr.s(); } - -const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor( - const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_TENSOR, attr.DebugString()); +const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR, attr.name()); return attr.t(); } - -gsl::span attr_proto_get_ints(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_INTS, attr.DebugString()); +gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INTS, attr.name()); return gsl::span(attr.ints()); } - -gsl::span attr_proto_get_floats(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_FLOATS, attr.DebugString()); +gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS, attr.name()); return gsl::span(attr.floats()); } - -std::vector attr_proto_get_strings( - const ONNX_NAMESPACE::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_STRINGS, attr.DebugString()); - return std::vector(attr.strings().begin(), attr.strings().end()); -} - -ONNX_NAMESPACE::AttributeProto attr_proto_from_i64(const std::string& name, - int64_t value) { - ONNX_NAMESPACE::AttributeProto ret; - ret.set_name(name); - ret.set_i(value); +std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS, attr.name()); + std::vector ret; + ret.reserve(attr.strings_size()); + for (int i = 0; i < attr.strings_size(); i++) { + ret.push_back(attr.strings(i)); + } return ret; } - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.h b/onnxruntime/core/providers/vitisai/imp/attr_proto.h index 32ba8fa672d74..f4d56dd618a8c 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.h @@ -2,46 +2,26 @@ // Licensed under the MIT License. #pragma once #include - +#include "vaip/my_ort.h" #include "core/common/gsl.h" -#include "onnx/onnx_pb.h" namespace vaip { -ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, - int64_t value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, - float value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, - const std::string& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor( - const std::string& name, const ONNX_NAMESPACE::TensorProto& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints( - const std::string& name, const std::vector& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats( - const std::string& name, const std::vector& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings( - const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, int64_t value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, float value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, const std::string& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor(const std::string& name, const ONNX_NAMESPACE::TensorProto& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints(const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats(const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings(const std::string& name, const std::vector& value); /// attr_proto getters int64_t attr_proto_get_int(const ONNX_NAMESPACE::AttributeProto& attr); float attr_proto_get_float(const ONNX_NAMESPACE::AttributeProto& attr); -const std::string& attr_proto_get_string( - const ONNX_NAMESPACE::AttributeProto& attr); - -const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor( - const onnx::AttributeProto& attr); -gsl::span attr_proto_get_ints(const onnx::AttributeProto& attr); -gsl::span attr_proto_get_floats(const onnx::AttributeProto& attr); -std::vector attr_proto_get_strings( - const ONNX_NAMESPACE::AttributeProto& attr); - -/// attr_proto makers -ONNX_NAMESPACE::AttributeProto attr_proto_from_i64(const std::string& name, - int64_t); - -/// -using attr_proto_func_t = std::function; +const std::string& attr_proto_get_string(const ONNX_NAMESPACE::AttributeProto& attr); +const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::AttributeProto& attr); +gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr); +gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr); +std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/capability.cc b/onnxruntime/core/providers/vitisai/imp/capability.cc index a55180bd2ee5e..58522a45a151e 100644 --- a/onnxruntime/core/providers/vitisai/imp/capability.cc +++ b/onnxruntime/core/providers/vitisai/imp/capability.cc @@ -3,15 +3,10 @@ #include "vaip/capability.h" #include "./vai_assert.h" -#include "core/graph/basic_types.h" - -#include "./attr_proto.h" - namespace vaip { using namespace ::onnxruntime; -static std::vector node_names_to_nodes(const GraphViewer& graph, - const std::vector& node_names) { +static std::vector node_names_to_nodes(const GraphViewer& graph, const std::vector& node_names) { auto ret = std::vector(); ret.reserve(node_names.size()); for (auto& onnx_node_name : node_names) { @@ -24,53 +19,45 @@ static std::vector node_names_to_nodes(const GraphViewer& graph, } std::unique_ptr XirSubgraphToComputeCapability1(const onnxruntime::GraphViewer& graph, vaip_core::ExecutionProvider* ep, size_t index) { - auto meta_def = std::make_unique(); - meta_def->constant_initializers = *ep->get_meta_def_constant_initializer(); - meta_def->inputs = *ep->get_meta_def_inputs(); - meta_def->outputs = *ep->get_meta_def_outputs(); - auto indexed_subgraph = std::make_unique(); - auto indexed_subgraph_ptr = indexed_subgraph.get(); - indexed_subgraph_ptr->nodes = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); + auto meta_def = IndexedSubGraph_MetaDef::Create(); + meta_def->constant_initializers() = *ep->get_meta_def_constant_initializer(); + meta_def->inputs() = *ep->get_meta_def_inputs(); + meta_def->outputs() = *ep->get_meta_def_outputs(); + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); static auto g_counter = 1; - meta_def->name = std::string("vitis_ai_ep_") + std::to_string(g_counter++); - meta_def->domain = "com.xilinx"; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - auto index_proto = std::unique_ptr(vaip::attr_proto_new_int("index", (int64_t)index)); - meta_def->attributes["index"] = *index_proto; + meta_def->name() = std::string("vitis_ai_ep_") + std::to_string(g_counter++); + meta_def->domain() = "com.xilinx"; + meta_def->since_version() = 1; + meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; + auto index_proto = ONNX_NAMESPACE::AttributeProto::Create(); + index_proto->set_name("index"); + index_proto->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + index_proto->set_i(index); + meta_def->attributes()["index"] = *index_proto; indexed_subgraph->SetMetaDef(std::move(meta_def)); - return std::make_unique(std::move(indexed_subgraph)); + return ComputeCapability::Create(std::move(indexed_subgraph)); } std::vector> GetComputeCapabilityOps(const onnxruntime::GraphViewer& graph, vaip_core::DllSafe>>* eps, - const std::set& all_not_support_optypes) { - std::set all_compute_capability_nodes; + const std::set& all_support_optypes_by_eps) { + std::set all_nodes_included_eps; for (auto& ep : **eps) { - auto nodes = *ep->get_meta_def_nodes(); - for (auto n : nodes) - all_compute_capability_nodes.insert(n); + auto nodes = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); + all_nodes_included_eps.insert(nodes.begin(), nodes.end()); } + + std::vector node_indexs = graph.GetNodesInTopologicalOrder(); + node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_nodes_included_eps.count(index) > 0; }), node_indexs.end()); + node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_support_optypes_by_eps.count(graph.GetNode(index)->OpType()) == 0; }), node_indexs.end()); + std::vector> result; - for (auto& n : graph.Nodes()) { - if ((!all_compute_capability_nodes.count(n.Name())) && all_not_support_optypes.count(n.OpType())) { - auto meta_def = std::make_unique(); - meta_def->name = n.OpType(); - meta_def->domain = n.Domain(); - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - auto indexed_subgraph = std::make_unique(); - indexed_subgraph->nodes.push_back(n.Index()); - for (auto i : n.InputDefs()) { - meta_def->inputs.push_back(i->Name()); - } - for (auto i : n.OutputDefs()) { - meta_def->outputs.push_back(i->Name()); - } - indexed_subgraph->SetMetaDef(std::move(meta_def)); - result.emplace_back(std::make_unique(std::move(indexed_subgraph))); - } + for (auto& n : node_indexs) { + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = {n}; + result.emplace_back(ComputeCapability::Create(std::move(indexed_subgraph))); } return result; } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index b629c8eff9097..f609d40f459b7 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -1,20 +1,18 @@ - // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. + #include "vaip/global_api.h" #include +#include +#include #include #include "./vai_assert.h" -#include "core/common/exceptions.h" -#include "core/common/logging/logging.h" +#include "core/common/exceptions.h" #include "core/framework/error_code_helper.h" - -#include "core/graph/model.h" -#include "core/session/ort_env.h" -#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/shared/common.h" #include @@ -55,16 +53,14 @@ struct OrtVitisAIEpAPI { std::vector>* (*compile_onnx_model_with_options)( const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); void Ensure() { - if (handle_) return; - auto full_path = Env::Default().GetRuntimePath() + - PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); - ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true, &handle_)); - ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary( - handle_, "initialize_onnxruntime_vitisai_ep", reinterpret_cast(&initialize_onnxruntime_vitisai_ep))); - auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", - reinterpret_cast(&compile_onnx_model_with_options)); - auto status2 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", - reinterpret_cast(&compile_onnx_model_3)); + if (handle_) + return; + auto& env = Provider_GetHost()->Env__Default(); + auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep)); + auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); + auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", (void**)&compile_onnx_model_3); if (!status1.IsOK() && !status2.IsOK()) { ::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast(__FUNCTION__), __LINE__); ORT_THROW(status1); @@ -76,6 +72,12 @@ struct OrtVitisAIEpAPI { }; static OrtVitisAIEpAPI s_library_vitisaiep; +static std::shared_ptr s_kernel_registry_vitisaiep; +static std::vector s_domains_vitisaiep; +static vaip_core::OrtApiForVaip the_global_api; +std::shared_ptr get_kernel_registry_vitisaiep() { return s_kernel_registry_vitisaiep; } +const std::vector& get_domains_vitisaiep() { return s_domains_vitisaiep; } + static std::string config_to_json_str(const onnxruntime::ProviderOptions& config) { auto iter = config.find("config_file"); if (iter == config.end()) { @@ -105,121 +107,142 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config return ""; } } -vaip_core::DllSafe>> compile_onnx_model_with_options( - const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { + +vaip_core::DllSafe>> compile_onnx_model( + const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { +#ifndef _WIN32 + auto model_path = graph_viewer.ModelPath().ToPathString(); +#else + using convert_t = std::codecvt_utf8; + std::wstring_convert strconverter; + auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); +#endif if (s_library_vitisaiep.compile_onnx_model_with_options) { - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph, options)); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); } else { auto json_str = config_to_json_str(options); - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph, json_str.c_str())); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph_viewer.GetGraph(), json_str.c_str())); } } -std::vector initialize_vitisai_ep() { - s_library_vitisaiep.Ensure(); - Status status = Status::OK(); - try { - OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, - "onnxruntime-vitisai-ep"}; - std::ignore = OrtEnv::GetInstance(lm_info, status); - } catch (onnxruntime::OnnxRuntimeException& /*e*/) { +struct MyCustomOpKernel : OpKernel { + MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { + op_kernel_ = + op_.CreateKernel(&op_, Ort::Global::api_, reinterpret_cast(&info)); } - auto domains = std::vector(); - domains.reserve(100); - s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); - auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); - if (domainToVersionRangeInstance.Map().find("com.xilinx") == domainToVersionRangeInstance.Map().end()) { - vaip::register_xir_ops(domains); + + ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } + + Status Compute(OpKernelContext* ctx) const override { + op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); + return Status::OK(); } - return domains; + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MyCustomOpKernel); + + const OrtCustomOp& op_; + void* op_kernel_; +}; + +void create_kernel_registry(std::vector domains) { + s_kernel_registry_vitisaiep = KernelRegistry::Create(); + for (const auto& domain : domains) { + for (const auto* op : domain->custom_ops_) { + auto def_builder = KernelDefBuilder::Create(); + def_builder->SetName(op->GetName(op)); + def_builder->SetDomain(domain->domain_.c_str()); + def_builder->SinceVersion(1); + if (op->version > 12) { + auto input_count = op->GetInputTypeCount(op); + for (auto i = 0u; i < input_count; i++) { + def_builder->InputMemoryType(op->GetInputMemoryType(op, i), i); + } + } + def_builder->Provider(onnxruntime::kVitisAIExecutionProvider); + KernelCreateFn kernel_create_fn = + [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + // out = std::make_unique(info, *op); + return Status::OK(); + }; + std::ignore = s_kernel_registry_vitisaiep->Register(KernelCreateInfo(def_builder->Build(), kernel_create_fn)); + } + } +} +void initialize_vitisai_ep() { + s_library_vitisaiep.Ensure(); + s_domains_vitisaiep.reserve(100); + s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), s_domains_vitisaiep); + vaip::register_xir_ops(s_domains_vitisaiep); + create_kernel_registry(s_domains_vitisaiep); } -static vaip_core::OrtApiForVaip the_global_api; vaip_core::OrtApiForVaip* create_org_api_hook() { + InitProviderOrtApi(); + the_global_api.host_ = Provider_GetHost(); assert(Ort::Global::api_ != nullptr); the_global_api.ort_api_ = Ort::Global::api_; the_global_api.model_load = [](const std::string& filename) -> Model* { - ONNX_NAMESPACE::ModelProto model_proto; + auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); auto& logger = logging::LoggingManager::DefaultLogger(); auto file_path = ToPathString(filename); - auto status = Model::Load(file_path, model_proto); + auto status = Model::Load(file_path, *model_proto); vai_assert(status.IsOK(), "load model proto error"); - auto model = std::make_unique(std::move(model_proto), file_path, nullptr, logger); + auto model = Model::Create(std::move(*model_proto), file_path, logger); return model.release(); }; the_global_api.model_delete = [](Model* model) { delete model; }; - the_global_api.model_clone = [](const Model& model) -> Model* { + + the_global_api.model_clone = [](const Model& const_model) -> Model* { auto& logger = logging::LoggingManager::DefaultLogger(); - auto model_proto = const_cast(model).ToProto(); - auto file_path = model.ModelPath().ToPathString(); - auto ret = std::make_unique(std::move(model_proto), file_path, nullptr, logger); + auto& model = const_cast(const_model); + auto model_proto = model.ToProto(); + auto file_path = model.MainGraph().ModelPath().ToPathString(); + auto ret = Model::Create(std::move(*model_proto), file_path, logger); auto status = ret->MainGraph().Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); }; - the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) -> void { + the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) { const_cast(model.MetaData())[key] = value; }; - the_global_api.model_get_meta_data = [](const Model& model, - const std::string& key) -> vaip_core::DllSafe { - auto& m = model.MetaData(); - auto it = m.find(key); - auto ret = std::string(); - if (it != m.end()) { - ret = it->second; + the_global_api.model_get_meta_data = + [](const Model& model, const std::string& key) -> vaip_core::DllSafe { + if (model.MetaData().count(key)) { + return vaip_core::DllSafe(model.MetaData().at(key)); } - return vaip_core::DllSafe(ret); + return vaip_core::DllSafe(std::string()); }; - the_global_api.model_has_meta_data = [](const Model& model, const std::string& key) -> int { - auto& m = model.MetaData(); - return m.find(key) != m.end() ? 1 : 0; + return int(model.MetaData().count(key)); }; - the_global_api.model_main_graph = [](Model& model) -> Graph& { return model.MainGraph(); }; the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { return graph.GetModel(); }; - the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { - auto ret = std::vector(); - auto inputs = graph.GetInputs(); - for (auto input : inputs) { - vai_assert(input->Exists(), input->Name()); - ret.push_back(input); - } - return vaip_core::DllSafe(std::move(ret)); + the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> auto { + return vaip_core::DllSafe(graph.GetInputs()); }; - the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.GetOutputs()); }; - - the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) -> void { - return graph.SetOutputs(outputs); + the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) { + graph.SetOutputs(outputs); }; - the_global_api.graph_get_node_arg = [](const Graph& graph, const std::string& name) -> const NodeArg* { return graph.GetNodeArg(name); }; the_global_api.graph_producer_node = [](const Graph& graph, const std::string& name) -> const Node* { return graph.GetProducerNode(name); }; - - the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { return graph.GetNode(index); }; - + the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { + return graph.GetNode(index); + }; the_global_api.graph_save = vaip::graph_save; the_global_api.graph_fuse = vaip::graph_fuse; the_global_api.graph_remove_node = vaip::graph_remove_node; - the_global_api.graph_add_node = [](Graph& graph, const std::string& name, const std::string& op_type, - const std::string& description, const std::vector& input_args, - const std::vector& output_args, - vaip_core::NodeAttributes& attributes, const std::string& domain) -> Node& { - return vaip::graph_add_node(graph, name, op_type, description, input_args, output_args, - std::move(reinterpret_cast(attributes)), domain); - }; - + the_global_api.graph_add_node = vaip::graph_add_node; the_global_api.graph_get_all_initialized_tensors = [](const Graph& graph) -> const InitializedTensorSet& { return graph.GetAllInitializedTensors(); }; - the_global_api.graph_resolve = [](Graph& graph, bool force) { if (force) { graph.SetGraphResolveNeeded(); @@ -227,129 +250,57 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { auto status = graph.Resolve(); return status.Code(); }; - - the_global_api.graph_get_consumer_nodes_unsafe = - [](const Graph& graph, const std::string& node_arg_name) -> vaip_core::DllSafe> { + the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto { return vaip_core::DllSafe(graph.GetConsumerNodes(node_arg_name)); }; - the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { - auto& node_refererence = graph.Nodes(); - std::vector nodes(static_cast(graph.NumberOfNodes()), nullptr); - std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); - return vaip_core::DllSafe(std::move(nodes)); - }; + the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.Nodes()); }; the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, - const std::function& enter, - const std::function& leave, - const std::function& stop) { + const auto& enter, const auto& leave, const auto& stop) { graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; - the_global_api.node_op_type = [](const Node& node) -> const std::string& { return node.OpType(); }; the_global_api.node_op_domain = [](const Node& node) -> const std::string& { return node.Domain(); }; - the_global_api.node_get_index = [](const Node& node) -> size_t { return static_cast(node.Index()); }; + the_global_api.node_get_index = [](const Node& node) -> size_t { return node.Index(); }; the_global_api.node_get_name = [](const Node& node) -> const std::string& { return node.Name(); }; the_global_api.node_description = [](const Node& node) -> const std::string& { return node.Description(); }; - - the_global_api.node_get_attributes = [](Node& node) -> vaip_core::NodeAttributes& { - return reinterpret_cast(node.GetMutableAttributes()); - }; - - the_global_api.node_type_is_fused = [](const Node& node) { - return node.NodeType() == onnxruntime::Node::Type::Fused; + the_global_api.node_get_attributes = [](Node& node) -> NodeAttributes& { + return const_cast(node.GetAttributes()); }; - the_global_api.node_get_function_body = [](const Node& node) -> const onnxruntime::Graph& { + the_global_api.node_type_is_fused = [](const Node& node) { return node.NodeType() == Node::Type::Fused; }; + the_global_api.node_get_function_body = [](const Node& node) -> const auto& { assert(node.GetFunctionBody() != nullptr); return node.GetFunctionBody()->Body(); }; // node_arg - the_global_api.node_arg_get_name_unsafe = [](const NodeArg& node_arg) -> const std::string& { - return node_arg.Name(); - }; + the_global_api.node_arg_get_name_unsafe = + [](const NodeArg& node_arg) -> const std::string& { return node_arg.Name(); }; the_global_api.node_arg_clone = vaip::node_arg_clone; the_global_api.node_arg_new = vaip::node_arg_new; - the_global_api.node_arg_is_exists = vaip::node_arg_is_exists; + the_global_api.node_arg_is_exists = [](const NodeArg& node_arg) { return node_arg.Exists(); }; the_global_api.node_arg_is_constant = vaip::node_arg_is_constant; the_global_api.node_arg_get_shape_i64_unsafe = vaip::node_arg_get_shape_i64; the_global_api.node_arg_set_shape_i64 = vaip::node_arg_set_shape_i64; the_global_api.node_arg_get_denotation_unsafe = vaip::node_arg_get_denotation; + the_global_api.node_arg_set_denotation = vaip::node_arg_set_denotation; the_global_api.node_arg_get_const_data_as_tensor = vaip::node_arg_get_const_data_as_tensor; the_global_api.node_arg_get_element_type = vaip::node_arg_get_element_type; - the_global_api.node_arg_set_element_type = [](NodeArg& node_arg, int type) { - auto data_type = ONNX_NAMESPACE::TensorProto::UNDEFINED; - switch (type) { - case 1: - data_type = ONNX_NAMESPACE::TensorProto::FLOAT; - break; - case 2: - data_type = ONNX_NAMESPACE::TensorProto::UINT8; - break; - case 3: - data_type = ONNX_NAMESPACE::TensorProto::INT8; - break; - - case 4: - data_type = ONNX_NAMESPACE::TensorProto::UINT16; - break; - case 5: - data_type = ONNX_NAMESPACE::TensorProto::INT16; - break; - case 6: - data_type = ONNX_NAMESPACE::TensorProto::INT32; - break; - case 7: - data_type = ONNX_NAMESPACE::TensorProto::INT64; - break; - case 8: - data_type = ONNX_NAMESPACE::TensorProto::STRING; - break; - case 9: - data_type = ONNX_NAMESPACE::TensorProto::BOOL; - break; - case 10: - data_type = ONNX_NAMESPACE::TensorProto::FLOAT16; - break; - case 11: - data_type = ONNX_NAMESPACE::TensorProto::DOUBLE; - break; - case 12: - data_type = ONNX_NAMESPACE::TensorProto::UINT32; - break; - case 13: - data_type = ONNX_NAMESPACE::TensorProto::UINT64; - break; - case 14: - data_type = ONNX_NAMESPACE::TensorProto::COMPLEX64; - break; - case 15: - data_type = ONNX_NAMESPACE::TensorProto::COMPLEX128; - break; - case 16: - data_type = ONNX_NAMESPACE::TensorProto::BFLOAT16; - break; - default: - vai_assert(false, "TensorProto::DataType not supoort"); - } - return vaip::node_arg_set_element_type(node_arg, data_type); - }; + the_global_api.node_arg_set_element_type = vaip::node_arg_set_element_type; /// attr proto - the_global_api.attr_proto_delete = [](onnx::AttributeProto* v) { delete v; }; - the_global_api.attr_proto_clone = [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { - return new onnx::AttributeProto(v); - }; - the_global_api.attr_proto_get_name = [](const onnx::AttributeProto& attr_proto) -> const std::string& { - return attr_proto.name(); - }; - the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, const std::string& name) { - attr_proto->set_name(name); + the_global_api.attr_proto_delete = [](ONNX_NAMESPACE::AttributeProto* v) { delete v; }; + the_global_api.attr_proto_clone = [](const ONNX_NAMESPACE::AttributeProto& v) -> ONNX_NAMESPACE::AttributeProto* { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); + *ret = v; + return ret.release(); }; + the_global_api.attr_proto_get_name = [](const auto& attr_proto) -> const std::string& { return attr_proto.name(); }; + the_global_api.attr_proto_set_name = [](auto* attr_proto, const auto& name) { attr_proto->set_name(name); }; the_global_api.attr_proto_new_int = vaip::attr_proto_new_int; the_global_api.attr_proto_new_float = vaip::attr_proto_new_float; the_global_api.attr_proto_new_string = vaip::attr_proto_new_string; @@ -364,31 +315,24 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.attr_proto_get_ints = vaip::attr_proto_get_ints; the_global_api.attr_proto_get_floats = vaip::attr_proto_get_floats; the_global_api.attr_proto_get_strings = vaip::attr_proto_get_strings; - the_global_api.attr_proto_get_type = [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; + the_global_api.attr_proto_get_type = [](const ONNX_NAMESPACE::AttributeProto& attr) -> int { return attr.type(); }; /// node attributes - the_global_api.node_attributes_new = []() { - return reinterpret_cast(new NodeAttributes()); - }; - the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, onnx::AttributeProto&& attr) { - reinterpret_cast(p).insert_or_assign(attr.name(), std::move(attr)); + the_global_api.node_attributes_new = []() { return NodeAttributes::Create().release(); }; + the_global_api.node_attributes_add = [](NodeAttributes& p, ONNX_NAMESPACE::AttributeProto&& attr) { + p.insert_or_assign(attr.name(), std::move(attr)); }; - the_global_api.node_attributes_delete = [](vaip_core::NodeAttributes* p) { - delete reinterpret_cast(p); - }; - the_global_api.node_attributes_get = [](vaip_core::NodeAttributes& p, - const std::string& name) -> ONNX_NAMESPACE::AttributeProto* { - auto& attr = reinterpret_cast(p); - auto it = attr.find(name); - if (it == attr.end()) { - return nullptr; + + the_global_api.node_attributes_delete = [](NodeAttributes* p) { delete p; }; + the_global_api.node_attributes_get = + [](const NodeAttributes& attr, const std::string& name) -> const ONNX_NAMESPACE::AttributeProto* { + if (attr.count(name)) { + return &attr.at(name); } - return &it->second; + return nullptr; }; - the_global_api.node_attributes_get_keys = - [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { + the_global_api.node_attributes_get_keys = [](NodeAttributes& attr) -> vaip_core::DllSafe> { auto ret = std::vector(); - auto& attr = reinterpret_cast(p); ret.reserve(attr.size()); for (auto& it : attr) { ret.push_back(it.first); @@ -396,35 +340,16 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(ret)); }; /// tensor proto - the_global_api.tensor_proto_get_shape_unsafe = - [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { - return vaip_core::DllSafe>(vaip::tensor_proto_get_shape(t)); - }; - - the_global_api.tensor_proto_data_type = [](const onnx::TensorProto& t) -> int { return t.data_type(); }; - - the_global_api.tensor_proto_delete = [](onnx::TensorProto* tp) { delete tp; }; - - the_global_api.tensor_proto_new_floats = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_floats(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i32 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i32(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i64 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i64(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i8 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i8(name, shape, data)}; - }; - the_global_api.tensor_proto_raw_data_size = vaip::tensor_proto_raw_data_size; - + the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape; + the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); }; + the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; }; + the_global_api.tensor_proto_new_floats = vaip::tensor_proto_new_floats; + the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32; + the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64; + the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8; + the_global_api.tensor_proto_raw_data_size = [](const auto& tensor) { return tensor.raw_data().size(); }; the_global_api.tensor_proto_as_raw = vaip::tensor_proto_as_raw; - the_global_api.tensor_proto_get_name = vaip::tensor_proto_get_name; + the_global_api.tensor_proto_get_name = [](const auto& tensor) -> const std::string& { return tensor.name(); }; the_global_api.get_lib_name = []() -> vaip_core::DllSafe { return vaip_core::DllSafe(std::string("onnxruntime.") + std::string(ORT_VERSION)); diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index cca680baf7dc0..061bc414fcec7 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -2,27 +2,15 @@ // Licensed under the MIT License. #include "vaip/graph.h" -#include - -#include "./vai_assert.h" #include #include #include #include #include #include -#include "onnx/onnx-ml.pb.h" -#ifdef _MSC_VER -#pragma warning(push) -// 'type' : forcing value to bool 'true' or 'false' (performance warning) -#pragma warning(disable : 4800) -#endif -#include -#ifdef _MSC_VER -#pragma warning(pop) -#endif -using convert_t = std::codecvt_utf8; -std::wstring_convert strconverter; + +#include "core/providers/shared_library/provider_api.h" +#include "./vai_assert.h" #include "vaip/node.h" #include "vaip/node_arg.h" @@ -38,23 +26,14 @@ struct NodeEdgeT { static void graph_remove_node(Graph& graph, const Node& node) { auto remove_edges = std::vector(); - auto begin = node.InputEdgesBegin(); - auto end = node.InputEdgesEnd(); - for (auto it = begin; it != end; ++it) { - remove_edges.push_back(NodeEdgeT{it->GetNode().Index(), node.Index(), - it->GetSrcArgIndex(), - it->GetDstArgIndex()}); + for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); ++it) { + remove_edges.push_back(NodeEdgeT{it->GetNode().Index(), node.Index(), it->GetSrcArgIndex(), it->GetDstArgIndex()}); } - begin = node.OutputEdgesBegin(); - end = node.OutputEdgesEnd(); - for (auto it = begin; it != end; ++it) { - remove_edges.push_back(NodeEdgeT{node.Index(), it->GetNode().Index(), - it->GetSrcArgIndex(), - it->GetDstArgIndex()}); + for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); ++it) { + remove_edges.push_back(NodeEdgeT{node.Index(), it->GetNode().Index(), it->GetSrcArgIndex(), it->GetDstArgIndex()}); } for (auto it : remove_edges) { - graph.RemoveEdge(it.src_node_index, it.dst_node_index, it.src_arg_index, - it.dst_arg_index); + graph.RemoveEdge(it.src_node_index, it.dst_node_index, it.src_arg_index, it.dst_arg_index); } graph.RemoveNode(node.Index()); } @@ -68,13 +47,9 @@ static std::vector node_get_implicit_input_node_args(const Node& } return ret; } - -Node& graph_add_node(Graph& graph, const std::string& name, - const std::string& op_type, const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const NodeAttributes& attributes, - const std::string& domain) { +Node& graph_add_node(Graph& graph, const std::string& name, const std::string& op_type, const std::string& description, + const std::vector& input_args, const std::vector& output_args, + const NodeAttributes& attributes, const std::string& domain) { std::vector inputs; inputs.reserve(input_args.size()); for (auto i : input_args) { @@ -85,8 +60,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto i : output_args) { outputs.push_back(const_cast(i)); } - auto& ret = graph.AddNode(name, op_type, description, inputs, outputs, - &attributes, domain); + auto& ret = graph.AddNode(name, op_type, description, inputs, outputs, &attributes, domain); auto src_arg_index = 0; for (auto& o : outputs) { auto consumers = graph.GetConsumerNodes(o->Name()); @@ -96,8 +70,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto ni : *tmp_inputs) { auto name1 = ni.node_arg->Name(); if (name1 == o->Name()) { - graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, - dst_arg_index); + graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, dst_arg_index); } dst_arg_index = dst_arg_index + 1; } @@ -105,8 +78,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto implicit_node_arg : node_get_implicit_input_node_args(*consumer)) { auto name1 = implicit_node_arg->Name(); if (name1 == o->Name()) { - graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, - dst_arg_index); + graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, dst_arg_index); } dst_arg_index = dst_arg_index + 1; } @@ -132,44 +104,39 @@ void graph_remove_node(Graph& graph, const NodeInput& node_input) { void graph_save(const Graph& graph, const std::string& filename, const std::string& filename_dat, size_t initializer_size_threshold) { auto& model = const_cast(graph.GetModel()); - auto model_proto = ONNX_NAMESPACE::ModelProto(); + std::unique_ptr model_proto; if (initializer_size_threshold == std::numeric_limits::max()) { model_proto = model.ToProto(); } else { - model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, - ToPathString(filename), - initializer_size_threshold); + model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, graph.ModelPath().ToPathString(), initializer_size_threshold); } auto& metadata = model.MetaData(); if (!metadata.empty()) { - model_proto.mutable_metadata_props()->Clear(); + auto metadata_props = model_proto->mutable_metadata_props(); + metadata_props->Clear(); for (auto& m : metadata) { - auto prop = model_proto.mutable_metadata_props()->Add(); + auto prop = metadata_props->Add(); *prop->mutable_key() = m.first; *prop->mutable_value() = m.second; } } // use relative path as data storage. - auto graph_proto = model_proto.mutable_graph(); - *graph_proto = graph.ToGraphProto(); - for (auto i = 0; i < graph_proto->initializer_size(); ++i) { - auto initializer = graph_proto->mutable_initializer(i); - for (auto j = 0; j < initializer->external_data_size(); ++j) { - auto external_data = initializer->mutable_external_data(j); - if (external_data->key() == "location") { - *external_data->mutable_value() = std::filesystem::path(external_data->value()).filename().u8string(); - } + auto graph_proto = model_proto->mutable_graph(); + *graph_proto = *graph.ToGraphProto(); + for (int i = 0; i < graph_proto->mutable_initializer()->size(); i++) { + auto mutable_external_data = graph_proto->mutable_initializer()->at(i).mutable_external_data(); + for (int j = 0; j < mutable_external_data->size(); j++) { + auto& external_data = mutable_external_data->at(j); + if (*external_data.mutable_key() == "location") + *external_data.mutable_value() = std::filesystem::path(*external_data.mutable_value()).filename().u8string(); } } - int fd = -1; - Status status = Env::Default().FileOpenWr(filename, fd); - vai_assert(status.IsOK(), status.ErrorMessage()); - google::protobuf::io::FileOutputStream output(fd); - const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); - vai_assert(result, "model serialize to zero cipy stream error"); - status = Env::Default().FileClose(fd); - vai_assert(status.IsOK(), status.ErrorMessage()); + + std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary); + bool result = model_proto->SerializeToOstream(output); + output << std::flush; + vai_assert(result, "model serialize to ostream error"); } Node& graph_fuse(Graph& graph, const std::string& name, @@ -178,25 +145,25 @@ Node& graph_fuse(Graph& graph, const std::string& name, const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers) { - auto meta_def = std::make_unique(); - auto indexed_subgraph = std::make_unique(); - indexed_subgraph->nodes = nodes; - meta_def->inputs = inputs; - meta_def->outputs = outputs; - meta_def->constant_initializers = constant_initializers; - meta_def->name = "super_layer"; - meta_def->domain = "com.xilinx"; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + auto meta_def = IndexedSubGraph_MetaDef::Create(); + meta_def->inputs() = inputs; + meta_def->outputs() = outputs; + meta_def->constant_initializers() = constant_initializers; + meta_def->name() = "super_layer"; + meta_def->domain() = "com.xilinx"; + meta_def->since_version() = 1; + meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; + + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = nodes; indexed_subgraph->SetMetaDef(std::move(meta_def)); + auto& fused_node = graph.FuseSubGraph(*indexed_subgraph, name); auto function_body = fused_node.GetFunctionBody(); if (function_body) { - auto& mygraph = function_body->Body(); - // auto proto = graph.ToGraphProtoWithExternal("exteranl.dat", 128); - auto proto = mygraph.ToGraphProto(); - *proto.mutable_name() = name; - fused_node.AddAttribute("body", proto); + auto proto = function_body->Body().ToGraphProto(); + *proto->mutable_name() = name; + fused_node.AddAttribute("body", *proto); } for (auto&& o : fused_node.OutputDefs()) { graph.UpdateProducerNode(o->Name(), fused_node.Index()); diff --git a/onnxruntime/core/providers/vitisai/imp/node.cc b/onnxruntime/core/providers/vitisai/imp/node.cc index 6d65ad4e8c408..0565171fb7f40 100644 --- a/onnxruntime/core/providers/vitisai/imp/node.cc +++ b/onnxruntime/core/providers/vitisai/imp/node.cc @@ -4,9 +4,8 @@ #include "./vai_assert.h" #include "attr_proto.h" -#include "core/graph/graph_utils.h" -#include "core/graph/node_arg.h" #include "vaip/node_arg.h" +#include "core/providers/shared_library/provider_api.h" namespace vaip { @@ -29,7 +28,6 @@ vaip_core::DllSafe> node_get_inputs(const Node& node) { } return vaip_core::DllSafe(ret); } - vaip_core::DllSafe> node_get_output_node_args(const Node& node) { auto outputs = node.OutputDefs(); auto size = outputs.size(); @@ -42,11 +40,4 @@ vaip_core::DllSafe> node_get_output_node_args(const } return vaip_core::DllSafe(ret); } - -vaip_core::DllSafe> node_get_output_shape(const Node& node, int index) { - auto outputs = node.OutputDefs(); - assert((size_t)index < outputs.size()); - return node_arg_get_shape_i64(*outputs[index]); -} - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/node_arg.cc b/onnxruntime/core/providers/vitisai/imp/node_arg.cc index 3bdeb09698d49..a54cbef91c398 100644 --- a/onnxruntime/core/providers/vitisai/imp/node_arg.cc +++ b/onnxruntime/core/providers/vitisai/imp/node_arg.cc @@ -2,25 +2,16 @@ // Licensed under the MIT License. #include "vaip/node_arg.h" #include "./vai_assert.h" - -#include +#include "core/providers/shared_library/provider_api.h" #include "./tensor_proto.h" -#include "core/graph/node_arg.h" namespace vaip { - -bool node_arg_is_exists(const NodeArg& node_arg) { - return node_arg.Exists(); -} bool node_arg_is_constant(const Graph& graph, const NodeArg& node_arg) { assert(node_arg.Exists()); assert(!node_arg.Name().empty()); - auto constant_tensor_proto = - graph.GetConstantInitializer(node_arg.Name(), true); - return constant_tensor_proto != nullptr; + return graph.GetConstantInitializer(node_arg.Name(), true) != nullptr; } - vaip_core::DllSafe> node_arg_get_shape_i64(const NodeArg& node_arg) { auto shape = node_arg.Shape(); if (nullptr == shape) return vaip_core::DllSafe>(); @@ -32,104 +23,42 @@ vaip_core::DllSafe> node_arg_get_shape_i64(const NodeArg& n } return vaip_core::DllSafe(shape_vector); } - -static void LayoutTransformRule_set_shape(onnx::TensorShapeProto& shape_proto, - const std::vector& shape) { - assert(shape.size() == static_cast(shape_proto.dim_size())); - auto rank = shape_proto.dim_size(); +void node_arg_set_shape_i64(const NodeArg& node_arg, const std::vector& shape) { + auto shape_proto = const_cast(node_arg.Shape()); + assert(shape_proto != nullptr); + assert(shape.size() == static_cast(shape_proto->dim_size())); + auto rank = shape_proto->dim_size(); for (auto i = 0; i < rank; ++i) { - shape_proto.mutable_dim(i)->set_dim_value(shape[i]); + shape_proto->mutable_dim(i)->set_dim_value(shape[i]); } } - -static void LayoutTransformRule_set_shape(onnx::TypeProto& type_proto, - const std::vector& shape) { - assert(type_proto.value_case() == onnx::TypeProto::kTensorType); - //<< type_proto.DebugString(); - auto& tensor_type = *type_proto.mutable_tensor_type(); - auto& shape_prot = *tensor_type.mutable_shape(); - return LayoutTransformRule_set_shape(shape_prot, shape); -} - -static void LayoutTransformRule_set_shape(NodeArg* node_arg, - const std::vector& shape) { - assert(node_arg != nullptr); - auto* type_proto = node_arg->TypeAsProto(); - assert(type_proto != nullptr); - return LayoutTransformRule_set_shape( - *const_cast(type_proto), shape); -} - -void node_arg_set_shape_i64(const NodeArg& node_arg, - const std::vector& shape) { - LayoutTransformRule_set_shape(const_cast(&node_arg), shape); -} - -static std::vector LayoutTransformRule_get_denotation( - const onnx::TensorShapeProto& shape) { +vaip_core::DllSafe> node_arg_get_denotation(const NodeArg& node_arg) { + auto shape = node_arg.Shape(); + if (shape == nullptr) { + return vaip_core::DllSafe>(); + } auto ret = std::vector(); - auto rank = shape.dim_size(); - ret.reserve(rank); + auto rank = shape->dim_size(); for (auto i = 0; i < rank; ++i) { - auto& d = shape.dim(i).denotation(); - ret.push_back(d); + ret.push_back(shape->dim(i).denotation()); } - return ret; + return vaip_core::DllSafe>(ret); } - -static vaip_core::DllSafe> LayoutTransformRule_get_denotation( - const onnx::TypeProto& type_proto) { - vai_assert(type_proto.value_case() == onnx::TypeProto::kTensorType, type_proto.DebugString()); - auto& tensor_type = type_proto.tensor_type(); - if (!tensor_type.has_shape()) { - return vaip_core::DllSafe>(); - } - auto& shape = tensor_type.shape(); - auto denotation = LayoutTransformRule_get_denotation(shape); - return vaip_core::DllSafe>(denotation); -} - -static vaip_core::DllSafe> LayoutTransformRule_get_denotation( - const NodeArg* node_arg) { - assert(node_arg != nullptr); - auto* type_proto = node_arg->TypeAsProto(); - assert(type_proto != nullptr); - return LayoutTransformRule_get_denotation(*type_proto); -} - -vaip_core::DllSafe> node_arg_get_denotation(const NodeArg& node_arg) { - return LayoutTransformRule_get_denotation(&node_arg); -} - -static onnx::TensorShapeProto* node_arg_get_tensor_mutable_shape( - NodeArg* node_arg) { - assert(node_arg != nullptr); - auto type_proto = const_cast(node_arg->TypeAsProto()); - assert(type_proto != nullptr); - vai_assert(type_proto->value_case() == onnx::TypeProto::kTensorType, - type_proto->DebugString()); - return type_proto->mutable_tensor_type()->mutable_shape(); -} - -static void LayoutTransformRule_set_denotation( - onnx::TensorShapeProto& shape, const std::vector& denotation) { - assert(denotation.size() == static_cast(shape.dim_size())); - auto rank = shape.dim_size(); +void node_arg_set_denotation(const NodeArg& node_arg, const std::vector& denotation) { + auto shape_proto = const_cast(node_arg.Shape()); + assert(shape_proto != nullptr); + assert(denotation.size() == static_cast(shape_proto->dim_size())); + auto rank = shape_proto->dim_size(); for (auto i = 0; i < rank; ++i) { - shape.mutable_dim(i)->set_denotation(denotation[i]); + shape_proto->mutable_dim(i)->set_denotation(denotation[i]); } } -void node_arg_set_denotation(const NodeArg& node_arg, - const std::vector& denotation) { - auto mutable_shape = - node_arg_get_tensor_mutable_shape(const_cast(&node_arg)); - - return LayoutTransformRule_set_denotation(*mutable_shape, denotation); -} - -void node_arg_set_element_type(NodeArg& node_arg, - onnx::TensorProto::DataType data_type) { - auto type_proto = const_cast(node_arg.TypeAsProto()); +void node_arg_set_element_type(NodeArg& node_arg, int type) { + if (type < 0 || type > 16) { + vai_assert(false, "TensorProto::DataType not supoort"); + } + auto data_type = static_cast(type); + auto type_proto = const_cast(node_arg.TypeAsProto()); assert(type_proto != nullptr); auto current_elem_type = type_proto->mutable_tensor_type()->elem_type(); auto input_elem_type = data_type; @@ -138,24 +67,12 @@ void node_arg_set_element_type(NodeArg& node_arg, current_elem_type, true); vai_assert(status.IsOK(), status.ErrorMessage()); } -void node_arg_set_shape(NodeArg& node_arg, std::vector shape) { - auto type_proto = const_cast(node_arg.TypeAsProto()); - assert(type_proto != nullptr); - for (auto i = 0u; i < shape.size(); i++) { - type_proto->mutable_tensor_type() - ->mutable_shape() - ->mutable_dim(i) - ->set_dim_value(shape[i]); - } -} - const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor( const Graph& graph, const NodeArg& node_arg) { auto tensor_proto = graph.GetConstantInitializer(node_arg.Name(), true); assert(tensor_proto != nullptr); return *tensor_proto; } - int node_arg_get_element_type(const NodeArg& node_arg) { auto type_proto = node_arg.TypeAsProto(); assert(type_proto != nullptr); @@ -164,9 +81,7 @@ int node_arg_get_element_type(const NodeArg& node_arg) { } return type_proto->tensor_type().elem_type(); } - -NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, - const std::string& name) { +NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, const std::string& name) { vai_assert(name != node_arg.Name(), "node arg must have a new unique name"); vai_assert(graph.GetNodeArg(name) == nullptr, std::string("node arg " + name + " already exists. ")); auto type_proto = node_arg.TypeAsProto(); @@ -174,12 +89,10 @@ NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, auto& ret = graph.GetOrCreateNodeArg(name, type_proto); return ret; } - -NodeArg& node_arg_new(Graph& graph, - const std::string& name, const std::vector* shape, int element_type) { +NodeArg& node_arg_new(Graph& graph, const std::string& name, const std::vector* shape, int element_type) { vai_assert(graph.GetNodeArg(name) == nullptr, std::string("node arg " + name + " already exists. ")); - auto type_proto = onnx::TypeProto(); - auto tensor_type = type_proto.mutable_tensor_type(); + auto type_proto = ONNX_NAMESPACE::TypeProto::Create(); + auto tensor_type = type_proto->mutable_tensor_type(); tensor_type->set_elem_type(element_type); if (shape != nullptr) { auto shape_proto = tensor_type->mutable_shape(); @@ -189,8 +102,6 @@ NodeArg& node_arg_new(Graph& graph, } else { assert(tensor_type->has_shape() == false); } - auto& ret = graph.GetOrCreateNodeArg(name, &type_proto); - return ret; + return graph.GetOrCreateNodeArg(name, type_proto.release()); } - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/node_attrs.cc b/onnxruntime/core/providers/vitisai/imp/node_attrs.cc deleted file mode 100644 index e438266e2a4c0..0000000000000 --- a/onnxruntime/core/providers/vitisai/imp/node_attrs.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#include "vaip/node_attrs.h" -#include "./vai_assert.h" - -namespace vaip { -static onnx::AttributeProto make_attribute(const std::string& name, - int64_t value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::INT); - ret.set_i(value); - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::vector value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::INTS); - for (auto v : value) { - ret.add_ints(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::string& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::STRING); - ret.set_s(value); - return ret; -} -static onnx::AttributeProto make_attribute( - const std::string& name, const std::vector& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::STRINGS); - for (auto v : value) { - ret.add_strings(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::vector& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::FLOATS); - for (auto v : value) { - ret.add_floats(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const onnx::TensorProto& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::TENSOR); - *(ret.mutable_t()) = std::move(value); - return ret; -} // namespace vaip - -NodeAttr::NodeAttr(const std::string& name, int64_t value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::string& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, - const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const onnx::TensorProto& value) - : attribute_proto_{make_attribute(name, value)} {} - -onnx::AttributeProto& NodeAttr::get() { return attribute_proto_; } - -NodeAttributesBuiler::NodeAttributesBuiler(size_t capacity) : attrs_{} { - attrs_.reserve(capacity); -} - -NodeAttributes NodeAttributesBuiler::build() { - auto ret = NodeAttributes(); - ret.reserve(attrs_.size()); - for (auto& node_attr : attrs_) { - onnx::AttributeProto& attr_proto = node_attr.get(); - auto name = attr_proto.name(); - ret.insert(std::make_pair(name, std::move(attr_proto))); - } - attrs_.clear(); - return ret; -} - -void NodeAttributesBuiler::merge_into(Node& node) { - merge_into(node.GetMutableAttributes()); -} - -void NodeAttributesBuiler::merge_into(NodeAttributes& attrs) { - for (auto& attr : attrs_) { - vai_assert(attr.get().has_name(), std::string("attr must has name " + attr.get().DebugString())); - auto name = attr.get().name(); - attrs.insert_or_assign(std::move(name), std::move(attr.get())); - } -} -} // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index ee8dfc6d03d12..97ed2d3b4b8a1 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -1,130 +1,25 @@ - - // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. + #include "./register_xir_ops.h" #include "./vai_assert.h" - -#include "core/common/logging/logging.h" -#include "core/common/status.h" - -#include "core/framework/customregistry.h" - +#include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" -#include "core/session/custom_ops.h" -#include "core/session/inference_session.h" -#include "onnx/defs/schema.h" -#include "onnx/defs/shape_inference.h" using namespace onnxruntime; -namespace vaip { - -static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { - auto* shape = ctx.getAttribute("shape"); - auto* data_type = ctx.getAttribute("data_type"); - if (data_type->s() == "float32") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT); - } else if (data_type->s() == "int8") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT8); - } else if (data_type->s() == "uint8") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::UINT8); - } else if (data_type->s() == "int32") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); - } else if (data_type->s() == "int64") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT64); - } else if (data_type->s() == "int1") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); - } else if (data_type->s() == "bfloat16") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BFLOAT16); - } else if (data_type->s() == "float16") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT16); - } else { - vai_assert(false, ", not supported data_type: " + data_type->s()); - } - if (shape != nullptr) { - for (auto i = 0; i < shape->ints_size(); ++i) { - ONNX_NAMESPACE::appendDim(ONNX_NAMESPACE::getOutputShape(ctx, 0), shape->ints(i)); - } - } else { - // set scalar type. - auto* output_shape = ONNX_NAMESPACE::getOutputShape(ctx, 0); - output_shape->clear_dim(); - } - return; -} - -static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); -} - -static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { - auto num_inputs = ctx.getNumInputs(); - - // Run inferencing on the subgraph - ONNX_NAMESPACE::GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body"); - if (!graphInferencer) { - fail_type_inference("body is missing."); - } - - std::vector input_data; - std::vector subgraph_input_types; - for (size_t i = 0; i < num_inputs; ++i) { - input_data.push_back(ctx.getInputData(i)); - subgraph_input_types.push_back(ctx.getInputType(i)); - } - std::vector output_types; - output_types = - graphInferencer->doInferencing(subgraph_input_types, input_data); - - auto num_outputs = ctx.getNumOutputs(); - auto num_of_the_subgraph_outputs = output_types.size(); - if (num_outputs != num_of_the_subgraph_outputs) { - fail_type_inference("super layer has ", num_outputs, - " but subgraphs produce ", num_of_the_subgraph_outputs); - } - for (size_t i = 0, end = output_types.size(); i < end; ++i) { - auto subgraph_output = output_types[i]; - auto* super_layer_output = ctx.getOutputType(i); - *super_layer_output = *subgraph_output; - } -} +namespace vaip { void register_xir_ops(const std::vector& domains) { - std::shared_ptr custom_registry; - auto status = CreateCustomRegistry(gsl::span(domains), custom_registry); - vai_assert(status.IsOK(), status.ErrorMessage()); for (auto domain : domains) { for (auto op : domain->custom_ops_) { auto name = op->GetName(op); - auto schema1 = custom_registry->GetOpschemaRegistry()->GetSchema(name, ORT_API_VERSION, domain->domain_); - auto schema2 = ::ONNX_NAMESPACE::OpSchema(); - schema2.SetName(schema1->Name()); - schema2.SetDomain(schema1->domain()); - auto n = 0; - for (auto input : schema1->inputs()) { - schema2.Input(n, input.GetName(), input.GetDescription(), std::string("T") + std::to_string(n), input.GetOption(), false, input.GetMinArity(), input.GetDifferentiationCategory()); - schema2.TypeConstraint(std::string("T") + std::to_string(n), DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); - n = n + 1; - } - auto m = n; - n = 0; - for (auto output : schema1->outputs()) { - auto type_str = std::string("T") + std::to_string(n + m); - schema2.Output(n, output.GetName(), output.GetDescription(), type_str, output.GetOption(), false, output.GetMinArity(), output.GetDifferentiationCategory()); - schema2.TypeConstraint(type_str, DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); - n = n + 1; - } - schema2.SinceVersion(1); - schema2.AllowUncheckedAttributes(); if ((std::string)name == "super_layer") { - schema2.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 1); } else if ((std::string)name == "FixNeuron") { - schema2.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 2); } else { - schema2.TypeAndShapeInferenceFunction(xir_shape_infer); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 3); } - ONNX_NAMESPACE::RegisterSchema(schema2, ORT_API_VERSION); } } } diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index db03354bf4c44..48dcd220a150c 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -1,20 +1,19 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #include "./tensor_proto.h" -#include "./vai_assert.h" -#include "core/framework/tensorprotoutils.h" #include #include +#include "./vai_assert.h" +#include "core/providers/shared_library/provider_api.h" namespace vaip { - -gsl::span tensor_proto_as_raw( - const ONNX_NAMESPACE::TensorProto& tensor) { +gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& tensor) { auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { std::vector unpacked_tensor; - auto s = onnxruntime::utils::UnpackInitializerData(tensor, onnxruntime::Path(), unpacked_tensor); + auto path = onnxruntime::Path::Create(); + auto s = onnxruntime::utils::UnpackInitializerData(tensor, *path, unpacked_tensor); mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size()); mut_tensor.clear_float_data(); mut_tensor.clear_int32_data(); @@ -27,78 +26,51 @@ gsl::span tensor_proto_as_raw( return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); } -size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor) { - return tensor.raw_data().size(); -} - -std::vector tensor_proto_get_shape( - const onnx::TensorProto& tensor_proto) { +vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor_proto) { auto ret = std::vector(); int rank = tensor_proto.dims_size(); if (rank > 0) { - ret.reserve((size_t)rank); - for (auto i = 0; i < rank; ++i) { - ret.push_back(tensor_proto.dims(i)); + auto& dims = tensor_proto.dims(); + for (auto i = 0; i < dims.size(); ++i) { + ret.push_back(dims[i]); } } - return ret; + return vaip_core::DllSafe(ret); } - -const std::string& tensor_proto_get_name( - const ONNX_NAMESPACE::TensorProto& tensor) { - return tensor.name(); +static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, const std::vector& shape, + int data_type, const char* data, size_t data_size) { + auto tensor_proto = ONNX_NAMESPACE::TensorProto::Create(); + tensor_proto->set_name(name); + for (auto s : shape) { + tensor_proto->add_dims(s); + } + tensor_proto->set_data_type(data_type); + tensor_proto->mutable_raw_data()->assign(data, data_size); + return tensor_proto.release(); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i32( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT32); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int32_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i32(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT32, + reinterpret_cast(&data[0]), data.size() * sizeof(int32_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i64( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT64); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int64_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64, + reinterpret_cast(&data[0]), data.size() * sizeof(int64_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i8( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT8); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int8_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8, + reinterpret_cast(&data[0]), data.size() * sizeof(int8_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_floats( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::FLOAT); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(float)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_floats(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + reinterpret_cast(&data[0]), data.size() * sizeof(float)); } } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 00aa388c809c1..292905ca734f1 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -1,31 +1,20 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -// -#include "core/common/gsl.h" -#include "onnx/onnx_pb.h" -namespace vaip { - -gsl::span tensor_proto_as_raw( - const ONNX_NAMESPACE::TensorProto& tensor); -size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor); - -std::vector tensor_proto_get_shape( - const ONNX_NAMESPACE::TensorProto& tensor); -const std::string& tensor_proto_get_name( - const ONNX_NAMESPACE::TensorProto& tensor); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i8( - const std::string& name, const std::vector& shape, - const std::vector& data); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i32( - const std::string& name, const std::vector& shape, - const std::vector& data); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i64( - const std::string& name, const std::vector& shape, - const std::vector& data); - -ONNX_NAMESPACE::TensorProto tensor_proto_new_floats( - const std::string& name, const std::vector& shape, - const std::vector& data); +#include "vaip/my_ort.h" +#include "vaip/vaip_gsl.h" +#include "vaip/dll_safe.h" +namespace vaip { +gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& tensor); +vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor); +const std::string& tensor_proto_get_name(const ONNX_NAMESPACE::TensorProto& tensor); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i32(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_floats(const std::string& name, const std::vector& shape, + const std::vector& data); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/capability.h b/onnxruntime/core/providers/vitisai/include/vaip/capability.h index d6b5ae34decc2..e7644dbe86354 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/capability.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/capability.h @@ -2,8 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "core/framework/compute_capability.h" -#include "core/graph/graph_viewer.h" +#include "core/providers/shared_library/provider_api.h" #include "vaip/custom_op.h" namespace vaip { using namespace ::onnxruntime; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index c446ab3aefcc5..1f8b8802e86b4 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -2,16 +2,15 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -#include -#include -#include - +#include "core/providers/shared_library/provider_api.h" +#define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/provider_options.h" #include "vaip/my_ort.h" #include "vaip/dll_safe.h" #include "vaip/custom_op.h" -std::vector initialize_vitisai_ep(); -vaip_core::DllSafe>> compile_onnx_model_with_options( - const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); +void initialize_vitisai_ep(); +vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); +std::shared_ptr get_kernel_registry_vitisaiep(); +const std::vector& get_domains_vitisaiep(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/graph.h b/onnxruntime/core/providers/vitisai/include/vaip/graph.h index 9def8645709fb..292fb2bb38b2b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/graph.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/graph.h @@ -1,25 +1,19 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -#include #include "./node.h" +#include "vaip/my_ort.h" namespace vaip { using namespace onnxruntime; void graph_remove_node(Graph& graph, const NodeInput& node_input); -Node& graph_add_node(Graph& graph, const std::string& name, - const std::string& op_type, const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const NodeAttributes& attributes, - const std::string& domain); - -void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, size_t initializer_size_threshold); -Node& graph_fuse(Graph& graph, const std::string& name, - const std::string& op_type, - const std::vector& nodes, - const std::vector& inputs, - const std::vector& outputs, +Node& graph_add_node(Graph& graph, const std::string& name, const std::string& op_type, const std::string& description, + const std::vector& input_args, const std::vector& output_args, + const NodeAttributes& attributes, const std::string& domain); +void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, + size_t initializer_size_threshold); +Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector& nodes, + const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index d43ef1253715c..46fc4ac9b2a5d 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -9,15 +9,17 @@ #include namespace onnxruntime { -class Model; -class Graph; -class GraphViewer; -class Node; -class NodeArg; +struct Model; +struct Graph; +struct GraphViewer; +struct Node; +struct NodeArg; +struct ProviderHost; +struct NodeAttributes; } // namespace onnxruntime namespace ONNX_NAMESPACE { -class AttributeProto; -class TensorProto; +struct AttributeProto; +struct TensorProto; #ifndef USE_VITISAI enum TensorProto_DataType : int { TensorProto_DataType_UNDEFINED = 0, @@ -68,6 +70,7 @@ using onnxruntime::GraphViewer; using onnxruntime::Model; using onnxruntime::Node; using onnxruntime::NodeArg; +using onnxruntime::NodeAttributes; struct ModelDeleter { VAIP_DLL_SPEC void operator()(Model* tp) const; }; @@ -75,22 +78,17 @@ using ModelPtr = std::unique_ptr; struct AttributeProtoDeleter { VAIP_DLL_SPEC void operator()(AttributeProto* p) const; }; -using AttributeProtoPtr = - std::unique_ptr; +using AttributeProtoPtr = std::unique_ptr; struct TensorProtoDeleter { VAIP_DLL_SPEC void operator()(TensorProto* tp) const; }; using TensorProtoPtr = std::unique_ptr; -/// I cannot forward declare a using directive, because -/// std::unorderd_map required AttributeProto must be defiend. -class NodeAttributes; struct NodeAttributesDeleter { VAIP_DLL_SPEC void operator()(NodeAttributes* p) const; }; -using NodeAttributesPtr = - std::unique_ptr; +using NodeAttributesPtr = std::unique_ptr; /// get node's input /// when Node* is nullptr, it is a tensor in the initializer. /// node_arg is always non-null. diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node.h b/onnxruntime/core/providers/vitisai/include/vaip/node.h index bad7660f66744..31d9d4bd73b8b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/node.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/node.h @@ -2,10 +2,6 @@ // Licensed under the MIT License. #pragma once - -#include - -#include "core/graph/node_arg.h" #include "vaip/dll_safe.h" #include "vaip/my_ort.h" namespace vaip { @@ -17,8 +13,4 @@ vaip_core::DllSafe> node_get_inputs(const Node& node); /// to support multiple outputs vaip_core::DllSafe> node_get_output_node_args(const Node& node); -/// get output shape -/// index is usually zero, because most operators only have a single output. -vaip_core::DllSafe> node_get_output_shape(const Node& node, int index = 0); - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h index 76432fc5b3a68..fca641c5e11c8 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h @@ -2,9 +2,8 @@ // Licensed under the MIT License. #pragma once -#include #include "vaip/dll_safe.h" -#include +#include "vaip/my_ort.h" namespace vaip { using namespace onnxruntime; @@ -26,9 +25,7 @@ void node_arg_set_shape_i64(const NodeArg& node_arg, void node_arg_set_denotation(const NodeArg& node_arg, const std::vector& denotation); void node_arg_set_element_type(NodeArg& node_arg, - ONNX_NAMESPACE::TensorProto::DataType data_type); -void node_arg_set_shape(NodeArg& node_arg, std::vector shape); - + int data_type); const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor(const Graph& graph, const NodeArg& node_arg); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h b/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h deleted file mode 100644 index 49cd1aad89f4f..0000000000000 --- a/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -#include - -#include "core/graph/basic_types.h" -namespace vaip { -using namespace onnxruntime; -class NodeAttr { - public: - NodeAttr(const std::string& name, int64_t value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const std::string& value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const onnx::TensorProto& value); - - onnx::AttributeProto& get(); - - private: - onnx::AttributeProto attribute_proto_; -}; - -class NodeAttributesBuiler { - public: - explicit NodeAttributesBuiler(size_t capacity = 10); - NodeAttributesBuiler(const NodeAttributesBuiler&) = delete; - NodeAttributesBuiler(NodeAttributesBuiler&&) = default; - /// after build, all attrs_ are cleared. - NodeAttributes build(); - /// for efficiency reason, after merge_into, all attrs_ are moved. - void merge_into(Node& node); - void merge_into(NodeAttributes& attrs); - template - NodeAttributesBuiler& add(const std::string& name, T&& value) { - attrs_.emplace_back(name, std::forward(value)); - return *this; - } - - private: - std::vector attrs_; -}; -} // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 0d7d5f6220d06..ae5f71d66269c 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,6 +13,7 @@ struct OrtApi; namespace vaip_core { struct OrtApiForVaip { + onnxruntime::ProviderHost* host_; const OrtApi* ort_api_; // model Model* (*model_load)(const std::string& file); // [0] @@ -49,7 +50,7 @@ struct OrtApiForVaip { const std::string& description, const std::vector& input_args, const std::vector& output_args, - NodeAttributes& attributes, + const NodeAttributes& attributes, const std::string& domain); // [18] void (*graph_save)(const Graph& graph, const std::string& filename, const std::string& dat_filename, @@ -119,8 +120,8 @@ struct OrtApiForVaip { NodeAttributes* (*node_attributes_new)(); // [46] void (*node_attributes_delete)(NodeAttributes* p); // [47] void (*node_attributes_add)(NodeAttributes& p, AttributeProto&& attr); // [48] - AttributeProto* (*node_attributes_get)(NodeAttributes& p, - const std::string& name); // [49] + const AttributeProto* (*node_attributes_get)(const NodeAttributes& p, + const std::string& name); // [49] DllSafe> (*node_attributes_get_keys)( NodeAttributes& p); // [50] /// attr proto @@ -194,5 +195,4 @@ VAIP_DLL_SPEC const OrtApiForVaip* api(); ? ::vaip_core::api()->name \ : (assert(false && #name " is not set"), nullptr)) #endif -VAIP_DLL_SPEC void initialize_ort(); } // namespace vaip_core diff --git a/onnxruntime/core/providers/vitisai/symbols.def b/onnxruntime/core/providers/vitisai/symbols.def new file mode 100644 index 0000000000000..4ec2f7914c208 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/symbols.def @@ -0,0 +1,2 @@ +EXPORTS + GetProvider diff --git a/onnxruntime/core/providers/vitisai/version_script.lds b/onnxruntime/core/providers/vitisai/version_script.lds new file mode 100644 index 0000000000000..2c8e9c4b3ed64 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/version_script.lds @@ -0,0 +1,9 @@ +#_init and _fini should be local +VERS_1.0 { + global: + GetProvider; + + # Hide everything else. + local: + *; +}; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 5f20b32cd6dc4..6fc09f3495aa1 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -1,91 +1,34 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. -#include "core/graph/graph_utils.h" #include "vitisai_execution_provider.h" #include -#include #include #include -#include "core/common/common.h" - #include "vaip/capability.h" #include "vaip/global_api.h" -#include "core/session/custom_ops.h" -#include "core/session/inference_session.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { - constexpr const char* VITISAI = "VITISAI"; -static vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { -#ifndef _WIN32 - auto model_path = graph_viewer.ModelPath().ToPathString(); -#else - using convert_t = std::codecvt_utf8; - std::wstring_convert strconverter; - auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); -#endif - return compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options); -} - -struct MyCustomOpKernel : OpKernel { - MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { - op_kernel_ = - op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); - } - - ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } - - Status Compute(OpKernelContext* ctx) const override { - op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); - return Status::OK(); - } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MyCustomOpKernel); - - const OrtCustomOp& op_; - void* op_kernel_; -}; - -VitisAIExecutionProvider::VitisAIExecutionProvider(const ProviderOptions& info) +VitisAIExecutionProvider::VitisAIExecutionProvider( + const ProviderOptions& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { - custom_op_domains_ = initialize_vitisai_ep(); - registry_ = std::make_shared(); CreateKernelRegistry(); } void VitisAIExecutionProvider::CreateKernelRegistry() { - for (const auto& domain : custom_op_domains_) { + for (const auto& domain : get_domains_vitisaiep()) { for (const auto* op : domain->custom_ops_) { - KernelDefBuilder def_builder; - def_builder.SetName(op->GetName(op)); - def_builder.SetDomain(domain->domain_); - def_builder.SinceVersion(1); - if (op->version > 12) { - auto input_count = op->GetInputTypeCount(op); - for (auto i = 0u; i < input_count; i++) { - def_builder.InputMemoryType(op->GetInputMemoryType(op, i), i); - } - } - def_builder.Provider(onnxruntime::kVitisAIExecutionProvider); - KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, - std::unique_ptr& out) -> Status { - out = std::make_unique(info, *op); - return Status::OK(); - }; - std::ignore = registry_->Register(def_builder, kernel_create_fn); vitisai_optypes_.insert(op->GetName(op)); } } } -std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return registry_; } +std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return get_kernel_registry_vitisaiep(); } std::vector> VitisAIExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { @@ -111,9 +54,9 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector& node_compute_funcs) { for (const auto& fused_node_graph : fused_nodes_and_graphs) { NodeComputeInfo compute_info; - const onnx::AttributeProto* attr = graph_utils::GetNodeAttribute(fused_node_graph.fused_node, "index"); - assert(attr != nullptr); - size_t index = (size_t)attr->i(); + auto& attrs = fused_node_graph.fused_node.get().GetAttributes(); + assert(attrs.count("index")); + size_t index = attrs.at("index").i(); compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index e86b53339d4d2..186427be4fab2 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -9,8 +9,7 @@ #include #include -#include "core/framework/execution_provider.h" -#include "core/framework/customregistry.h" +#include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" // we cannot include vaip/vaip.hpp here because header file referred by @@ -21,7 +20,6 @@ class DllSafe; class ExecutionProvider; } // namespace vaip_core namespace onnxruntime { - // Logical device representation. class VitisAIExecutionProvider : public IExecutionProvider { public: diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 4c416124ca8f2..5895e1973f231 100755 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -11,7 +11,6 @@ #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" -#include "core/providers/shared_library/provider_host_api.h" using namespace onnxruntime; namespace onnxruntime { @@ -30,10 +29,37 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider() { return std::make_unique(info_); } -std::shared_ptr VitisAIProviderFactoryCreator::Create( - const ProviderOptions& provider_options) { - initialize_vitisai_ep(); - return std::make_shared(provider_options); -} +struct VitisAI_Provider : Provider { + // Takes a pointer to a provider specific structure to create the factory. For example, with OpenVINO it is a pointer to an OrtOpenVINOProviderOptions structure + std::shared_ptr + CreateExecutionProviderFactory(const void* options) override { + return std::make_shared(GetProviderOptions(options)); + } + // Convert provider options struct to ProviderOptions which is a map + ProviderOptions GetProviderOptions(const void* options) override { + auto vitisai_options = reinterpret_cast(options); + return *vitisai_options; + } + // Update provider options from key-value string configuration + void UpdateProviderOptions(void* options, const ProviderOptions& provider_options) override { + auto vitisai_options = reinterpret_cast(options); + for (const auto& entry : provider_options) { + vitisai_options->insert_or_assign(entry.first, entry.second); + } + }; + // Get provider specific custom op domain list. Provider has the resposibility to release OrtCustomOpDomain instances it creates. + void GetCustomOpDomainList(IExecutionProviderFactory*, std::vector&) override{}; + // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded + void Initialize() override { initialize_vitisai_ep(); } + // Called right before unloading the shared library + void Shutdown() override {} +} g_provider; } // namespace onnxruntime + +extern "C" { + +ORT_API(onnxruntime::Provider*, GetProvider) { + return &onnxruntime::g_provider; +} +} diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d77c188f832a7..947bb30d42f70 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2724,6 +2724,7 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::SetDeterministicCompute, &OrtApis::KernelContext_ParallelFor, &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, + &OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c1caafa4dcad3..9ce94ba89a942 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -509,4 +509,8 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 45d8006e6b49e..8774b11f2dffc 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -56,6 +56,8 @@ namespace ONNX_NAMESPACE { // We use these names in the provider API because we don't have the protobuf definitions of the RepeatedField* types using int64s = google::protobuf::RepeatedField; +using float32s = google::protobuf::RepeatedField; +using StringStringEntryProtos = google::protobuf::RepeatedPtrField; using TensorProtos = google::protobuf::RepeatedPtrField; using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; using ValueInfoProtos = google::protobuf::RepeatedPtrField; @@ -76,6 +78,7 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; #include "core/providers/migraphx/migraphx_provider_factory_creator.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" #include "core/providers/tensorrt/tensorrt_provider_factory_creator.h" +#include "core/providers/vitisai/vitisai_provider_factory_creator.h" #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cann/cann_provider_factory.h" @@ -118,6 +121,7 @@ ProviderInfo_Dnnl& GetProviderInfo_Dnnl(); ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); ProviderInfo_ROCM& GetProviderInfo_ROCM(); ProviderHostCPU& GetProviderHostCPU(); +ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops); struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Iterator { TensorShapeProto_Dimension_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} @@ -269,7 +273,10 @@ struct ProviderHostImpl : ProviderHost { Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } - + Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) override { + return utils::UnpackInitializerData(tensor, model_path, unpacked_tensor); + } uint16_t math__floatToHalf(float f) override { return math::floatToHalf(f); } float math__halfToFloat(uint16_t h) override { return math::halfToFloat(h); } @@ -351,12 +358,32 @@ struct ProviderHostImpl : ProviderHost { void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; } std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); } + // Env + Env& Env__Default() override { return Env::Default(); } + // Utils::DataTypeUtils (wrapped) const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) override { return ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(type_proto); } // int64s (wrapped) int int64s__size(const ONNX_NAMESPACE::int64s* p) override { return p->size(); } const int64_t& int64s__Get(const ONNX_NAMESPACE::int64s* p, int index) override { return p->Get(index); } + void int64s__Reserve(ONNX_NAMESPACE::int64s* p, int size) override { p->Reserve(size); }; + const int64_t* int64s__data(const ONNX_NAMESPACE::int64s* p) override { return p->data(); } + + // float32s + void float32s__Reserve(ONNX_NAMESPACE::float32s* p, int size) override { p->Reserve(size); }; + const float* float32s__data(const ONNX_NAMESPACE::float32s* p) override { return p->data(); } + int float32s__size(const ONNX_NAMESPACE::float32s* p) override { return p->size(); } + + // StringStringEntryProto + std::string* StringStringEntryProto__mutable_key(ONNX_NAMESPACE::StringStringEntryProto* p) override { return p->mutable_key(); } + std::string* StringStringEntryProto__mutable_value(ONNX_NAMESPACE::StringStringEntryProto* p) override { return p->mutable_value(); } + + // StringStringEntryProtos + void StringStringEntryProtos__Clear(ONNX_NAMESPACE::StringStringEntryProtos* p) override { p->Clear(); }; + ONNX_NAMESPACE::StringStringEntryProto* StringStringEntryProtos__Add(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->Add(); } + int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->size(); } + ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) override { return p->at(index); }; #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional (wrapped) @@ -373,6 +400,7 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->shape(); } ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->mutable_shape(); } int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->elem_type(); } + void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) override { p->set_elem_type(value); }; // TypeProto_SparseTensor (wrapped) #if !defined(DISABLE_SPARSE_TENSORS) @@ -425,9 +453,18 @@ struct ProviderHostImpl : ProviderHost { float AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p, int i) override { return p->floats(i); } const std::string& AttributeProto__strings(const ONNX_NAMESPACE::AttributeProto* p, int i) override { return p->strings(i); } const ONNX_NAMESPACE::int64s& AttributeProto__ints(const ONNX_NAMESPACE::AttributeProto* p) override { return p->ints(); } + const ONNX_NAMESPACE::float32s& AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p) override { return p->floats(); } + ONNX_NAMESPACE::int64s* AttributeProto__mutable_ints(ONNX_NAMESPACE::AttributeProto* p) override { return p->mutable_ints(); } + ONNX_NAMESPACE::float32s* AttributeProto__mutable_floats(ONNX_NAMESPACE::AttributeProto* p) override { return p->mutable_floats(); } + void AttributeProto__add_ints(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { p->add_ints(value); }; + void AttributeProto__add_floats(ONNX_NAMESPACE::AttributeProto* p, float value) override { p->add_floats(value); }; + void AttributeProto__add_strings(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { p->add_strings(value); }; + int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) override { return p->i(); } float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) override { return p->f(); } + const ONNX_NAMESPACE::TensorProto& AttributeProto__t(const ONNX_NAMESPACE::AttributeProto* p) override { return p->t(); } void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_s(value); } + void AttributeProto__set_f(ONNX_NAMESPACE::AttributeProto* p, const float& value) override { return p->set_f(value); } void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { return p->set_i(value); } const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) override { return p->s(); } void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_name(value); } @@ -449,6 +486,7 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_value_info(); } ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_initializer(); } ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) override { return p->add_node(); } + std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_name(); } ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) override { return p->mutable_node(index); } void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override { *p = v; } @@ -466,6 +504,7 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::GraphProto* ModelProto__mutable_graph(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_graph(); } void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) override { p->set_ir_version(value); } + ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_metadata_props(); }; // NodeProto (wrapped) std::unique_ptr NodeProto__construct() override { return std::make_unique(); } @@ -480,19 +519,34 @@ struct ProviderHostImpl : ProviderHost { void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) override { delete p; } void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) override { *p = v; } bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_name(); } + void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) override { p->set_name(name); } + const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) override { return p->name(); } int TensorProto__dims_size(const ONNX_NAMESPACE::TensorProto* p) override { return p->dims_size(); } const ONNX_NAMESPACE::int64s& TensorProto__dims(const ONNX_NAMESPACE::TensorProto* p) override { return p->dims(); } + void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) override { p->add_dims(value); } bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_data_location(); } int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_location(); } bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_raw_data(); } const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->raw_data(); } + std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_raw_data(); } + int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_type(); } + void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) override { p->set_data_type(type); } bool TensorProto_DataType_IsValid(int value) override { return ONNX_NAMESPACE::TensorProto::DataType_IsValid(value); } void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) override { p->CopyFrom(*other); } + ONNX_NAMESPACE::StringStringEntryProtos* TensorProto__mutable_external_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_external_data(); }; + void TensorProto__clear_float_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_float_data(); } + void TensorProto__clear_int32_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_int32_data(); } + void TensorProto__clear_string_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_string_data(); } + void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_int64_data(); } + void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_double_data(); } + void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_uint64_data(); } // TensorProtos (wrapped) ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); } + int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) override { return p->size(); } + ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) override { return p->at(index); }; // TensorShapeProto_Dimension (wrapped) int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->value_case(); } @@ -502,6 +556,8 @@ struct ProviderHostImpl : ProviderHost { void TensorShapeProto_Dimension__set_dim_value(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, int64_t value) override { return p->set_dim_value(value); } bool TensorShapeProto_Dimension__has_dim_value(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->has_dim_value(); } bool TensorShapeProto_Dimension__has_dim_param(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->has_dim_param(); } + const std::string& TensorShapeProto_Dimension__denotation(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) const override { return p->denotation(); } + void TensorShapeProto_Dimension__set_denotation(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, const std::string& value) override { return p->set_denotation(value); } // TensorShapeProto_Dimensions (wrapped) std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) override { @@ -530,6 +586,90 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } + static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { + auto* shape = ctx.getAttribute("shape"); + auto* data_type = ctx.getAttribute("data_type"); + int32_t elemType = 0; + if (data_type->s() == "float32") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + } else if (data_type->s() == "int8") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT8; + } else if (data_type->s() == "uint8") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } else if (data_type->s() == "int32") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32; + } else if (data_type->s() == "int64") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64; + } else if (data_type->s() == "int1") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL; + } else if (data_type->s() == "bfloat16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; + } else if (data_type->s() == "float16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + } else if (data_type->s() == "uint16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16; + } else if (data_type->s() == "int16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16; + } else { + return; + } + ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); + if (shape != nullptr) { + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } else { + // set scalar type. + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); + } + } + + static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); + } + + static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { + auto num_inputs = ctx.getNumInputs(); + + // Run inferencing on the subgraph + auto* graphInferencer = ctx.getGraphAttributeInferencer("body"); + + std::vector input_data; + std::vector subgraph_input_types; + for (size_t i = 0; i < num_inputs; ++i) { + input_data.push_back(ctx.getInputData(i)); + subgraph_input_types.push_back(ctx.getInputType(i)); + } + + auto output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); + for (size_t i = 0, end = output_types.size(); i < end; ++i) { + *ctx.getOutputType(i) = *output_types[i]; + } + } + void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) override { + auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + const auto& domain_to_version_map = domain_instance.Map(); + if (domain_to_version_map.find(domain) == domain_to_version_map.end()) { + domain_instance.AddDomainToVersion(domain, 1, 1000); + } + auto schema = CreateSchema(domain, {op}); + switch (type) { + case 1: + schema.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); + break; + case 2: + schema.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); + break; + case 3: + schema.TypeAndShapeInferenceFunction(xir_shape_infer); + break; + default: + break; + } + ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION); + } + // ConfigOptions (wrapped) std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) override { return p->GetConfigEntry(config_key); @@ -761,6 +901,9 @@ struct ProviderHostImpl : ProviderHost { void Node__ToProto(const Node* p, ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) override { p->ToProto(proto, update_subgraphs); } const NodeAttributes& Node__GetAttributes(const Node* p) noexcept override { return p->GetAttributes(); } + void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) override { + p->AddAttribute(attr_name, value); + } size_t Node__GetInputEdgesCount(const Node* p) noexcept override { return p->GetInputEdgesCount(); } size_t Node__GetOutputEdgesCount(const Node* p) noexcept override { return p->GetOutputEdgesCount(); } @@ -769,13 +912,19 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr Node__OutputNodesBegin(const Node* p) noexcept override { return std::make_unique(p->OutputNodesBegin()); } std::unique_ptr Node__OutputNodesEnd(const Node* p) noexcept override { return std::make_unique(p->OutputNodesEnd()); } - + std::unique_ptr Node__InputEdgesBegin(const Node* p) noexcept override { + return std::make_unique(p->InputEdgesBegin()); + } + std::unique_ptr Node__InputEdgesEnd(const Node* p) noexcept override { + return std::make_unique(p->InputEdgesEnd()); + } std::unique_ptr Node__OutputEdgesBegin(const Node* p) noexcept override { return std::make_unique(p->OutputEdgesBegin()); } std::unique_ptr Node__OutputEdgesEnd(const Node* p) noexcept override { return std::make_unique(p->OutputEdgesEnd()); } void Node__ForEachDef(const Node* p, std::function func, bool include_missing_optional_defs) override { p->ForEachDef(func, std::move(include_missing_optional_defs)); } const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) noexcept override { return p->GetAttributeNameToMutableSubgraphMap(); } std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const override { return p->GetAttributeNameToSubgraphMap(); } + int Node__NodeType(const Node* p) const noexcept override { return int(p->NodeType()); } // NodeArg (wrapped) const std::string& NodeArg__Name(const NodeArg* p) noexcept override { return p->Name(); } @@ -784,6 +933,7 @@ struct ProviderHostImpl : ProviderHost { const NodeArgInfo& NodeArg__ToProto(const NodeArg* p) noexcept override { return p->ToProto(); } bool NodeArg__Exists(const NodeArg* p) const noexcept override { return p->Exists(); } const ONNX_NAMESPACE::TypeProto* NodeArg__TypeAsProto(const NodeArg* p) noexcept override { return p->TypeAsProto(); } + Status NodeArg__OverrideTypesHelper(NodeArg* p, const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) override { return p->OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types); }; // NodeAttributes (wrapped) std::unique_ptr NodeAttributes__construct() override { return std::make_unique(); } @@ -806,12 +956,20 @@ struct ProviderHostImpl : ProviderHost { } void NodeAttributes__insert(NodeAttributes* p, const NodeAttributes& v) override { return p->insert(v.begin(), v.end()); } void NodeAttributes__emplace(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) override { p->emplace(k, v); } + void NodeAttributes__insert_or_assign(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) override { p->insert_or_assign(k, v); } void NodeAttributes__reserve(NodeAttributes* p, size_t size) override { p->reserve(size); } // Model (wrapped) + std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, + const logging::Logger& logger) override { + return std::make_unique(model_proto, model_path, nullptr, logger); + } void Model__operator_delete(Model* p) override { delete p; } Graph& Model__MainGraph(Model* p) override { return p->MainGraph(); } std::unique_ptr Model__ToProto(Model* p) override { return std::make_unique(p->ToProto()); } + std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) override { return std::make_unique(p->ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold)); }; + const ModelMetaData& Model__MetaData(const Model* p) const noexcept override { return p->MetaData(); }; + Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) override { return Model::Load(file_path, model_proto); } // Graph (wrapped) std::unique_ptr Graph__CreateGraphViewer(const Graph* p) override { return std::make_unique(*p); } @@ -831,6 +989,12 @@ struct ProviderHostImpl : ProviderHost { void Graph__SetOutputs(Graph* p, gsl::span outputs) override { p->SetOutputs(outputs); } const std::vector& Graph__GetInputs(const Graph* p) noexcept override { return p->GetInputs(); } + std::vector Graph__Nodes(const Graph* p) override { + auto& node_refererence = p->Nodes(); + std::vector nodes(p->NumberOfNodes(), nullptr); + std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); + return nodes; + } bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) override { return p->GetInitializedTensor(tensor_name, value); } const Node* Graph__ParentNode(const Graph* p) const override { return p->ParentNode(); } @@ -840,6 +1004,40 @@ struct ProviderHostImpl : ProviderHost { const Path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); } const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); } bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); } + const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } + const Model& Graph__GetModel(const Graph* p) override { return p->GetModel(); } + void Graph__ReverseDFSFrom(const Graph* p, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const override { + p->ReverseDFSFrom(from, enter, leave, comp, stop); + } + Graph& Graph__SetGraphResolveNeeded(Graph* p) override { return p->SetGraphResolveNeeded(); } + void Graph__RemoveInitializedTensor(Graph* p, const std::string& tensor_name) override { p->RemoveInitializedTensor(tensor_name); } + + std::vector Graph__GetConsumerNodes(const Graph* p, const std::string& node_arg_name) const override { + return p->GetConsumerNodes(node_arg_name); + } + void Graph__AddEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) override { + p->AddEdge(src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void Graph__RemoveEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) override { + p->RemoveEdge(src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void Graph__RemoveNode(Graph* p, NodeIndex index) override { p->RemoveNode(index); } + Node& Graph__FuseSubGraph(Graph* p, const IndexedSubGraph& sub_graph, const std::string& fused_node_name) override { + return p->FuseSubGraph(sub_graph, fused_node_name); + } + void Graph__UpdateProducerNode(Graph* p, const std::string& node_arg_name, NodeIndex node_index) override { + p->UpdateProducerNode(node_arg_name, node_index); + } + const ONNX_NAMESPACE::TensorProto* Graph__GetConstantInitializer(const Graph* p, const std::string& name, bool check_outer_scope) const override { + return p->GetConstantInitializer(name, check_outer_scope); + } + const InitializedTensorSet& Graph__GetAllInitializedTensors(const Graph* p) override { return p->GetAllInitializedTensors(); } int Graph__MaxNodeIndex(const Graph* p) const noexcept override { return p->MaxNodeIndex(); } Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept override { return p->GetNode(node_index); } const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); } @@ -884,11 +1082,14 @@ struct ProviderHostImpl : ProviderHost { void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override { GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args); } + const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } // Path (wrapped) PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); } const std::vector& Path__GetComponents(const Path* p) noexcept override { return p->GetComponents(); } bool Path__IsEmpty(const Path* p) noexcept override { return p->IsEmpty(); } + std::unique_ptr Path__construct() override { return std::make_unique(); } + void Path__operator_delete(ONNX_NAMESPACE::Path* p) override { delete p; }; // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } @@ -1274,6 +1475,7 @@ static ProviderLibrary s_library_rocm(LIBRARY_PREFIX ORT_TSTR("onnxruntime_provi #endif ); static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_dnnl") LIBRARY_EXTENSION); +static ProviderLibrary s_library_vitisai(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_vitisai") LIBRARY_EXTENSION); static ProviderLibrary s_library_openvino(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_openvino") LIBRARY_EXTENSION); static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_tensorrt") LIBRARY_EXTENSION #ifndef _WIN32 @@ -1302,6 +1504,7 @@ static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX ORT_TSTR("onnxruntime_p void UnloadSharedProviders() { s_library_dnnl.Unload(); + s_library_vitisai.Unload(); s_library_openvino.Unload(); s_library_tensorrt.Unload(); s_library_cuda.Unload(); @@ -1489,6 +1692,10 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(co return s_library_dnnl.Get().CreateExecutionProviderFactory(dnnl_options); } +std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { + return s_library_vitisai.Get().CreateExecutionProviderFactory(&provider_options); +} + ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() { return reinterpret_cast(s_library_openvino.Get().GetInfo()); } @@ -2371,3 +2578,34 @@ ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProvid ORT_UNUSED_PARAMETER(ptr); #endif } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { + API_IMPL_BEGIN + onnxruntime::ProviderOptions provider_options; + for (size_t i = 0; i != num_keys; ++i) { + if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || + provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Provider options key/value cannot be empty"); + } + + // arbitrary length to validate the key/value. adjust if/when needed. + // TODO: are any other input validation checks required here (and in the other functions that process + // provider options)? + if (strlen(provider_options_keys[i]) > 1024 || strlen(provider_options_values[i]) > 1024) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Maximum string length for a provider options key/value is 1024."); + } + + provider_options[provider_options_keys[i]] = provider_options_values[i]; + } + auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options); + if (!factory) { + return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library"); + } + + options->provider_factories.push_back(factory); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 86b3d01c640a3..9477b563fe0c4 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -148,12 +148,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); -#endif - } else if (strcmp(provider_name, "VitisAI") == 0) { -#if defined(USE_VITISAI) - options->provider_factories.push_back(VitisAIProviderFactoryCreator::Create(provider_options)); -#else - status = create_not_supported_status(); #endif } else { ORT_UNUSED_PARAMETER(options); @@ -499,4 +493,14 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { ORT_UNUSED_PARAMETER(ptr); } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateNotEnabledStatus("VitisAI"); +} #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d2cd6140b838e..f5c9b609c348a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -983,7 +983,7 @@ std::unique_ptr CreateExecutionProviderInstance( return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider(); #endif } else if (type == kVitisAIExecutionProvider) { -#if USE_VITISAI +#ifdef USE_VITISAI const auto it = provider_options_map.find(type); if (it == provider_options_map.end()) { LOGS_DEFAULT(FATAL) << "cannot find provider options for VitisAIExecutionProvider"; diff --git a/setup.py b/setup.py index e94165fdf9b05..67d34b065ad03 100644 --- a/setup.py +++ b/setup.py @@ -298,6 +298,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_shared.so"]) libs.extend(["libonnxruntime_providers_dnnl.so"]) libs.extend(["libonnxruntime_providers_openvino.so"]) + libs.extend(["libonnxruntime_providers_vitisai.so"]) libs.append(providers_cuda_or_rocm) libs.append(providers_tensorrt_or_migraphx) libs.append(providers_cann) @@ -310,6 +311,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_dnnl.dylib"]) libs.extend(["libonnxruntime_providers_tensorrt.dylib"]) libs.extend(["libonnxruntime_providers_cuda.dylib"]) + libs.extend(["libonnxruntime_providers_vitisai.dylib"]) if nightly_build: libs.extend(["libonnxruntime_pywrapper.dylib"]) else: @@ -320,6 +322,7 @@ def finalize_options(self): libs.extend(["onnxruntime_providers_tensorrt.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) + libs.extend(["onnxruntime_providers_vitisai.dll"]) # DirectML Libs libs.extend(["DirectML.dll"]) if nightly_build: