diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7168a99fe1f93..cd6d0669c67f5 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1670,7 +1670,7 @@ endif() #Now the 'onnxruntime_EXTERNAL_LIBRARIES' variable should be sealed. It will be used in onnxruntime.cmake which will be included in the next. #The order of the following targets matters. Right depends on left. If target A appears before target B. Then A.cmake can not use variables defined in B.cmake. -set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_mlas onnxruntime_graph onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME}) +set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_mlas onnxruntime_graph onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME}) if (onnxruntime_USE_WINML) # WINML uses and depends on the shared lib. Note: You can build WINML without DML and you will get a diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index f2be742458313..aac78390cdced 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -207,6 +207,7 @@ set(onnxruntime_INTERNAL_LIBRARIES onnxruntime_optimizer onnxruntime_providers ${onnxruntime_tvm_libs} + onnxruntime_lora onnxruntime_framework onnxruntime_graph onnxruntime_util diff --git a/cmake/onnxruntime_lora.cmake b/cmake/onnxruntime_lora.cmake new file mode 100644 index 0000000000000..7ba48454d997e --- /dev/null +++ b/cmake/onnxruntime_lora.cmake @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +file(GLOB onnxruntime_lora_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/lora_format/*.h" + "${ONNXRUNTIME_ROOT}/lora/*.h" + "${ONNXRUNTIME_ROOT}/lora/*.cc" + ) + +source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_lora_srcs}) + +onnxruntime_add_static_library(onnxruntime_lora ${onnxruntime_lora_srcs}) +onnxruntime_add_include_to_target(onnxruntime_lora onnx flatbuffers::flatbuffers Boost::mp11 ${GSL_TARGET}) +target_link_libraries(onnxruntime_lora onnxruntime_framework) + +if(onnxruntime_ENABLE_INSTRUMENT) + target_compile_definitions(onnxruntime_lora PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT) +endif() + +target_include_directories(onnxruntime_lora PRIVATE ${ONNXRUNTIME_ROOT}) +add_dependencies(onnxruntime_lora ${onnxruntime_EXTERNAL_DEPENDENCIES}) +set_target_properties(onnxruntime_lora PROPERTIES FOLDER "ONNXRuntime") + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_lora + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 3a87a78d2b16e..cb69886ce671a 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -71,9 +71,7 @@ onnxruntime_add_shared_library_module(onnxruntime_pybind11_state ${onnxruntime_p if(MSVC) target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - if(onnxruntime_ENABLE_TRAINING) - target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj") - endif() + target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj") endif() if(HAS_CAST_FUNCTION_TYPE) target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type") @@ -186,6 +184,7 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_providers onnxruntime_util ${onnxruntime_tvm_libs} + onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_graph diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index b51c875951135..47cf2dfc5e7aa 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -30,7 +30,8 @@ endif() source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_session_srcs}) onnxruntime_add_static_library(onnxruntime_session ${onnxruntime_session_srcs}) -onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface nlohmann_json::nlohmann_json) +onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnxruntime_lora onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface nlohmann_json::nlohmann_json) +target_link_libraries(onnxruntime_session PRIVATE onnxruntime_lora) if(onnxruntime_ENABLE_INSTRUMENT) target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT) endif() diff --git a/cmake/onnxruntime_training.cmake b/cmake/onnxruntime_training.cmake index b633a9c2de378..f1a5f908eb245 100644 --- a/cmake/onnxruntime_training.cmake +++ b/cmake/onnxruntime_training.cmake @@ -139,7 +139,7 @@ if (onnxruntime_BUILD_UNIT_TESTS) target_compile_options(onnxruntime_training_mnist PUBLIC "-Wno-maybe-uninitialized") endif() endif() - target_link_libraries(onnxruntime_training_mnist PRIVATE onnxruntime_training_runner onnxruntime_training ${ONNXRUNTIME_LIBS} ${onnxruntime_EXTERNAL_LIBRARIES}) + target_link_libraries(onnxruntime_training_mnist PRIVATE onnxruntime_training_runner onnxruntime_lora onnxruntime_training ${ONNXRUNTIME_LIBS} ${onnxruntime_EXTERNAL_LIBRARIES}) set_target_properties(onnxruntime_training_mnist PROPERTIES FOLDER "ONNXRuntimeTest") # squeezenet diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 27172c6544df3..0148861d42761 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -263,6 +263,11 @@ file(GLOB onnxruntime_test_flatbuffers_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/flatbuffers/*.h" ) +file(GLOB onnxruntime_test_lora_src CONFIGURE_DEPENDS + "${TEST_SRC_DIR}/lora/*.cc" + "${TEST_SRC_DIR}/lora/*.h" +) + if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) file(GLOB onnxruntime_test_ir_src CONFIGURE_DEPENDS @@ -612,6 +617,7 @@ set(ONNXRUNTIME_TEST_LIBS onnxruntime_providers onnxruntime_util ${onnxruntime_tvm_libs} + onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_graph @@ -782,7 +788,7 @@ endif() set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantization_src} - ${onnxruntime_test_flatbuffers_src}) + ${onnxruntime_test_flatbuffers_src} ${onnxruntime_test_lora_src}) if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS @@ -1514,6 +1520,7 @@ endif() onnxruntime_optimizer onnxruntime_providers onnxruntime_util + onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_graph @@ -1634,7 +1641,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") list(APPEND onnxruntime_customopregistration_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp) + list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp) endif() AddTest(DYN TARGET onnxruntime_customopregistration_test @@ -1753,7 +1760,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils) if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp) + list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_lora onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp) endif() if(NOT WIN32) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 0686b66876d9f..3a1576065205f 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -102,6 +102,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB) onnx onnx_proto onnxruntime_common + onnxruntime_lora onnxruntime_flatbuffers onnxruntime_framework onnxruntime_graph @@ -179,6 +180,7 @@ else() onnx onnx_proto onnxruntime_common + onnxruntime_lora onnxruntime_flatbuffers onnxruntime_framework onnxruntime_graph diff --git a/cmake/winml_unittests.cmake b/cmake/winml_unittests.cmake index b655e60a8aec9..68acac584f2c0 100644 --- a/cmake/winml_unittests.cmake +++ b/cmake/winml_unittests.cmake @@ -166,7 +166,7 @@ function (get_winml_test_model_src "${winml_test_src_path}/model/*.cpp") set(${output_winml_test_model_src} ${winml_test_model_src} PARENT_SCOPE) set(${winml_test_model_libs} onnx_test_data_proto onnx_test_runner_common onnxruntime_common onnxruntime_mlas - onnxruntime_graph onnxruntime_test_utils onnxruntime_framework onnxruntime_util onnxruntime_flatbuffers PARENT_SCOPE) + onnxruntime_graph onnxruntime_test_utils onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_flatbuffers PARENT_SCOPE) endfunction() file(GLOB winml_test_common_src CONFIGURE_DEPENDS diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index 789c3b13f2c3e..fab65e8fee692 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -5,9 +5,17 @@ #include #include + +#include "core/common/inlined_containers_fwd.h" #include "core/session/onnxruntime_c_api.h" #include "core/framework/config_options.h" +namespace onnxruntime { +namespace lora { +class LoraAdapter; +} +} // namespace onnxruntime + /** * Configuration information for a Run call. */ @@ -40,6 +48,8 @@ struct OrtRunOptions { // /include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h onnxruntime::ConfigOptions config_options; + onnxruntime::InlinedVector active_adapters; + OrtRunOptions() = default; ~OrtRunOptions() = default; }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 3aa98bb020452..39e0361b7ff4f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -304,6 +304,7 @@ ORT_RUNTIME_CLASS(Op); ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); +ORT_RUNTIME_CLASS(LoraAdapter); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -4670,6 +4671,57 @@ struct OrtApi { _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, size_t num_external_initializer_files); + + /** \brief Create an OrtLoraAdapter + * + * The function attempts to locate file specified by adapter_file_path, read it and create an OrtLoraAdapter + * instance. The adapter_file_path should be a valid absolute path to a file that contains a valid Lora Adapter + * format. The function attempts to validate the format at load time. The file will always be memory mapped, unless + * the platform does not support memory mapping, in which case the file will be read into memory. + * + * \param[in] adapter_file_path adapter file path. + * \param[in] allocator optional pointer to a device allocator. If specified + * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. + * The data would still be copied to device if required by the model at inference time. + * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with + * OrtApi::ReleaseLoraAdapter. + */ + ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); + + /** \brief Create an OrtLoraAdapter + * + * The function copies the bytes from the array and creates an OrtLoraAdapter instance. + * + * + * \param[in] bytes pointer to a valid Lora Adapter format buffer. + * \param[in] num_bytes length of bytes buffer. + * \param[in] allocator optional pointer to a device allocator. If specified + * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. + * The data would still be copied to device if required by the model at inference time. + * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with + * OrtApi::ReleaseLoraAdapter. + */ + ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); + + /** \brief Release an ::OrtLoraAdapter obtained from OrtApi::CreateLoraAdapter + */ + ORT_CLASS_RELEASE(LoraAdapter); + + /** \brief Add the Lora Adapter to the list of active adapters. + * + * The function adds the Lora Adapter to the list of active adapters. The Lora Adapter must be created with + * OrtApi::CreateLoraAdapter or FromArray. The Lora Adapter will be used by the session to run the model. + * The instance of the OrtRunOptions can then be used to customize the Run() calls. + * More than one OrtLoraAdapter can be active at the same time. Lora Parameters that belong to different + * Lora adapters that will be active at the same time must not overlap. + * This setting does not affect RunWithBinding. + * + * \param[in] options OrtRunOptions instance + * \param[in] adapter OrtLoraAdapter instance + */ + ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 29a229f427163..12a6a5c87c0aa 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -508,6 +508,7 @@ ORT_DEFINE_RELEASE(CustomOpDomain); ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(Env); ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(LoraAdapter); ORT_DEFINE_RELEASE(Session); ORT_DEFINE_RELEASE(SessionOptions); ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); @@ -736,6 +737,32 @@ struct CustomOpDomain : detail::Base { void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add }; +/// \brief LoraAdapter holds a set of Lora Parameters loaded from a single file +struct LoraAdapter : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit LoraAdapter(std::nullptr_t) {} ///< Create an empty LoraAdapter object, must be assigned a valid one to be used + /// \brief Wraps OrtApi::CreateLoraAdapter + /// + /// The function attempts to load the adapter from the specified file + /// \param adapter_path The path to the Lora adapter + /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still + /// be copied to device if required by the model at inference time. + static LoraAdapter CreateLoraAdapter(const std::basic_string& adapter_path, + OrtAllocator* allocator); + + /// \brief Wraps OrtApi::CreateLoraAdapterFromArray + /// + /// The function attempts to load the adapter from the specified byte array. + /// \param bytes The byte array containing file LoraAdapter format + /// \param num_bytes The number of bytes in the byte array + /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still + /// be copied to device if required by the model at inference time. + static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes, + OrtAllocator* allocator); +}; + /** \brief RunOptions * */ @@ -766,6 +793,14 @@ struct RunOptions : detail::Base { * Wraps OrtApi::RunOptionsUnsetTerminate */ RunOptions& UnsetTerminate(); + + /** \brief Add the LoraAdapter to the list of active adapters. + * The setting does not affect RunWithBinding() calls. + * + * Wraps OrtApi::RunOptionsSetLoraAdapterActive + * \param adapter The LoraAdapter to be used as the active adapter + */ + RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter); }; namespace detail { diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d3a8cade4d28f..7401cb2438121 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -557,6 +557,20 @@ inline void CustomOpDomain::Add(const OrtCustomOp* op) { ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); } +inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string& adapter_path, + OrtAllocator* allocator) { + OrtLoraAdapter* p; + ThrowOnError(GetApi().CreateLoraAdapter(adapter_path.c_str(), allocator, &p)); + return LoraAdapter{p}; +} + +inline LoraAdapter LoraAdapter::CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes, + OrtAllocator* allocator) { + OrtLoraAdapter* p; + ThrowOnError(GetApi().CreateLoraAdapterFromArray(bytes, num_bytes, allocator, &p)); + return LoraAdapter{p}; +} + inline RunOptions::RunOptions() { ThrowOnError(GetApi().CreateRunOptions(&p_)); } @@ -609,6 +623,11 @@ inline RunOptions& RunOptions::UnsetTerminate() { return *this; } +inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) { + ThrowOnError(GetApi().RunOptionsAddActiveLoraAdapter(p_, adapter)); + return *this; +} + namespace detail { template diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index e4d85c9d7b975..0e9a924bde4bb 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -23,6 +23,7 @@ from onnxruntime.capi._pybind_state import ExecutionMode # noqa: F401 from onnxruntime.capi._pybind_state import ExecutionOrder # noqa: F401 from onnxruntime.capi._pybind_state import GraphOptimizationLevel # noqa: F401 + from onnxruntime.capi._pybind_state import LoraAdapter # noqa: F401 from onnxruntime.capi._pybind_state import ModelMetadata # noqa: F401 from onnxruntime.capi._pybind_state import NodeArg # noqa: F401 from onnxruntime.capi._pybind_state import OrtAllocatorType # noqa: F401 @@ -56,6 +57,7 @@ if import_capi_exception: raise import_capi_exception +from onnxruntime.capi.onnxruntime_inference_collection import AdapterFormat # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401 diff --git a/onnxruntime/core/framework/config_options.h b/onnxruntime/core/framework/config_options.h index 7b7c226819e79..efdfdb45abbaa 100644 --- a/onnxruntime/core/framework/config_options.h +++ b/onnxruntime/core/framework/config_options.h @@ -19,7 +19,7 @@ struct ConfigOptions { // Gets the config string associated with the given config_key. // If not found, an empty optional is returned. - optional GetConfigEntry(const std::string& config_key) const noexcept; + std::optional GetConfigEntry(const std::string& config_key) const noexcept; // Check if this instance of ConfigOptions has a config using the given config_key. // Returns true if found and copies the value into config_value. diff --git a/onnxruntime/core/framework/run_options.cc b/onnxruntime/core/framework/run_options.cc index 95c111009c791..cb07cc22b1b2f 100644 --- a/onnxruntime/core/framework/run_options.cc +++ b/onnxruntime/core/framework/run_options.cc @@ -5,9 +5,11 @@ #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26409) #endif + ORT_API_STATUS_IMPL(OrtApis::CreateRunOptions, _Outptr_ OrtRunOptions** out) { API_IMPL_BEGIN *out = new OrtRunOptions(); @@ -60,3 +62,12 @@ ORT_API_STATUS_IMPL(OrtApis::AddRunConfigEntry, _Inout_ OrtRunOptions* options, _In_z_ const char* config_key, _In_z_ const char* config_value) { return onnxruntime::ToOrtStatus(options->config_options.AddConfigEntry(config_key, config_value)); } + +ORT_API_STATUS_IMPL(OrtApis::RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, + const _In_ OrtLoraAdapter* adapter) { + API_IMPL_BEGIN + auto* lora_adapter = reinterpret_cast(adapter); + options->active_adapters.push_back(lora_adapter); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/session/lora_adapters.cc b/onnxruntime/core/session/lora_adapters.cc new file mode 100644 index 0000000000000..466edce187a56 --- /dev/null +++ b/onnxruntime/core/session/lora_adapters.cc @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/lora_adapters.h" +#include "lora/adapter_format_utils.h" + +#include + +#include "core/framework/data_transfer.h" +#include "core/framework/error_code_helper.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/allocator_adapters.h" +#include "core/session/ort_apis.h" + +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_factory.h" +#endif + +namespace onnxruntime { + +#ifdef USE_CUDA +ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); +#endif + +namespace lora { + +LoraAdapter::Param::Param(OrtValue ort_value_mapped) noexcept + : ort_value_mapped_(std::move(ort_value_mapped)) {} + +LoraAdapter::Param::Param(OrtValue ort_value_mapped, OrtValue ort_value_device) noexcept + : ort_value_mapped_(std::move(ort_value_mapped)), ort_value_device_(std::move(ort_value_device)) { +} + +void LoraAdapter::Load(const std::filesystem::path& file_path) { + auto buffer = adapters::utils::LoadLoraAdapterBytes(file_path); + Load(std::move(buffer)); +} + +void LoraAdapter::Load(std::vector buffer) { + adapter_ = adapters::utils::ValidateAndGetAdapterFromBytes(buffer); + buffer_.emplace(std::move(buffer)); + InitializeParamsValues(); +} + +void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) { + auto [mapped_memory, file_size] = adapters::utils::MemoryMapAdapterFile(file_path); + auto u8_span = ReinterpretAsSpan(gsl::make_span(mapped_memory.get(), file_size)); + adapter_ = adapters::utils::ValidateAndGetAdapterFromBytes(u8_span); + buffer_.emplace(std::move(mapped_memory), file_size); + InitializeParamsValues(); +} + +static std::unique_ptr GetDataTransfer(const OrtMemoryInfo& mem_info) { + std::unique_ptr data_transfer; + + if (strcmp(mem_info.name, onnxruntime::CPU) == 0) { + return data_transfer; + } + + if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) { +#ifdef USE_CUDA + auto* cuda_provider_info = TryGetProviderInfo_CUDA(); + if (cuda_provider_info != nullptr) { + data_transfer = cuda_provider_info->CreateGPUDataTransfer(); + } +#endif + } + + return data_transfer; +} + +static Status CreateOrtValueOnDevice(const OrtValue& ort_value_mapped, + const AllocatorPtr& device_allocator, + const IDataTransfer& data_transfer, + OrtValue& out) { + OrtValue result; + const auto& src = ort_value_mapped.Get(); + Tensor on_device(src.DataType(), src.Shape(), device_allocator); + ORT_RETURN_IF_ERROR(data_transfer.CopyTensor(src, on_device)); + Tensor::InitOrtValue(std::move(on_device), result); + out = std::move(result); + return Status::OK(); +} + +void LoraAdapter::InitializeParamsValues() { + if (adapter_ == nullptr) { + ORT_THROW("Adapter is not loaded yet."); + } + + std::unique_ptr data_transfer; + if (device_allocator_) { + data_transfer = GetDataTransfer(device_allocator_->Info()); + if (data_transfer == nullptr) { + ORT_THROW("Data transfer is not available for the specified device allocator, it also must not be a CPU allocator"); + } + } + + const auto* params = adapter_->parameters(); + ORT_ENFORCE(params != nullptr, "Params absent"); + std::unordered_map params_values; + params_values.reserve(params->size()); + // Re-work in two separate loops due to compiler issues + if (data_transfer) { + for (const auto* param : *params) { + auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param); + OrtValue ort_value_ondevice; + ORT_THROW_IF_ERROR(CreateOrtValueOnDevice(ort_value, device_allocator_, + *data_transfer, ort_value_ondevice)); + Param lora_param(std::move(ort_value), std::move(ort_value_ondevice)); + params_values.emplace(std::move(name), std::move(lora_param)); + } + } else { + for (const auto* param : *params) { + auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param); + Param lora_param(std::move(ort_value)); + params_values.emplace(std::move(name), std::move(lora_param)); + } + } + + params_values_.swap(params_values); +} + +} // namespace lora +} // namespace onnxruntime + +ORT_API_STATUS_IMPL(OrtApis::CreateLoraAdapter, _In_ const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** adapter) { + API_IMPL_BEGIN + + std::unique_ptr lora_adapter; + if (allocator != nullptr) { + auto alloc_ptr = std::make_shared(allocator); + lora_adapter = std::make_unique(std::move(alloc_ptr)); + } else { + lora_adapter = std::make_unique(); + } + // For platforms that do not support Memmap, we can #ifdef it to ->Load(adapter_file_path) + lora_adapter->MemoryMap(adapter_file_path); + *adapter = reinterpret_cast(lora_adapter.release()); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, + _In_ OrtAllocator* allocator, _Outptr_ OrtLoraAdapter** adapter) { + API_IMPL_BEGIN + + std::unique_ptr lora_adapter; + if (allocator != nullptr) { + auto alloc_ptr = std::make_shared(allocator); + lora_adapter = std::make_unique(std::move(alloc_ptr)); + } else { + lora_adapter = std::make_unique(); + } + + std::vector buffer(num_bytes); + memcpy(buffer.data(), bytes, num_bytes); + lora_adapter->Load(std::move(buffer)); + *adapter = reinterpret_cast(lora_adapter.release()); + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseLoraAdapter, _Frees_ptr_opt_ OrtLoraAdapter* adapter) { + delete reinterpret_cast(adapter); +} diff --git a/onnxruntime/core/session/lora_adapters.h b/onnxruntime/core/session/lora_adapters.h new file mode 100644 index 0000000000000..77534b2bb7d15 --- /dev/null +++ b/onnxruntime/core/session/lora_adapters.h @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/framework/allocator.h" +#include "core/framework/ort_value.h" +#include "core/platform/env.h" + +#include "lora/adapter_format_utils.h" + +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace lora { + +/// +/// Container to hold and access Lora Parameters +/// +class LoraAdapter { + public: + LoraAdapter() = default; + explicit LoraAdapter(AllocatorPtr device_allocator) + : device_allocator_(std::move(device_allocator)) {} + ~LoraAdapter() = default; + LoraAdapter(const LoraAdapter&) = delete; + LoraAdapter& operator=(const LoraAdapter&) = delete; + + LoraAdapter(LoraAdapter&&) = default; + LoraAdapter& operator=(LoraAdapter&&) = default; + + /// + /// Represents a named lora parameter (tensor) + /// + class Param { + public: + Param() = default; + explicit Param(OrtValue ort_value_mapped) noexcept; + Param(OrtValue ort_value_mapped, OrtValue ort_value_device) noexcept; + + const OrtValue& GetMapped() const noexcept { + return ort_value_mapped_; + } + + // For python interface + OrtValue& GetMapped() noexcept { + return ort_value_mapped_; + } + + const OrtValue& GetDeviceOrMapped() const noexcept { + if (ort_value_device_.IsAllocated()) { + return ort_value_device_; + } + return ort_value_mapped_; + } + + private: + OrtValue ort_value_mapped_; + OrtValue ort_value_device_; + }; + + using param_const_iterator = std::unordered_map::const_iterator; + using param_iterator = std::unordered_map::iterator; + + /// + /// Obtain a range of the iterators + /// + /// + std::pair GetParamIterators() const { + return std::make_pair(params_values_.cbegin(), params_values_.cend()); + } + + std::pair GetParamIterators() { + return std::make_pair(params_values_.begin(), params_values_.end()); + } + + /// + /// Load parameters into memory from an adapter file and validates its format. + /// + /// file name that can be opened + void Load(const std::filesystem::path& file_path); + + /// + /// Load parameters from serialized bytes and validates its format. + /// + /// + void Load(std::vector buffer); + + /// + /// Memory maps adapter file into memory and validates its format. + /// + /// + void MemoryMap(const std::filesystem::path& file_path); + + /// + /// Returns number of parameters in the adapter. + /// The number is expected to be even as lora params come in pairs. + /// + /// size of params_values_ container + size_t GetParamNum() const { + return params_values_.size(); + } + + /// + /// Gets lora format version + /// + /// + int FormatVersion() const noexcept { + return adapter_->format_version(); + } + + /// + /// Gets adapter version + /// + /// + int AdapterVersion() const noexcept { + return adapter_->adapter_version(); + } + + /// + /// Gets model version for which the adapter was created + /// + /// + int ModelVersion() const noexcept { + return adapter_->model_version(); + } + + /// + /// Outputs Lora Parameters on CPU, their names and values + /// into the supplied output iterators. + /// + /// + /// + /// output iterator that accepts const char* + /// output iterator that accepts const OrtValue* + template + void OutputAdapterParameters(NamesOutputIter names_out, + TensorOutputIter tensor_out) const { + for (const auto& [name, param] : params_values_) { + *names_out = name.c_str(); + ++names_out; + *tensor_out = ¶m.GetDeviceOrMapped(); + ++tensor_out; + } + } + + private: + void InitializeParamsValues(); + + struct BufferHolder { + explicit BufferHolder(std::vector buffer) : buffer_(std::move(buffer)) {} + std::vector buffer_; + }; + + struct MemMapHolder { + MemMapHolder(Env::MappedMemoryPtr mapped_memory, size_t file_size) + : mapped_memory_(std::move(mapped_memory)), file_size_(file_size) {} + Env::MappedMemoryPtr mapped_memory_; + size_t file_size_; + }; + + std::variant buffer_; + + AllocatorPtr device_allocator_; + const adapters::Adapter* adapter_{nullptr}; + std::unordered_map params_values_; +}; + +} // namespace lora +} // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 1a5484ddc0055..64546e634694f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -39,6 +39,8 @@ #include "core/platform/ort_mutex.h" #include "core/common/string_helper.h" +#include "core/session/lora_adapters.h" + #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/cuda_execution_provider_info.h" @@ -813,6 +815,34 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In API_IMPL_END } +namespace { +// Checks if there are active lora adapters and adjusts input spans. +void CheckAndAdjustInputSpansForLora(const OrtRunOptions& run_options, + InlinedVector& input_names_with_lora, + InlinedVector& inputs_with_lora, + gsl::span& input_names, + gsl::span& inputs) { + size_t total_lora_params = 0; + for (const lora::LoraAdapter* ad : run_options.active_adapters) { + total_lora_params += ad->GetParamNum(); + } + + input_names_with_lora.reserve(input_names.size() + total_lora_params); + inputs_with_lora.reserve(inputs.size() + total_lora_params); + std::copy(input_names.begin(), input_names.end(), std::back_inserter(input_names_with_lora)); + std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputs_with_lora)); + + for (const lora::LoraAdapter* ad : run_options.active_adapters) { + ad->OutputAdapterParameters(std::back_inserter(input_names_with_lora), + std::back_inserter(inputs_with_lora)); + } + + input_names = gsl::make_span(input_names_with_lora); + inputs = gsl::make_span(inputs_with_lora); +} + +} // namespace + ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, _In_reads_(input_len) const char* const* input_names, _In_reads_(input_len) const OrtValue* const* input, size_t input_len, @@ -821,18 +851,31 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRu API_IMPL_BEGIN auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess); - gsl::span input_names_span(input_names, input_len); - gsl::span input_span(input, input_len); - gsl::span output_name_span(output_names, output_names_len); - gsl::span output_span(output, output_names_len); + auto input_names_span = gsl::make_span(input_names, input_len); + auto input_span = gsl::make_span(input, input_len); + auto output_name_span = gsl::make_span(output_names, output_names_len); + auto output_span = gsl::make_span(output, output_names_len); Status status; - if (run_options) { - status = session->Run(*run_options, - input_names_span, - input_span, - output_name_span, - output_span); + if (run_options != nullptr) { + if (!run_options->active_adapters.empty()) { + InlinedVector input_names_with_lora; + InlinedVector input_with_lora; + + CheckAndAdjustInputSpansForLora(*run_options, input_names_with_lora, input_with_lora, input_names_span, input_span); + + status = session->Run(*run_options, + input_names_span, + input_span, + output_name_span, + output_span); + } else { + status = session->Run(*run_options, + input_names_span, + input_span, + output_name_span, + output_span); + } } else { const RunOptions default_run_options; status = session->Run(default_run_options, @@ -854,10 +897,14 @@ ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const API_IMPL_BEGIN auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess); - gsl::span input_names_span(input_names, input_len); - gsl::span input_span(input, input_len); - gsl::span output_name_span(output_names, output_names_len); - gsl::span output_span(output, output_names_len); + if (run_options != nullptr && !run_options->active_adapters.empty()) { + LOGS(*session->GetLogger(), WARNING) << "RunAsync() active adapters specified, but won't have an effect"; + } + + auto input_names_span = gsl::make_span(input_names, input_len); + auto input_span = gsl::make_span(input, input_len); + auto output_name_span = gsl::make_span(output_names, output_names_len); + auto output_span = gsl::make_span(output, output_names_len); return ToOrtStatus(session->RunAsync(run_options, input_names_span, @@ -885,6 +932,10 @@ ORT_API_STATUS_IMPL(OrtApis::RunWithBinding, _Inout_ OrtSession* sess, _In_ cons OrtRunOptions default_run_options; status = session->Run(default_run_options, *binding_ptr->binding_); } else { + if (!run_options->active_adapters.empty()) { + LOGS(*session->GetLogger(), WARNING) + << "RunWithBinding() has active adapters specified, but won't have an effect"; + } status = session->Run(*run_options, *binding_ptr->binding_); } if (!status.IsOK()) { @@ -2730,6 +2781,10 @@ static constexpr OrtApi ort_api_1_to_20 = { &OrtApis::KernelInfoGetAllocator, &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::CreateLoraAdapter, + &OrtApis::CreateLoraAdapterFromArray, + &OrtApis::ReleaseLoraAdapter, + &OrtApis::RunOptionsAddActiveLoraAdapter, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fcae173e6c162..9054246873232 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -523,4 +523,12 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessi ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); + +ORT_API_STATUS_IMPL(CreateLoraAdapter, _In_ const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); +ORT_API_STATUS_IMPL(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); +ORT_API(void, ReleaseLoraAdapter, _Frees_ptr_opt_ OrtLoraAdapter*); +ORT_API_STATUS_IMPL(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); + } // namespace OrtApis diff --git a/onnxruntime/lora/adapter_format/README.md b/onnxruntime/lora/adapter_format/README.md new file mode 100644 index 0000000000000..65011c93729f8 --- /dev/null +++ b/onnxruntime/lora/adapter_format/README.md @@ -0,0 +1,36 @@ +# Lora Parameters Flatbuffer Schemas +This directory contains [ONNXRuntime Lora Parameter format schema](lora_schema.fbs) and [the generated C++ header file](lora_schema.fbs.h) for the +Lora Parameters file format. This file format is defined as a means to deliver Lora parameters so it can read by ONNXRuntime C++ code. + +The format is generally designed to house a single Lora adapter with named Lora parameters. + +[ONNXRuntime Lora Parameter file format schema](lora_schema.fbs) uses the [FlatBuffers](https://github.com/google/flatbuffers) serialization library. + +Please do not directly modify the generated C++ header file for [ONNXRuntime Lora Parameter file format]((lora_schema.fbs.h)). + +Use flatc compiler for the purpose. + +e.g. + - Windows Debug build + - \build\Windows\Debug\_deps\flatbuffers-build\Debug\flatc.exe + - Linux Debug build + - /build/Linux/Debug/_deps/flatbuffers-build/flatc + +It is possible to use another flatc as well, e.g., from a separate installation. + +To update the flatbuffers schemas and generated files: +1. Modify [ONNXRuntime Lora Parameter file format schema](lora_schema.fbs). +2. Run [compile_schema.py](./compile_schema.py) to generate the C++ bindings. + + ``` + python onnxruntime/lora/lora_format/compile_schema.py --flatc + ``` +# Lora format version history +In [lora_format_version.h](../lora_format_version.h), see `IsLoraParameterslVersionSupported()` for the supported versions and +`kLoraParametersVersion` for the current version. + +## Version 1 +History begins. + +Initial support for FlatBuffers that Lora Parameters support. This includes a definition of Tensor entity +so it can be saved in a tensor per file format. diff --git a/onnxruntime/lora/adapter_format/adapter_schema.fbs b/onnxruntime/lora/adapter_format/adapter_schema.fbs new file mode 100644 index 0000000000000..da1f8dcf5da92 --- /dev/null +++ b/onnxruntime/lora/adapter_format/adapter_schema.fbs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace onnxruntime.adapters; + +// Tensor +enum TensorDataType : int32 { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, + FLOAT8E4M3FN = 17, + FLOAT8E4M3FNUZ = 18, + FLOAT8E5M2 = 19, + FLOAT8E5M2FNUZ = 20, +} + +// For simplicity, we will have only have one data field +// - raw_data for all primitive types. +// We do not foresee strings as parameters. +table Parameter { + name:string; + + dims:[int64]; + data_type:TensorDataType; + + raw_data:[uint8] (force_align : 8); +} + +table Adapter { + format_version:int; + adapter_version:int; + model_version:int; + parameters:[Parameter]; +} + +root_type Adapter; +file_identifier "TORT"; diff --git a/onnxruntime/lora/adapter_format/adapter_schema.fbs.h b/onnxruntime/lora/adapter_format/adapter_schema.fbs.h new file mode 100644 index 0000000000000..c1d5412acbbde --- /dev/null +++ b/onnxruntime/lora/adapter_format/adapter_schema.fbs.h @@ -0,0 +1,338 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef FLATBUFFERS_GENERATED_ADAPTERSCHEMA_ONNXRUNTIME_ADAPTERS_H_ +#define FLATBUFFERS_GENERATED_ADAPTERSCHEMA_ONNXRUNTIME_ADAPTERS_H_ + +#include "flatbuffers/flatbuffers.h" + +// Ensure the included flatbuffers.h is the same version as when this file was +// generated, otherwise it may not be compatible. +static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && + FLATBUFFERS_VERSION_MINOR == 5 && + FLATBUFFERS_VERSION_REVISION == 26, + "Non-compatible flatbuffers version included"); + +namespace onnxruntime { +namespace adapters { + +struct Parameter; +struct ParameterBuilder; + +struct Adapter; +struct AdapterBuilder; + +enum class TensorDataType : int32_t { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, + FLOAT8E4M3FN = 17, + FLOAT8E4M3FNUZ = 18, + FLOAT8E5M2 = 19, + FLOAT8E5M2FNUZ = 20, + MIN = UNDEFINED, + MAX = FLOAT8E5M2FNUZ +}; + +inline const TensorDataType (&EnumValuesTensorDataType())[21] { + static const TensorDataType values[] = { + TensorDataType::UNDEFINED, + TensorDataType::FLOAT, + TensorDataType::UINT8, + TensorDataType::INT8, + TensorDataType::UINT16, + TensorDataType::INT16, + TensorDataType::INT32, + TensorDataType::INT64, + TensorDataType::STRING, + TensorDataType::BOOL, + TensorDataType::FLOAT16, + TensorDataType::DOUBLE, + TensorDataType::UINT32, + TensorDataType::UINT64, + TensorDataType::COMPLEX64, + TensorDataType::COMPLEX128, + TensorDataType::BFLOAT16, + TensorDataType::FLOAT8E4M3FN, + TensorDataType::FLOAT8E4M3FNUZ, + TensorDataType::FLOAT8E5M2, + TensorDataType::FLOAT8E5M2FNUZ}; + return values; +} + +inline const char* const* EnumNamesTensorDataType() { + static const char* const names[22] = { + "UNDEFINED", + "FLOAT", + "UINT8", + "INT8", + "UINT16", + "INT16", + "INT32", + "INT64", + "STRING", + "BOOL", + "FLOAT16", + "DOUBLE", + "UINT32", + "UINT64", + "COMPLEX64", + "COMPLEX128", + "BFLOAT16", + "FLOAT8E4M3FN", + "FLOAT8E4M3FNUZ", + "FLOAT8E5M2", + "FLOAT8E5M2FNUZ", + nullptr}; + return names; +} + +inline const char* EnumNameTensorDataType(TensorDataType e) { + if (::flatbuffers::IsOutRange(e, TensorDataType::UNDEFINED, TensorDataType::FLOAT8E5M2FNUZ)) return ""; + const size_t index = static_cast(e); + return EnumNamesTensorDataType()[index]; +} + +struct Parameter FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ParameterBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_DIMS = 6, + VT_DATA_TYPE = 8, + VT_RAW_DATA = 10 + }; + const ::flatbuffers::String* name() const { + return GetPointer(VT_NAME); + } + const ::flatbuffers::Vector* dims() const { + return GetPointer*>(VT_DIMS); + } + onnxruntime::adapters::TensorDataType data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const ::flatbuffers::Vector* raw_data() const { + return GetPointer*>(VT_RAW_DATA); + } + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_DIMS) && + verifier.VerifyVector(dims()) && + VerifyField(verifier, VT_DATA_TYPE, 4) && + VerifyOffset(verifier, VT_RAW_DATA) && + verifier.VerifyVector(raw_data()) && + verifier.EndTable(); + } +}; + +struct ParameterBuilder { + typedef Parameter Table; + ::flatbuffers::FlatBufferBuilder& fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(Parameter::VT_NAME, name); + } + void add_dims(::flatbuffers::Offset<::flatbuffers::Vector> dims) { + fbb_.AddOffset(Parameter::VT_DIMS, dims); + } + void add_data_type(onnxruntime::adapters::TensorDataType data_type) { + fbb_.AddElement(Parameter::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_raw_data(::flatbuffers::Offset<::flatbuffers::Vector> raw_data) { + fbb_.AddOffset(Parameter::VT_RAW_DATA, raw_data); + } + explicit ParameterBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateParameter( + ::flatbuffers::FlatBufferBuilder& _fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> dims = 0, + onnxruntime::adapters::TensorDataType data_type = onnxruntime::adapters::TensorDataType::UNDEFINED, + ::flatbuffers::Offset<::flatbuffers::Vector> raw_data = 0) { + ParameterBuilder builder_(_fbb); + builder_.add_raw_data(raw_data); + builder_.add_data_type(data_type); + builder_.add_dims(dims); + builder_.add_name(name); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateParameterDirect( + ::flatbuffers::FlatBufferBuilder& _fbb, + const char* name = nullptr, + const std::vector* dims = nullptr, + onnxruntime::adapters::TensorDataType data_type = onnxruntime::adapters::TensorDataType::UNDEFINED, + const std::vector* raw_data = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto dims__ = dims ? _fbb.CreateVector(*dims) : 0; + if (raw_data) { + _fbb.ForceVectorAlignment(raw_data->size(), sizeof(uint8_t), 8); + } + auto raw_data__ = raw_data ? _fbb.CreateVector(*raw_data) : 0; + return onnxruntime::adapters::CreateParameter( + _fbb, + name__, + dims__, + data_type, + raw_data__); +} + +struct Adapter FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef AdapterBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FORMAT_VERSION = 4, + VT_ADAPTER_VERSION = 6, + VT_MODEL_VERSION = 8, + VT_PARAMETERS = 10 + }; + int32_t format_version() const { + return GetField(VT_FORMAT_VERSION, 0); + } + int32_t adapter_version() const { + return GetField(VT_ADAPTER_VERSION, 0); + } + int32_t model_version() const { + return GetField(VT_MODEL_VERSION, 0); + } + const ::flatbuffers::Vector<::flatbuffers::Offset>* parameters() const { + return GetPointer>*>(VT_PARAMETERS); + } + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FORMAT_VERSION, 4) && + VerifyField(verifier, VT_ADAPTER_VERSION, 4) && + VerifyField(verifier, VT_MODEL_VERSION, 4) && + VerifyOffset(verifier, VT_PARAMETERS) && + verifier.VerifyVector(parameters()) && + verifier.VerifyVectorOfTables(parameters()) && + verifier.EndTable(); + } +}; + +struct AdapterBuilder { + typedef Adapter Table; + ::flatbuffers::FlatBufferBuilder& fbb_; + ::flatbuffers::uoffset_t start_; + void add_format_version(int32_t format_version) { + fbb_.AddElement(Adapter::VT_FORMAT_VERSION, format_version, 0); + } + void add_adapter_version(int32_t adapter_version) { + fbb_.AddElement(Adapter::VT_ADAPTER_VERSION, adapter_version, 0); + } + void add_model_version(int32_t model_version) { + fbb_.AddElement(Adapter::VT_MODEL_VERSION, model_version, 0); + } + void add_parameters(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> parameters) { + fbb_.AddOffset(Adapter::VT_PARAMETERS, parameters); + } + explicit AdapterBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateAdapter( + ::flatbuffers::FlatBufferBuilder& _fbb, + int32_t format_version = 0, + int32_t adapter_version = 0, + int32_t model_version = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> parameters = 0) { + AdapterBuilder builder_(_fbb); + builder_.add_parameters(parameters); + builder_.add_model_version(model_version); + builder_.add_adapter_version(adapter_version); + builder_.add_format_version(format_version); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateAdapterDirect( + ::flatbuffers::FlatBufferBuilder& _fbb, + int32_t format_version = 0, + int32_t adapter_version = 0, + int32_t model_version = 0, + const std::vector<::flatbuffers::Offset>* parameters = nullptr) { + auto parameters__ = parameters ? _fbb.CreateVector<::flatbuffers::Offset>(*parameters) : 0; + return onnxruntime::adapters::CreateAdapter( + _fbb, + format_version, + adapter_version, + model_version, + parameters__); +} + +inline const onnxruntime::adapters::Adapter* GetAdapter(const void* buf) { + return ::flatbuffers::GetRoot(buf); +} + +inline const onnxruntime::adapters::Adapter* GetSizePrefixedAdapter(const void* buf) { + return ::flatbuffers::GetSizePrefixedRoot(buf); +} + +inline const char* AdapterIdentifier() { + return "TORT"; +} + +inline bool AdapterBufferHasIdentifier(const void* buf) { + return ::flatbuffers::BufferHasIdentifier( + buf, AdapterIdentifier()); +} + +inline bool SizePrefixedAdapterBufferHasIdentifier(const void* buf) { + return ::flatbuffers::BufferHasIdentifier( + buf, AdapterIdentifier(), true); +} + +inline bool VerifyAdapterBuffer( + ::flatbuffers::Verifier& verifier) { + return verifier.VerifyBuffer(AdapterIdentifier()); +} + +inline bool VerifySizePrefixedAdapterBuffer( + ::flatbuffers::Verifier& verifier) { + return verifier.VerifySizePrefixedBuffer(AdapterIdentifier()); +} + +inline void FinishAdapterBuffer( + ::flatbuffers::FlatBufferBuilder& fbb, + ::flatbuffers::Offset root) { + fbb.Finish(root, AdapterIdentifier()); +} + +inline void FinishSizePrefixedAdapterBuffer( + ::flatbuffers::FlatBufferBuilder& fbb, + ::flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root, AdapterIdentifier()); +} + +} // namespace adapters +} // namespace onnxruntime + +#endif // FLATBUFFERS_GENERATED_ADAPTERSCHEMA_ONNXRUNTIME_ADAPTERS_H_ diff --git a/onnxruntime/lora/adapter_format/compile_schema.py b/onnxruntime/lora/adapter_format/compile_schema.py new file mode 100644 index 0000000000000..4536c48391dda --- /dev/null +++ b/onnxruntime/lora/adapter_format/compile_schema.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import pathlib +import subprocess + +SCRIPT_DIR = pathlib.Path(__file__).parent.resolve() + + +def generate_cpp(flatc: pathlib.Path, schema_path: pathlib.Path): + # run flatc to generate C++ code + cmd = [str(flatc), "--cpp", "--scoped-enums", "--filename-suffix", ".fbs", str(schema_path)] + subprocess.run(cmd, check=True, cwd=SCRIPT_DIR) + + +def main(): + parser = argparse.ArgumentParser( + description="Generate language bindings for the ORT flatbuffers schema.", + usage="Provide the path to the flatbuffers flatc executable. " + "Script can be executed from anywhere but must be located in its original " + "directory in the ONNX Runtime enlistment.", + ) + + parser.add_argument( + "-f", + "--flatc", + required=True, + type=pathlib.Path, + help="Path to flatbuffers flatc executable. " + "Can be found in the build directory under _deps/flatbuffers-build//", + ) + + all_languages = ["cpp"] + parser.add_argument( + "-l", + "--language", + action="append", + dest="languages", + choices=all_languages, + help="Specify which language bindings to generate.", + ) + + args = parser.parse_args() + languages = args.languages if args.languages is not None else all_languages + flatc = args.flatc.resolve(strict=True) + schema_path = SCRIPT_DIR / "adapter_schema.fbs" + + if "cpp" in languages: + generate_cpp(flatc, schema_path) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/lora/adapter_format_utils.cc b/onnxruntime/lora/adapter_format_utils.cc new file mode 100644 index 0000000000000..9a6f8f3b7b1c8 --- /dev/null +++ b/onnxruntime/lora/adapter_format_utils.cc @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "adapter_format_utils.h" +#include "adapter_format_version.h" + +#include "core/framework/allocator.h" +#include "core/common/common.h" +#include "core/common/span_utils.h" +#include "core/framework/ortdevice.h" +#include "core/framework/ortmemoryinfo.h" +#include "core/framework/ort_value.h" +#include "core/framework/tensor.h" + +#include + +namespace onnxruntime { +namespace adapters { +namespace utils { + +bool IsAdapterFormatModelBytes(const void* bytes, size_t num_bytes) { + return num_bytes > 8 && // check buffer is large enough to contain identifier so we don't read random memory + AdapterBufferHasIdentifier(bytes); +} + +void LoadStringFromLoraFormat(std::string& dst, const flatbuffers::String* fbs_string) { + if (fbs_string) { + dst = fbs_string->str(); + } +} + +std::vector LoadLoraAdapterBytes(const std::filesystem::path& file_path) { + Env& env = Env::Default(); + + size_t file_size = 0; + ORT_THROW_IF_ERROR(env.GetFileLength(file_path.c_str(), file_size)); + + std::vector result; + result.resize(file_size); + + // The API accepts char span, so we need to reinterpret the uint8_t span as char span + auto dest_span = ReinterpretAsSpan(AsSpan(result)); + ORT_THROW_IF_ERROR(env.ReadFileIntoBuffer(file_path.c_str(), 0, file_size, dest_span)); + + return result; +} + +std::pair MemoryMapAdapterFile(const std::filesystem::path& file_path) { + Env& env = Env::Default(); + + size_t file_size = 0; + ORT_THROW_IF_ERROR(env.GetFileLength(file_path.c_str(), file_size)); + + Env::MappedMemoryPtr result; + ORT_THROW_IF_ERROR(env.MapFileIntoMemory(file_path.c_str(), 0, file_size, result)); + + return {std::move(result), file_size}; +} + +const Adapter* ValidateAndGetAdapterFromBytes(gsl::span bytes) { + if (!IsAdapterFormatModelBytes(bytes.data(), bytes.size())) { + ORT_THROW("The buffer does not appear to be a valid lora parameter format"); + } + + flatbuffers::Verifier verifier(bytes.data(), bytes.size()); + if (!VerifyAdapterBuffer(verifier)) { + ORT_THROW("The buffer fails lora adapter format verification"); + } + + auto* adapter = GetAdapter(bytes.data()); + if (!IsAdapterFormatVersionSupported(adapter->format_version())) { + ORT_THROW("Unsupported lora format version"); + } + + return adapter; +} + +void SaveLoraParameter(flatbuffers::FlatBufferBuilder& flat_builder, std::string_view name, + TensorDataType data_type, gsl::span shape, + gsl::span data, + flatbuffers::Offset& fbs_tensor) { + auto name_str = (name.empty()) ? 0 : flat_builder.CreateString(name.data(), name.size()); + auto shape_vec = flat_builder.CreateVector(shape.data(), shape.size()); + auto data_vec = flat_builder.CreateVector(data.data(), data.size()); + + fbs_tensor = CreateParameter(flat_builder, name_str, shape_vec, data_type, data_vec); +} + +std::pair CreateOrtValueOverLoraParameter(const Parameter& param) { + OrtValue result; + + std::string name; + LoadStringFromLoraFormat(name, param.name()); + + const auto data_type = param.data_type(); + gsl::span shape_span(param.dims()->data(), param.dims()->size()); + + static const OrtMemoryInfo cpu_meminfo(CPU, OrtAllocatorType::OrtDeviceAllocator); + + auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType(); + // const_cast is necessery due to Tensor class API + Tensor::InitOrtValue(elem_type, + TensorShape(shape_span), + const_cast(param.raw_data()->data()), + cpu_meminfo, + result); + + return std::make_pair(std::move(name), std::move(result)); +} + +void AdapterFormatBuilder::AddParameter(const std::string& name, TensorDataType data_type, + gsl::span shape, gsl::span data) { + flatbuffers::Offset fbs_param; + SaveLoraParameter(builder_, name, data_type, shape, data, fbs_param); + params_.push_back(fbs_param); +} + +std::vector AdapterFormatBuilder::Finish(int adapter_version, int model_version) { + FinishImpl(adapter_version, model_version); + + std::vector result; + result.reserve(builder_.GetSize()); + gsl::span buffer(builder_.GetBufferPointer(), builder_.GetSize()); + std::copy(buffer.begin(), buffer.end(), std::back_inserter(result)); + return result; +} + +gsl::span AdapterFormatBuilder::FinishWithSpan(int adapter_version, int model_version) { + FinishImpl(adapter_version, model_version); + return gsl::make_span(builder_.GetBufferPointer(), builder_.GetSize()); +} + +void AdapterFormatBuilder::FinishImpl(int adapter_version, int model_version) { + auto fbs_params = builder_.CreateVector(params_); + auto fbs_adapter = CreateAdapter(builder_, kAdapterFormatVersion, adapter_version, + model_version, fbs_params); + builder_.Finish(fbs_adapter, AdapterIdentifier()); +} + +} // namespace utils +} // namespace adapters +} // namespace onnxruntime diff --git a/onnxruntime/lora/adapter_format_utils.h b/onnxruntime/lora/adapter_format_utils.h new file mode 100644 index 0000000000000..21a68e6846ac1 --- /dev/null +++ b/onnxruntime/lora/adapter_format_utils.h @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/flatbuffers.h" +#include "core/framework/allocator.h" +#include "core/platform/env.h" + +#include +#include + +#include "adapter_format/adapter_schema.fbs.h" + +#include +#include +#include +#include + +struct OrtValue; + +namespace onnxruntime { +namespace adapters { +namespace utils { + +/// +/// Helper class to serialize Lora adapter +/// +class AdapterFormatBuilder { + public: + AdapterFormatBuilder() = default; + + /// + /// Appends parameter tensor to the adapter builder + /// + /// parameter name + /// + /// + /// + void AddParameter(const std::string& name, adapters::TensorDataType data_type, + gsl::span shape, gsl::span data); + + /// + /// Finishes serialization and returns a serialized byte vector + /// + /// + /// + /// + std::vector Finish(int adapter_version, int model_version); + + /// + /// Finishes serialization and returns a span to internal buffer. + /// + /// + /// + /// + gsl::span FinishWithSpan(int adapter_version, int model_version); + + private: + void FinishImpl(int adapter_version, int model_version); + + flatbuffers::FlatBufferBuilder builder_; + std::vector> params_; +}; + +/// +/// +/// +/// +/// +/// +bool IsAdapterFormatModelBytes(const void* bytes, size_t num_bytes); + +void LoadStringFromLoraFormat(std::string& dst, const flatbuffers::String* fbs_string); + +/// +/// The function loads the lora adapter bytes from the file system +/// +/// file path +/// bytes in a vector +/// If the path can not be found +std::vector LoadLoraAdapterBytes(const std::filesystem::path& file_path); + +/// +/// This function memory maps the adapter file in memory +/// +/// +/// memory handle and file size in a tuple +std::pair MemoryMapAdapterFile(const std::filesystem::path& file_path); + +/// +/// Validates underlying format and the format version +/// +/// +/// Adapter ptr +const Adapter* ValidateAndGetAdapterFromBytes(gsl::span bytes); + +/// +/// Serializes tensor data into flatbuffer +/// +/// +/// parameter name +/// doc, optional +/// +/// +/// +/// output offset +void SaveLoraParameter(flatbuffers::FlatBufferBuilder& flat_builder, std::string_view name, + TensorDataType data_type, + gsl::span shape, gsl::span data, + flatbuffers::Offset& fbs_tensor); + +/// +/// Create an OrtValue on top of the flatbuffer tensor +/// No copying of data is done here. The caller is responsible for managing the lifetime of flatbuffer +/// structures. +/// +/// In this scenario, one can memory map the entire flatbuffer tensor data into OrtValue without copying. +/// +/// +/// +std::pair CreateOrtValueOverLoraParameter(const Parameter& param); +} // namespace utils +} // namespace adapters +} // namespace onnxruntime diff --git a/onnxruntime/lora/adapter_format_version.h b/onnxruntime/lora/adapter_format_version.h new file mode 100644 index 0000000000000..e7cfc781d2e95 --- /dev/null +++ b/onnxruntime/lora/adapter_format_version.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime { +namespace adapters { + +// The current model versions for saving lora parameters in flatbuffers format. +// Once this version is updated, the kSupportedAdapterFormatVersions in IsAdapterFormatVersionSupported +// below will also need to be updated. +// See onnxruntime/lora/adapter_format/README.md for more details on versioning. +// Version 1 - history begins +constexpr const int kAdapterFormatVersion = 1; + +// Check if the given lora format version is supported in this build +inline bool IsAdapterFormatVersionSupported(const int lora_format_version) { + // The lora format versions we will support in this build + // This may contain more versions than the kAdapterFormatVersion, based on the compatibilities + static constexpr std::array kSupportedAdapterFormatVersions{ + kAdapterFormatVersion, + }; + + const auto it = + std::find(kSupportedAdapterFormatVersions.begin(), kSupportedAdapterFormatVersions.end(), lora_format_version); + return it != kSupportedAdapterFormatVersions.cend(); +} + +} // namespace adapters +} // namespace onnxruntime diff --git a/onnxruntime/python/convert_npz_to_onnx_adapter.py b/onnxruntime/python/convert_npz_to_onnx_adapter.py new file mode 100644 index 0000000000000..94bfe69e34cf3 --- /dev/null +++ b/onnxruntime/python/convert_npz_to_onnx_adapter.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# This script helps converting .npz files to .onnx_adapter files + +import argparse +import os +import sys + +import numpy as np + +import onnxruntime as ort + + +def get_args() -> argparse: + parser = argparse.ArgumentParser() + parser.add_argument("--npz_file_path", type=str, required=True) + parser.add_argument("--output_file_path", type=str, required=True) + parser.add_argument("--adapter_version", type=int, required=True) + parser.add_argument("--model_version", type=int, required=True) + return parser.parse_args() + + +def export_lora_parameters( + npz_file_path: os.PathLike, adapter_version: int, model_version: int, output_file_path: os.PathLike +): + """The function converts lora parameters in npz to onnx_adapter format""" + adapter_format = ort.AdapterFormat() + adapter_format.set_adapter_version(adapter_version) + adapter_format.set_model_version(model_version) + name_to_ort_value = {} + with np.load(npz_file_path) as data: + for name, np_arr in data.items(): + ort_value = ort.OrtValue.ortvalue_from_numpy(np_arr) + name_to_ort_value[name] = ort_value + + adapter_format.set_parameters(name_to_ort_value) + adapter_format.export_adapter(output_file_path) + + +def main() -> int: + args = get_args() + export_lora_parameters(args.npz_file_path, args.adapter_version, args.model_version, args.output_file_path) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index c3cfe2c97ae95..d0304160dc68d 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -32,6 +32,52 @@ def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice: raise Exception("Unsupported device type: " + device_type) +class AdapterFormat: + """ + This class is used to create adapter files from python structures + """ + + def __init__(self, adapter=None) -> None: + if adapter is None: + self._adapter = C.AdapterFormat() + else: + self._adapter = adapter + + @staticmethod + def read_adapter(file_path: os.PathLike) -> AdapterFormat: + return AdapterFormat(C.AdapterFormat.read_adapter(file_path)) + + def export_adapter(self, file_path: os.PathLike): + """ + This function writes a file at the specified location + in onnxrunitme adapter format containing Lora parameters. + + :param file_path: absolute path for the adapter + """ + self._adapter.export_adapter(file_path) + + def get_format_version(self): + return self._adapter.format_version + + def set_adapter_version(self, adapter_version: int): + self._adapter.adapter_version = adapter_version + + def get_adapter_version(self): + return self._adapter.adapter_version + + def set_model_version(self, model_version: int): + self._adapter.model_version = model_version + + def get_model_version(self): + return self._adapter.model_version + + def set_parameters(self, params: dict[str, OrtValue]): + self._adapter.parameters = {k: v._ortvalue for k, v in params.items()} + + def get_parameters(self) -> dict[str, OrtValue]: + return {k: OrtValue(v) for k, v in self._adapter.parameters.items()} + + def check_and_normalize_provider_args( providers: Sequence[str | tuple[str, dict[Any, Any]]] | None, provider_options: Sequence[dict[Any, Any]] | None, @@ -711,6 +757,20 @@ def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0): numpy_obj if device_type.lower() == "cpu" else None, ) + @staticmethod + def ortvalue_from_numpy_with_onnxtype(data: Sequence[int], onnx_element_type: int): + """ + This method creates an instance of OrtValue on top of the numpy array + No data copy is made and the lifespan of the resulting OrtValue should never + exceed the lifespan of bytes object. The API attempts to reinterpret + the data type which is expected to be the same size. This is useful + when we want to use an ONNX data type that is not supported by numpy. + + :param data: numpy array. + :param onnx_elemenet_type: a valid onnx TensorProto::DataType enum value + """ + return OrtValue(C.OrtValue.ortvalue_from_numpy_with_onnxtype(data, onnx_element_type), data) + @staticmethod def ortvalue_from_shape_and_type(shape=None, element_type=None, device_type="cpu", device_id=0): """ diff --git a/onnxruntime/python/onnxruntime_pybind_lora.cc b/onnxruntime/python/onnxruntime_pybind_lora.cc new file mode 100644 index 0000000000000..af8365418e5ea --- /dev/null +++ b/onnxruntime/python/onnxruntime_pybind_lora.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "python/onnxruntime_pybind_exceptions.h" +#include "python/onnxruntime_pybind_mlvalue.h" +#include "python/onnxruntime_pybind_state_common.h" + +#define NO_IMPORT_ARRAY +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API +#include "python/numpy_helper.h" + +#include "core/graph/onnx_protobuf.h" + +#include "core/framework/ort_value.h" +#include "core/framework/tensor.h" + +#include "lora/adapter_format_version.h" +#include "lora/adapter_format_utils.h" +#include "core/session/lora_adapters.h" + +#include + +#include + +namespace onnxruntime { +namespace python { + +namespace py = pybind11; +namespace { +/// +/// Class that supports writing and reading adapters +/// in innxruntime format +/// +struct PyAdapterFormatReaderWriter { + PyAdapterFormatReaderWriter() = default; + PyAdapterFormatReaderWriter(int format_version, int adapter_version, + int model_version, + lora::LoraAdapter&& loaded_adapter, + py::dict&& params) + : format_version_(format_version), + adapter_version_(adapter_version), + model_version_(model_version), + loaded_adater_(std::move(loaded_adapter)), + parameters_(std::move(params)) {} + + int format_version_{adapters::kAdapterFormatVersion}; + int adapter_version_{0}; + int model_version_{0}; + // This container is used when reading the the file so + // OrtValue objects can be backed by it. Not exposed to Python + std::optional loaded_adater_; + // This is a dictionary of string -> OrtValue + // this is populated directly on write and + // built on top of the loaded_adapter on read + py::dict parameters_; +}; + +} // namespace + +/* */ +void addAdapterFormatMethods(pybind11::module& m) { + py::class_ adapter_binding(m, "AdapterFormat"); + adapter_binding.def(py::init()) + .def_property_readonly( + "format_version", + [](const PyAdapterFormatReaderWriter* reader_writer) -> int { return reader_writer->format_version_; }, + R"pbdoc("Enables user to read format version stored in the file")pbdoc") + .def_property( + "adapter_version", + [](const PyAdapterFormatReaderWriter* reader_writer) -> int { return reader_writer->adapter_version_; }, + [](PyAdapterFormatReaderWriter* reader_writer, int adapter_version) -> void { reader_writer->adapter_version_ = adapter_version; }, + R"pbdoc("Enables user to read format version stored in the file")pbdoc") + .def_property( + "adapter_version", + [](const PyAdapterFormatReaderWriter* reader_writer) -> int { return reader_writer->adapter_version_; }, + [](PyAdapterFormatReaderWriter* reader_writer, int adapter_version) -> void { reader_writer->adapter_version_ = adapter_version; }, + R"pbdoc("Enables user to read/write adapter version stored in the file")pbdoc") + .def_property( + "model_version", + [](const PyAdapterFormatReaderWriter* reader_writer) -> int { return reader_writer->model_version_; }, + [](PyAdapterFormatReaderWriter* reader_writer, int model_version) -> void { reader_writer->model_version_ = model_version; }, + R"pbdoc("Enables user to read/write model version this adapter was created for")pbdoc") + .def_property( + "parameters", + [](const PyAdapterFormatReaderWriter* reader_writer) -> py::dict { return reader_writer->parameters_; }, + [](PyAdapterFormatReaderWriter* reader_writer, py::dict& parameters) -> void { + reader_writer->parameters_ = parameters; + }, + R"pbdoc("Enables user to read/write adapter version stored in the file")pbdoc") + .def( + "export_adapter", + [](const PyAdapterFormatReaderWriter* reader_writer, const std::wstring& path) { + std::filesystem::path file_path(path); + std::ofstream file(file_path, std::ios::binary); + if (file.fail()) { + ORT_THROW("Failed to open file:", file_path, " for writing."); + } + + adapters::utils::AdapterFormatBuilder format_builder; + for (auto& [n, value] : reader_writer->parameters_) { + const std::string param_name = py::str(n); + const OrtValue* ort_value = value.cast(); + const Tensor& tensor = ort_value->Get(); + const auto data_span = + gsl::make_span(reinterpret_cast(tensor.DataRaw()), + tensor.SizeInBytes()); + format_builder.AddParameter( + param_name, static_cast(tensor.GetElementType()), + tensor.Shape().GetDims(), data_span); + } + + auto format_span = format_builder.FinishWithSpan(reader_writer->adapter_version_, + reader_writer->model_version_); + if (file.write(reinterpret_cast(format_span.data()), format_span.size()).fail()) { + ORT_THROW("Failed to write :", std::to_string(format_span.size()), " bytes to ", file_path); + } + + if (file.flush().fail()) { + ORT_THROW("Failed to flush :", file_path, " on close"); + } + }, + R"pbdoc("Save adapter parameters into a onnxruntime adapter file format.)pbdoc") + + .def_static( + "read_adapter", [](const std::wstring& file_path) -> std::unique_ptr { + lora::LoraAdapter lora_adapter; + lora_adapter.Load(file_path); + + auto [begin, end] = lora_adapter.GetParamIterators(); + py::dict params; + for (; begin != end; ++begin) { + auto& [name, param] = *begin; + OrtValue& ort_value = param.GetMapped(); + params[py::str(name)] = py::cast(&ort_value); + } + + auto py_adapter = std::make_unique( + lora_adapter.FormatVersion(), lora_adapter.AdapterVersion(), + lora_adapter.ModelVersion(), std::move(lora_adapter), std::move(params)); + + return py_adapter; + }, + R"pbdoc(The function returns an instance of the class that contains a dictionary of name -> numpy arrays)pbdoc"); + + py::class_ lora_adapter_binding(m, "LoraAdapter"); + lora_adapter_binding.def(py::init()) + .def("Load", [](lora::LoraAdapter* adapter, const std::wstring& file_path) { adapter->MemoryMap(file_path); }, R"pbdoc(Memory map the specified file as LoraAdapter)pbdoc"); +} + +} // namespace python +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 8fdac257297c1..6ed4c42bd4304 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -87,15 +87,15 @@ static TensorShape GetArrayShape(PyArrayObject* pyObject) { const int ndim = PyArray_NDIM(pyObject); const npy_intp* npy_dims = PyArray_DIMS(pyObject); auto span = gsl::make_span(npy_dims, ndim); - std::vector dims(span.begin(), span.end()); - TensorShape shape(std::move(dims)); + TensorShapeVector shape_vec(span.begin(), span.end()); + TensorShape shape(shape_vec); return shape; } TensorShape GetShape(const py::array& arr) { auto span = gsl::make_span(arr.shape(), arr.ndim()); - std::vector dims(span.begin(), span.end()); - TensorShape shape(std::move(dims)); + TensorShapeVector shape_vec(span.begin(), span.end()); + TensorShape shape(shape_vec); return shape; } diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index d76b9032afe73..e338634d73bd3 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -142,6 +142,28 @@ void addOrtValueMethods(pybind11::module& m) { throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device"); } }) + // Create an ortvalue value on top of the numpy array, but interpret the data + // as a different type with the same element size. + .def_static("ortvalue_from_numpy_with_onnxtype", [](py::array& data, int32_t onnx_element_type) -> std::unique_ptr { + if (!ONNX_NAMESPACE::TensorProto_DataType_IsValid(onnx_element_type)) { + ORT_THROW("Not a valid ONNX Tensor data type: ", onnx_element_type); + } + + const auto element_type = DataTypeImpl::TensorTypeFromONNXEnum(onnx_element_type) + ->GetElementType(); + + const auto element_size = element_type->Size(); + if (narrow(data.itemsize()) != element_size) { + ORT_THROW("Items size in the incoming array: ", data.itemsize(), + " specified by onnxtype: ", element_size); + } + + auto cpu_allocator = GetAllocator(); + auto ort_value = std::make_unique(); + Tensor::InitOrtValue(element_type, GetShape(data), + const_cast(data.data()), cpu_allocator->Info(), *ort_value); + return ort_value; + }) // Factory method to create an OrtValue (Tensor) from the given shape and element type with memory on the specified device // The memory is left uninitialized .def_static("ortvalue_from_shape_and_type", [](const std::vector& shape, py::object& element_type, const OrtDevice& device) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e8bf61612c89b..5ac9c149bbe80 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -32,6 +32,8 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/provider_bridge_ort.h" +#include "core/session/lora_adapters.h" + #ifdef ENABLE_ATEN #include "contrib_ops/cpu/aten_ops/aten_op_executor.h" #endif @@ -53,10 +55,6 @@ // (This static var is referenced in GetCudaToHostMemCpyFunction()) const OrtDevice::DeviceType OrtDevice::GPU; -namespace onnxruntime { - -} // namespace onnxruntime - #if defined(_MSC_VER) #pragma warning(disable : 4267 4996 4503) #endif // _MSC_VER @@ -168,6 +166,24 @@ void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtS } } +void AppendLoraParametersAsInputs(const RunOptions& run_options, + size_t total_entries, + NameMLValMap& feeds) { + for (const auto* adapter : run_options.active_adapters) { + total_entries += adapter->GetParamNum(); + } + feeds.reserve(total_entries + feeds.size()); + + // Append necessary inputs for active adapters + for (const auto* adapter : run_options.active_adapters) { + auto [begin, end] = adapter->GetParamIterators(); + for (; begin != end; ++begin) { + const auto& [name, param] = *begin; + feeds.insert(std::make_pair(name, param.GetMapped())); + } + } +} + template static py::object AddNonTensor(const OrtValue& val, const DataTransferManager* /*data_transfer_manager*/, @@ -1901,7 +1917,12 @@ RunOptions instance. The individual calls will exit gracefully and return an err return value; }, - R"pbdoc(Get a single run configuration value using the given configuration key.)pbdoc"); + R"pbdoc(Get a single run configuration value using the given configuration key.)pbdoc") + .def( + "add_active_adapter", [](RunOptions* options, lora::LoraAdapter* adapter) { + options->active_adapters.push_back(adapter); + }, + R"pbdoc(Adds specified adapter as an active adapter)pbdoc"); py::class_(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model. It is usually used to identify the model used to run the prediction and @@ -2022,7 +2043,12 @@ including arg name, arg type (contains both type and shape).)pbdoc") const std::map& pyfeeds, RunOptions* run_options = nullptr) -> py::list { NameMLValMap feeds; - feeds.reserve(pyfeeds.size()); + if (run_options != nullptr && !run_options->active_adapters.empty()) { + AppendLoraParametersAsInputs(*run_options, pyfeeds.size(), feeds); + } else { + feeds.reserve(pyfeeds.size()); + } + for (const auto& feed : pyfeeds) { // No need to process 'None's sent in by the user // to feed Optional inputs in the graph. @@ -2036,7 +2062,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") } CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value); ThrowIfPyErrOccured(); - feeds.insert(std::make_pair(feed.first, ml_value)); + feeds.insert(std::make_pair(feed.first, std::move(ml_value))); } } @@ -2079,6 +2105,11 @@ including arg name, arg type (contains both type and shape).)pbdoc") PyCallback callback, py::object user_data = {}, RunOptions* run_options = nullptr) -> void { + if (run_options != nullptr && !run_options->active_adapters.empty()) { + LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING) + << "run_async has active adapters specified, but won't have an effect"; + } + std::unique_ptr async_resource = std::make_unique(); async_resource->callback = callback; async_resource->user_data = user_data; @@ -2124,7 +2155,12 @@ including arg name, arg type (contains both type and shape).)pbdoc") /// a Tensor, SparseTensor or a TensorSequence. .def("run_with_ort_values", [](PyInferenceSession* sess, const py::dict& feeds, const std::vector& output_names, RunOptions* run_options = nullptr) -> std::vector { NameMLValMap ort_feeds; - ort_feeds.reserve(feeds.size()); + if (run_options != nullptr && !run_options->active_adapters.empty()) { + AppendLoraParametersAsInputs(*run_options, feeds.size(), ort_feeds); + } else { + ort_feeds.reserve(feeds.size()); + } + // item is always a copy since dict returns a value and not a ref // and Apple XToolChain barks for (const auto& item : feeds) { @@ -2147,6 +2183,11 @@ including arg name, arg type (contains both type and shape).)pbdoc") return fetches; }) .def("run_with_ortvaluevector", [](PyInferenceSession* sess, RunOptions run_options, const std::vector& feed_names, const std::vector& feeds, const std::vector& fetch_names, std::vector& fetches, const std::vector& fetch_devices) -> void { + if (!run_options.active_adapters.empty()) { + LOGS(*sess->GetSessionHandle()->GetLogger(), WARNING) + << "run_with_ortvaluevector has active adapters specified, but won't have an effect"; + } + // release GIL to allow multiple python threads to invoke Run() in parallel. py::gil_scoped_release release; OrtPybindThrowIfError(sess->GetSessionHandle()->Run(run_options, feed_names, feeds, fetch_names, &fetches, &fetch_devices)); @@ -2261,6 +2302,7 @@ bool CreateInferencePybindStateModule(py::module& m) { addOrtValueMethods(m); addSparseTensorMethods(m); addIoBindingMethods(m); + addAdapterFormatMethods(m); #if !defined(__APPLE__) && !defined(ORT_MINIMAL_BUILD) if (!InitProvidersSharedLibrary()) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.h b/onnxruntime/python/onnxruntime_pybind_state.h index 47cde0d4cf193..fc9ef83d7a0d3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.h +++ b/onnxruntime/python/onnxruntime_pybind_state.h @@ -9,6 +9,7 @@ namespace python { void addGlobalMethods(py::module& m, Environment& env); void addObjectMethods(py::module& m, Environment& env); void addOrtValueMethods(pybind11::module& m); +void AddLoraMethods(pybind11::module& m); } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 08e5e4f7b18fa..225931533615d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -395,6 +395,8 @@ void addIoBindingMethods(pybind11::module& m); void addSparseTensorMethods(pybind11::module& m); +void addAdapterFormatMethods(pybind11::module& m); + void addGlobalSchemaFunctions(pybind11::module& m); void addOpKernelSubmodule(pybind11::module& m); diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc new file mode 100644 index 0000000000000..e8291a36447ca --- /dev/null +++ b/onnxruntime/test/lora/lora_test.cc @@ -0,0 +1,237 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers_fwd.h" +#include "core/framework/data_types_internal.h" +#include "core/framework/to_tensor_proto_element_type.h" + +#include "test/util/include/default_providers.h" + +#include "core/session/lora_adapters.h" +#include "lora/adapter_format_version.h" +#include "lora/adapter_format_utils.h" + +#include "gtest/gtest.h" +#include + +#include "test/util/include/asserts.h" + +namespace onnxruntime { +namespace test { + +namespace { + +constexpr const int kAdapterVersion = 1; +constexpr const int kModelVersion = 1; + +template +struct ReadAndValidateData { + void operator()(const Tensor& parameter) const { + auto data = parameter.DataAsSpan(); + for (size_t i = static_cast(data[0]), size = data.size(); i < size; ++i) { + ASSERT_EQ(static_cast(i), data[i]); + } + } +}; + +template <> +struct ReadAndValidateData { + void operator()(const Tensor& parameter) const { + auto data = parameter.DataAsSpan(); + for (size_t i = static_cast(data[0]), size = data.size(); i < size; ++i) { + ASSERT_FALSE(std::isnan(data[i])); + ASSERT_TRUE(std::isfinite(data[i])); + ASSERT_EQ(static_cast(i), data[i]); + } + } +}; + +template <> +struct ReadAndValidateData { + void operator()(const Tensor& parameter) const { + auto data = parameter.DataAsSpan(); + for (size_t i = static_cast(data[0]), size = data.size(); i < size; ++i) { + ASSERT_FALSE(std::isnan(data[i])); + ASSERT_TRUE(std::isfinite(data[i])); + ASSERT_EQ(static_cast(i), data[i]); + } + } +}; + +template <> +struct ReadAndValidateData { + void operator()(const Tensor& parameter) const { + auto data = parameter.DataAsSpan(); + for (size_t i = static_cast(data[0].ToFloat()), size = data.size(); i < size; ++i) { + ASSERT_FALSE(data[i].IsNaN()); + ASSERT_FALSE(data[i].IsInfinity()); + ASSERT_EQ(static_cast(i), data[i].ToFloat()); + } + } +}; + +template <> +struct ReadAndValidateData { + void operator()(const Tensor& parameter) const { + auto data = parameter.DataAsSpan(); + for (size_t i = static_cast(data[0].ToFloat()), size = data.size(); i < size; ++i) { + ASSERT_FALSE(data[i].IsNaN()); + ASSERT_FALSE(data[i].IsInfinity()); + ASSERT_EQ(static_cast(i), data[i].ToFloat()); + } + } +}; + +auto verify_load = [](const lora::LoraAdapter& adapter) { + ASSERT_EQ(kAdapterVersion, adapter.AdapterVersion()); + ASSERT_EQ(kModelVersion, adapter.ModelVersion()); + + const auto param_num = adapter.GetParamNum(); + ASSERT_EQ(param_num, 2U); + + InlinedVector names; + InlinedVector ort_values; + names.reserve(param_num); + ort_values.reserve(param_num); + + adapter.OutputAdapterParameters(std::back_inserter(names), std::back_inserter(ort_values)); + ASSERT_EQ(param_num, names.size()); + ASSERT_EQ(param_num, ort_values.size()); + + for (size_t i = 0; i < param_num; ++i) { + const auto& name = names[i]; + const auto* ort_value = ort_values[i]; + ASSERT_TRUE(name != nullptr); + ASSERT_TRUE(ort_value->IsTensor()); + + const auto& tensor = ort_value->Get(); + ASSERT_NE(tensor.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + + const auto shape = tensor.Shape().GetDims(); + ASSERT_EQ(2U, shape.size()); + ASSERT_EQ(8, shape[0]); + ASSERT_EQ(4, shape[1]); + + // Read all the elements to make sure they are accessible + // only on CPU + const auto& mem_info = tensor.Location(); + if (mem_info.device.Type() == OrtDevice::CPU) { + utils::MLTypeCallDispatcher + disp(tensor.GetElementType()); + disp.Invoke(tensor); + } + } +}; + +constexpr const std::array param_shape = {8, 4}; + +template +struct CreateParam { + InlinedVector operator()() const { + InlinedVector param(32); + std::iota(param.begin(), param.end(), T{0}); + return param; + } +}; + +template +struct GenerateTestParameters { + std::vector operator()() const { + constexpr const auto data_type = utils::ToTensorProtoElementType(); + + InlinedVector param_1(32); + InlinedVector param_2(32); + if constexpr (std::is_same::value || std::is_same::value) { + for (float f = 0.f; f < 32; ++f) { + param_1[static_cast(f)] = static_cast(f); + param_2[static_cast(f)] = static_cast(f + 32); + } + } else { + std::iota(param_1.begin(), param_1.end(), T{0}); + std::iota(param_2.begin(), param_2.end(), T{32}); + } + + adapters::utils::AdapterFormatBuilder adapter_builder; + adapter_builder.AddParameter("param_1", static_cast(data_type), + param_shape, ReinterpretAsSpan(gsl::make_span(param_1))); + adapter_builder.AddParameter("param_2", static_cast(data_type), + param_shape, ReinterpretAsSpan(gsl::make_span(param_2))); + + return adapter_builder.Finish(kAdapterVersion, kModelVersion); + } +}; + +template +struct TestDataType { + void operator()() const { + const auto test_params = GenerateTestParameters()(); + lora::LoraAdapter lora_adapter; + lora_adapter.Load(std::move(test_params)); + verify_load(lora_adapter); + } +}; + +} // namespace + +TEST(LoraAdapterTest, Load) { + // Test different data types + const auto data_types = gsl::make_span(adapters::EnumValuesTensorDataType()); + for (size_t i = 1, size = data_types.size(); i < size; ++i) { + const auto dt = data_types[i]; + + using namespace adapters; + if (dt == TensorDataType::STRING || + dt == TensorDataType::BOOL || + dt == TensorDataType::COMPLEX64 || + dt == TensorDataType::COMPLEX128 || + static_cast(dt) >= static_cast(TensorDataType::BFLOAT16)) + continue; + + onnxruntime::utils::MLTypeCallDispatcher + disp(static_cast(data_types[i])); + disp.Invoke(); + } +} + +#ifdef USE_CUDA +TEST(LoraAdapterTest, VerifyDeviceCopy) { + auto cpu_ep = DefaultCpuExecutionProvider(); + auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; + auto cuda_ep = DefaultCudaExecutionProvider(); + auto cuda_allocator = cuda_ep->CreatePreferredAllocators()[0]; + + auto gpu_transfer = cuda_ep->GetDataTransfer(); + + auto test_params = GenerateTestParameters()(); + lora::LoraAdapter adapter(std::move(cuda_allocator)); + adapter.Load(std::move(test_params)); + + auto [begin, end] = adapter.GetParamIterators(); + for (; begin != end; ++begin) { + const auto& [_, param] = *begin; + const auto& tensor_device = param.GetDeviceOrMapped().Get(); + ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::CUDA)); + + const auto& tensor_cpu = param.GetMapped().Get(); + ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); + + Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator); + ASSERT_TRUE(gpu_transfer->CanCopy(tensor_device.Location().device, + copy.Location().device)); + ASSERT_STATUS_OK(gpu_transfer->CopyTensor(tensor_device, copy)); + + auto expected_span = tensor_cpu.DataAsSpan(); + auto copy_span = copy.DataAsSpan(); + + ASSERT_EQ(expected_span, copy_span); + } +} +#endif +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 24151932a6681..9419761340517 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1389,6 +1389,13 @@ def test_session_with_ortvalue_input(ortvalue): # The constructed OrtValue should still be valid after being used in a session self.assertTrue(np.array_equal(ortvalue1.numpy(), numpy_arr_input)) + # test ort_value creation on top of the bytes + float_tensor_data_type = 1 # TensorProto_DataType_FLOAT + ort_value_with_type = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(numpy_arr_input, float_tensor_data_type) + self.assertTrue(ort_value_with_type.is_tensor()) + self.assertEqual(float_tensor_data_type, ort_value_with_type.element_type()) + self.assertEqual([3, 2], ort_value_with_type.shape()) + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): ortvalue2 = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr_input, "cuda", 0) self.assertEqual(ortvalue2.device_name(), "cuda") @@ -1824,6 +1831,91 @@ def test_multiple_devices(self): device1_session.run(output_names=["Plus214_Output_0"], input_feed=image) device0_session.run(output_names=["Plus214_Output_0"], input_feed=image) + def test_adater_export_read(self): + adapter_version = 1 + model_version = 1 + file_path = pathlib.Path(os.path.realpath(__file__)).parent + file_path = str(file_path / "test_adapter.onnx_adapter") + + float_data_type = 1 + int64_data_type = 7 + val = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + param_1 = np.array(val).astype(np.float32).reshape(5, 2) + param_2 = np.array(val).astype(np.int64).reshape(2, 5) + + ort_val_1 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(param_1, float_data_type) + ort_val_2 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(param_2, int64_data_type) + + params = {"param_1": ort_val_1, "param_2": ort_val_2} + + adapter_format = onnxrt.AdapterFormat() + adapter_format.set_adapter_version(adapter_version) + adapter_format.set_model_version(model_version) + adapter_format.set_parameters(params) + + adapter_format.export_adapter(file_path) + + adapter_format_read = onnxrt.AdapterFormat.read_adapter(file_path) + os.remove(file_path) + + self.assertEqual(adapter_version, adapter_format_read.get_adapter_version()) + self.assertEqual(model_version, adapter_format_read.get_model_version()) + + actual_params = adapter_format_read.get_parameters() + self.assertCountEqual(params, actual_params) + for key, value in actual_params.items(): + self.assertIn(key, params) + expected_val = params.get(key) + self.assertTrue(value.is_tensor()) + self.assertEqual(expected_val.element_type(), value.element_type()) + self.assertEqual(expected_val.shape(), value.shape()) + np.testing.assert_allclose(expected_val.numpy(), value.numpy()) + + def test_run_with_adapter(self): + model_path = get_name("lora/two_params_lora_model.onnx") + file_path = os.getcwd() + "/" + get_name("lora/two_params_lora_model.onnx_adapter") + adapter_path = os.path.abspath(file_path) + + expected_output = np.array( + [ + [154.0, 176.0, 198.0, 220.0], + [154.0, 176.0, 198.0, 220.0], + [154.0, 176.0, 198.0, 220.0], + [154.0, 176.0, 198.0, 220.0], + ], + dtype=np.float32, + ) + + adapter = onnxrt.LoraAdapter() + adapter.Load(adapter_path) + + run_options = onnxrt.RunOptions() + run_options.add_active_adapter(adapter) + session = onnxrt.InferenceSession(model_path) + + inputs = {"input_x": np.ones((4, 4), dtype=np.float32)} + + outputs = session.run(None, inputs, run_options) + self.assertEqual(len(outputs), 1) + self.assertTrue(np.allclose(outputs[0], expected_output)) + + def test_run_base_model(self): + model_path = get_name("lora/two_params_lora_model.onnx") + + expected_output = np.array( + [[28.0, 32.0, 36.0, 40.0], [28.0, 32.0, 36.0, 40.0], [28.0, 32.0, 36.0, 40.0], [28.0, 32.0, 36.0, 40.0]], + dtype=np.float32, + ) + + run_options = onnxrt.RunOptions() + session = onnxrt.InferenceSession(model_path) + + inputs = {"input_x": np.ones((4, 4), dtype=np.float32)} + + outputs = session.run(None, inputs, run_options) + self.assertEqual(len(outputs), 1) + self.assertTrue(np.allclose(outputs[0], expected_output)) + if __name__ == "__main__": unittest.main(verbosity=1) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 7a33bf8a527cd..782992e90bd39 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4402,6 +4402,120 @@ TEST(CApiTest, RunAsyncFail) { EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception); } +static void TestRunWithLoraAdapter(const Ort::LoraAdapter& adapter) { + constexpr const ORTCHAR_T* model_path = TSTR("testdata/lora/two_params_lora_model.onnx"); + + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + + Ort::RunOptions run_options; + run_options.AddActiveLoraAdapter(adapter); + + // Single input + constexpr const std::array input_shape = {4, 4}; + std::vector input_x(16); + std::fill(input_x.begin(), input_x.end(), 1.0f); + constexpr const char* input_names[] = {"input_x"}; + constexpr const char* output_names[] = {"output"}; + + auto cpu_meminfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + auto input_x_val = Ort::Value::CreateTensor( + cpu_meminfo, input_x.data(), input_x.size(), input_shape.data(), input_shape.size()); + + Ort::Value inputs[] = {std::move(input_x_val)}; + + Ort::SessionOptions default_session_options; + + constexpr const std::array expected_output = { + 154.f, 176.f, 198.f, 220.f, + 154.f, 176.f, 198.f, 220.f, + 154.f, 176.f, 198.f, 220.f, + 154.f, 176.f, 198.f, 220.f}; + + Ort::Session session(env, model_path, default_session_options); + + auto outputs = session.Run(run_options, input_names, inputs, std::size(input_names), output_names, std::size(output_names)); + ASSERT_EQ(1U, outputs.size()); + + auto tensor_type_shape = outputs[0].GetTensorTypeAndShapeInfo(); + const auto elements = tensor_type_shape.GetElementCount(); + ASSERT_EQ(expected_output.size(), elements); + const float* data = outputs[0].GetTensorData(); + for (size_t i = 0; i < elements; ++i) { + EXPECT_NEAR(expected_output[i], data[i], 0.06); + } +} + +static Ort::LoraAdapter CreateAdapterFromFile() { + constexpr const ORTCHAR_T* adapter_path = TSTR("testdata/lora/two_params_lora_model.onnx_adapter"); + return Ort::LoraAdapter::CreateLoraAdapter(adapter_path, nullptr); +} + +static Ort::LoraAdapter CreateAdapterFromArray() { + constexpr const ORTCHAR_T* adapter_path = TSTR("testdata/lora/two_params_lora_model.onnx_adapter"); + std::ifstream adapter_file(adapter_path, std::ios::binary); + + EXPECT_TRUE(adapter_file.is_open()); + adapter_file.seekg(0, std::ios::end); + const size_t adapter_size = adapter_file.tellg(); + + std::vector buffer(adapter_size); + adapter_file.seekg(0, std::ios::beg); + adapter_file.read(reinterpret_cast(buffer.data()), adapter_size); + adapter_file.close(); + + return Ort::LoraAdapter::CreateLoraAdapterFromArray(buffer.data(), buffer.size(), nullptr); +} + +TEST(CApiTest, RunWithLoraAdapterFromFile) { + auto adapter = CreateAdapterFromFile(); + TestRunWithLoraAdapter(adapter); +} + +TEST(CApiTest, RunWithLoraAdapterFromArray) { + auto adapter = CreateAdapterFromArray(); + TestRunWithLoraAdapter(adapter); +} + +TEST(CApiTest, RunBaseLoraModel) { + constexpr const ORTCHAR_T* model_path = TSTR("testdata/lora/two_params_lora_model.onnx"); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + constexpr const std::array input_shape = {4, 4}; + std::vector input_x(16); + std::fill(input_x.begin(), input_x.end(), 1.0f); + constexpr const char* input_names[] = {"input_x"}; + constexpr const char* output_names[] = {"output"}; + + auto cpu_meminfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + auto input_x_val = Ort::Value::CreateTensor( + cpu_meminfo, input_x.data(), input_x.size(), input_shape.data(), input_shape.size()); + + Ort::Value inputs[] = {std::move(input_x_val)}; + + Ort::SessionOptions default_session_options; + + constexpr const std::array expected_output = { + 28.f, 32.f, 36.f, 40.f, + 28.f, 32.f, 36.f, 40.f, + 28.f, 32.f, 36.f, 40.f, + 28.f, 32.f, 36.f, 40.f}; + + Ort::Session session(env, model_path, default_session_options); + + Ort::RunOptions run_options; + auto outputs = session.Run(run_options, input_names, inputs, std::size(input_names), output_names, std::size(output_names)); + ASSERT_EQ(1U, outputs.size()); + + auto tensor_type_shape = outputs[0].GetTensorTypeAndShapeInfo(); + const auto elements = tensor_type_shape.GetElementCount(); + ASSERT_EQ(expected_output.size(), elements); + const float* data = outputs[0].GetTensorData(); + for (size_t i = 0; i < elements; ++i) { + EXPECT_NEAR(expected_output[i], data[i], 0.06); + } +} + struct MockGQA : public OrtCustomOp { MockGQA() { OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) { diff --git a/onnxruntime/test/testdata/lora/two_params_lora_model.onnx b/onnxruntime/test/testdata/lora/two_params_lora_model.onnx new file mode 100644 index 0000000000000..66e316a8b71dc Binary files /dev/null and b/onnxruntime/test/testdata/lora/two_params_lora_model.onnx differ diff --git a/onnxruntime/test/testdata/lora/two_params_lora_model.onnx_adapter b/onnxruntime/test/testdata/lora/two_params_lora_model.onnx_adapter new file mode 100644 index 0000000000000..369471f0d74ee Binary files /dev/null and b/onnxruntime/test/testdata/lora/two_params_lora_model.onnx_adapter differ diff --git a/onnxruntime/test/testdata/lora/two_params_lora_model.py b/onnxruntime/test/testdata/lora/two_params_lora_model.py new file mode 100644 index 0000000000000..12706ad71e82e --- /dev/null +++ b/onnxruntime/test/testdata/lora/two_params_lora_model.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import onnx + +import onnxruntime as ort + +model_path = "two_params_lora_model.onnx" +adapter_path = "two_params_lora_model.onnx_adapter" + + +def create_model(model_path: os.PathLike): + #### Inputs + # original input_x and its associated weight + input_x = onnx.helper.make_tensor_value_info("input_x", onnx.TensorProto.FLOAT, [4, 4]) + + # Inputs overriding default Lora initializers + lora_param_a_input = onnx.helper.make_tensor_value_info("lora_param_a", onnx.TensorProto.FLOAT, [4, "dim"]) + lora_param_b_input = onnx.helper.make_tensor_value_info("lora_param_b", onnx.TensorProto.FLOAT, ["dim", 4]) + + ### Outputs + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [4, 4]) + + #### Initializers + # Base weight tensor proto + weight_x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]).reshape(4, 4).astype(np.float32) + weight_x_tensor = onnx.helper.make_tensor("weight_x", onnx.TensorProto.FLOAT, [4, 4], weight_x.flatten()) + + # tensor proto for default lora parameter A + lora_weight_a = np.zeros([4, 0], dtype=np.float32) + lora_weight_a_tensor = onnx.helper.make_tensor( + "lora_param_a", onnx.TensorProto.FLOAT, [4, 0], lora_weight_a.flatten() + ) + + # tensor proto for default lora parameter B + lora_weight_b = np.zeros([0, 4], dtype=np.float32) + lora_weight_b_tensor = onnx.helper.make_tensor( + "lora_param_b", onnx.TensorProto.FLOAT, [0, 4], lora_weight_b.flatten() + ) + + ##### Linear nodes + # Create matmul for base case + matmul_x = onnx.helper.make_node("MatMul", ["input_x", "weight_x"], ["mm_output_x"]) + # create matmul node for lora_param_a + matmul_a = onnx.helper.make_node("MatMul", ["input_x", "lora_param_a"], ["mm_output_a"]) + # Create matmul for lora_param_b + matmul_b = onnx.helper.make_node("MatMul", ["mm_output_a", "lora_param_b"], ["mm_output_b"]) + + # Create Add + add_node = onnx.helper.make_node("Add", ["mm_output_x", "mm_output_b"], ["output"]) + + graph = onnx.helper.make_graph( + name="two_params_lora_model", + nodes=[matmul_x, matmul_a, matmul_b, add_node], + inputs=[input_x, lora_param_a_input, lora_param_b_input], + outputs=[output], + initializer=[weight_x_tensor, lora_weight_a_tensor, lora_weight_b_tensor], + ) + + # create a model + model = onnx.helper.make_model(graph) + + # onnx.checker.check_model(model, full_check=True) + + onnx.save_model(model, model_path) + + +def create_adapter(adapter_path: os.PathLike): + """ + Creates an test adapter for the model above + """ + param_a = np.array([3, 4, 5, 6]).astype(np.float32).reshape(4, 1) + param_b = np.array([7, 8, 9, 10]).astype(np.float32).reshape(1, 4) + ort_value_a = ort.OrtValue.ortvalue_from_numpy(param_a) + ort_value_b = ort.OrtValue.ortvalue_from_numpy(param_b) + + numpy_a = ort_value_a.numpy() + numpy_b = ort_value_b.numpy() + np.allclose(param_a, numpy_a) + np.allclose(param_b, numpy_b) + + print(param_a) + print(param_b) + + name_to_value = {"lora_param_a": ort_value_a, "lora_param_b": ort_value_b} + + adapter_format = ort.AdapterFormat() + adapter_format.set_adapter_version(1) + adapter_format.set_model_version(1) + adapter_format.set_parameters(name_to_value) + adapter_format.export_adapter(adapter_path) + + +def read_adapter(adapter_path: os.PathLike): + adapter = ort.AdapterFormat.read_adapter(adapter_path) + params = adapter.get_parameters() + + assert "lora_param_a" in params + assert "lora_param_b" in params + + numpy_a = params["lora_param_a"].numpy() + print(numpy_a) + + numpy_b = params["lora_param_b"].numpy() + print(numpy_b) + + +def run_base_model(model_path: os.PathLike): + session = ort.InferenceSession(model_path) + + # Run the base case + inputs = {"input_x": np.ones((4, 4), dtype=np.float32)} + + outputs = session.run(None, inputs) + print(outputs) + + +def run_with_override(model_path: os.PathLike): + session = ort.InferenceSession(model_path) + + inputs = { + "input_x": np.ones((4, 4), dtype=np.float32), + "lora_param_a": np.array([3, 4, 5, 6]).astype(np.float32).reshape(4, 1), + "lora_param_b": np.array([7, 8, 9, 10]).astype(np.float32).reshape(1, 4), + } + + outputs = session.run(None, inputs) + print(outputs) + + +def run_with_adapter(model_path: os.PathLike, adapter_path: os.PathLike): + adapter = ort.LoraAdapter() + adapter.Load(adapter_path) + + run_options = ort.RunOptions() + run_options.set_adapter_active(adapter) + + session = ort.InferenceSession(model_path) + + inputs = {"input_x": np.ones((4, 4), dtype=np.float32)} + + outputs = session.run(None, inputs, run_options) + + print(outputs) + + +if __name__ == "__main__": + # create_model(model_path) + # run_base_model(model_path) + run_with_override(model_path) + # create_adapter(adapter_path) + # read_adapter(adapter_path) + run_with_adapter(model_path, adapter_path) diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index b0d1ed50af126..7ec924b6d9bb4 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -317,6 +317,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { addOrtValueMethods(m); addSparseTensorMethods(m); addIoBindingMethods(m); + addAdapterFormatMethods(m); #if !defined(__APPLE__) && \ (!defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS))