diff --git a/.github/workflows/linux-cpu-arm64-build.yml b/.github/workflows/linux-cpu-arm64-build.yml index aeac8a77b..26b749c5e 100644 --- a/.github/workflows/linux-cpu-arm64-build.yml +++ b/.github/workflows/linux-cpu-arm64-build.yml @@ -95,4 +95,4 @@ jobs: run: | docker run --rm \ --volume $GITHUB_WORKSPACE:/onnxruntime_src \ - -w /onnxruntime_src ort_genai_linux_arm64_gha bash -c "/onnxruntime_src/build/cpu/test/unit_tests" + -w /onnxruntime_src ort_genai_linux_arm64_gha bash -c "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/onnxruntime_src/build/cpu/ /onnxruntime_src/build/cpu/test/unit_tests" diff --git a/CMakeLists.txt b/CMakeLists.txt index a33e0ad2c..0c816a9c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,7 +84,7 @@ target_link_directories(onnxruntime-genai PRIVATE ${ORT_LIB_DIR}) # we keep the shared libraries disconnected on Android as they will come from separate AARs and we don't want to force # the ORT version to match in both. -if(NOT ANDROID) +if(NOT (CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Linux")) target_link_libraries(onnxruntime-genai PRIVATE ${ONNXRUNTIME_LIB}) endif() diff --git a/src/logging.cpp b/src/logging.cpp index 38ea26877..e92afee82 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -5,6 +5,7 @@ #include "json.h" #include #include +#include namespace Generators { @@ -87,4 +88,19 @@ std::ostream& Log(std::string_view label, std::string_view string) { return *gp_stream; } -} // namespace Generators \ No newline at end of file +std::ostream& Log(std::string_view label, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + va_list args_copy; + va_copy(args_copy, args); + size_t len = vsnprintf(0, 0, fmt, args_copy); + if (len <= 0) { + throw std::runtime_error("Invalid format"); + } + std::unique_ptr buf(new char[len + 1]); + vsnprintf(buf.get(), len + 1, fmt, args); + va_end(args); + return Log(label, std::string(buf.get(), buf.get() + len)); +} + +} // namespace Generators diff --git a/src/logging.h b/src/logging.h index 929451f88..99dc8c3d4 100644 --- a/src/logging.h +++ b/src/logging.h @@ -76,4 +76,5 @@ std::ostream& operator<<(std::ostream& stream, SGR sgr_code); std::ostream& Log(std::string_view label, std::string_view text = {}); -} // namespace Generators \ No newline at end of file +std::ostream& Log(std::string_view label, const char* fmt, ...); +} // namespace Generators diff --git a/src/models/onnxruntime_api.h b/src/models/onnxruntime_api.h index 0ce2acc0a..eebfe0278 100644 --- a/src/models/onnxruntime_api.h +++ b/src/models/onnxruntime_api.h @@ -68,13 +68,37 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o #include #include #include +#include #include "onnxruntime_c_api.h" #include "../span.h" +#include "../logging.h" #if defined(__ANDROID__) #include #include + +#define TAG "GenAI" + +#define LOG_DEBUG(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__) +#define LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOG_WARN(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) +#define LOG_ERROR(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) +#define LOG_FATAL(...) __android_log_print(ANDROID_LOG_FATAL, TAG, __VA_ARGS__) + +#elif defined(__linux__) +#include + +#ifndef PATH_MAX +#define PATH_MAX (4096) +#endif + +#define LOG_DEBUG(...) Generators::Log("debug", __VA_ARGS__) +#define LOG_INFO(...) Generators::Log("info", __VA_ARGS__) +#define LOG_WARN(...) Generators::Log("warning", __VA_ARGS__) +#define LOG_ERROR(...) Generators::Log("error", __VA_ARGS__) +#define LOG_FATAL(...) Generators::Log("fatal", __VA_ARGS__) + #endif /** \brief Free functions and a few helpers are defined inside this namespace. Otherwise all types are the C API types @@ -82,10 +106,76 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o */ namespace Ort { +using OrtApiBaseFn = const OrtApiBase* (*)(void); + /// Before using this C++ wrapper API, you MUST call Ort::InitApi to set the below 'api' variable inline const OrtApi* api{}; + +#if defined(__linux__) +inline void* LoadDynamicLibraryIfExists(const std::string& path) { + LOG_INFO("Attempting to dlopen %s", path.c_str()); + void* ort_lib_handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (ort_lib_handle == nullptr) { + return nullptr; + } + +#if !defined(__ANDROID__) // RTLD_DI_ORIGIN not available on Android + char pathname[PATH_MAX]; + dlinfo((void*)ort_lib_handle, RTLD_DI_ORIGIN, &pathname); + LOG_INFO("Loaded native library at %s", pathname); +#endif + return ort_lib_handle; +} + +inline std::string GetCurrentModuleDir() { + Dl_info dl_info; + dladdr((void*)GetCurrentModuleDir, &dl_info); + std::string module_name(dl_info.dli_fname); + std::string module_directory{}; + + const size_t last_slash_idx = module_name.rfind('/'); + if (std::string::npos != last_slash_idx) { + module_directory = module_name.substr(0, last_slash_idx); + } + return module_directory; +} + +inline void InitApiWithDynamicFn(OrtApiBaseFn ort_api_base_fn) { + if (ort_api_base_fn == nullptr) { + throw std::runtime_error("OrtGetApiBase not found"); + } + + const OrtApiBase* ort_api_base = ort_api_base_fn(); + if (ort_api_base == nullptr) { + throw std::runtime_error("OrtGetApiBase() returned nullptr"); + } + + // loop from the ORT version GenAI was built with, down to the minimum ORT version we require. + // as long as the libonnxruntime.so we loaded supports one of those we're good. + constexpr int genai_min_ort_api_version = 18; // GenAI was first released around the time of ORT 1.18 so use that + for (int i = ORT_API_VERSION; i >= genai_min_ort_api_version; --i) { + api = ort_api_base->GetApi(i); + if (api) { + LOG_INFO("ORT API Version %d was found.", i); + break; + } + } + + if (!api) { + LOG_WARN("The loaded library did not have an ORT API version between %d and %d.", + ORT_API_VERSION, genai_min_ort_api_version); + throw std::runtime_error("Failed to load onnxruntime. Please make sure you installed the correct version"); + } +} +#endif + inline void InitApi() { -#if defined(__ANDROID__) + if (api) { + // api was already set. + return; + } + +#if defined(__linux__) // If the GenAI library links against the onnxruntime library, it will have a dependency on a specific // version of OrtGetApiBase. // @@ -108,49 +198,56 @@ inline void InitApi() { // const std::string path = "libonnxruntime.so"; // "libonnxruntime4j_jni.so" is also an option if we have issues - __android_log_print(ANDROID_LOG_INFO, "GenAI", "Attempting to dlopen %s native library", path.c_str()); + void* ort_lib_handle = LoadDynamicLibraryIfExists(path); - using OrtApiBaseFn = const OrtApiBase* (*)(void); - OrtApiBaseFn ort_api_base_fn = nullptr; - - void* ort_lib_handle = dlopen(path.c_str(), RTLD_LOCAL); +#if !defined(__ANDROID__) if (ort_lib_handle == nullptr) { - __android_log_assert("ort_lib_handle != nullptr", "GenAI", "Failed to load %s", path.c_str()); - } + const std::array target_libraries = { + std::string("libonnxruntime.so"), + std::string("libonnxruntime.so.1.18.0"), + std::string("libonnxruntime.so.1.19.0"), + std::string("libonnxruntime.so.1.20.0")}; + + // Search parent directory + std::string current_module_dir = GetCurrentModuleDir(); + for (const std::string& lib_name : target_libraries) { + std::string pip_path{current_module_dir + "/" + lib_name}; + ort_lib_handle = LoadDynamicLibraryIfExists(pip_path); + if (ort_lib_handle != nullptr) { + break; + } + } - ort_api_base_fn = (OrtApiBaseFn)dlsym(ort_lib_handle, "OrtGetApiBase"); - if (ort_api_base_fn == nullptr) { - __android_log_assert("ort_api_base_fn != nullptr", "GenAI", "OrtGetApiBase not found"); + if (ort_lib_handle == nullptr) { + // Search for pip installation + for (const std::string& lib_name : target_libraries) { + std::string pip_path{current_module_dir + "/../onnxruntime/capi/" + lib_name}; + ort_lib_handle = LoadDynamicLibraryIfExists(pip_path); + if (ort_lib_handle != nullptr) { + break; + } + } + } } +#endif - const OrtApiBase* ort_api_base = ort_api_base_fn(); - if (ort_api_base == nullptr) { - __android_log_assert("ort_api_base != nullptr", "GenAI", "OrtGetApiBase() returned nullptr"); + if (ort_lib_handle == nullptr) { + char* err = dlerror(); + throw std::runtime_error(std::string("Failed to load ") + path.c_str() + ": " + (err != nullptr ? err : "Unknown")); } - // loop from the ORT version GenAI was built with, down to the minimum ORT version we require. - // as long as the libonnxruntime.so we loaded supports one of those we're good. - constexpr int genai_min_ort_api_version = 18; // GenAI was first released around the time of ORT 1.18 so use that - for (int i = ORT_API_VERSION; i >= genai_min_ort_api_version; --i) { - api = ort_api_base->GetApi(i); - if (!api) { - __android_log_print(ANDROID_LOG_INFO, "GenAI", "ORT API Version %d was not found.", i); - } else { - __android_log_print(ANDROID_LOG_INFO, "GenAI", "ORT API Version %d was found", i); - break; - } + OrtApiBaseFn ort_api_base_fn = (OrtApiBaseFn)dlsym(ort_lib_handle, "OrtGetApiBase"); + if (ort_api_base_fn == nullptr) { + char* err = dlerror(); + throw std::runtime_error(std::string("Failed to load symbol OrtGetApiBase: ") + (err != nullptr ? err : "Unknown")); } - if (!api) { - __android_log_assert("api != nullptr", "GenAI", - "%s did not have an ORT API version between %d and %d.", - path.c_str(), ORT_API_VERSION, genai_min_ort_api_version); - } -#else // defined(__ANDROID__) + InitApiWithDynamicFn(ort_api_base_fn); +#else // defined(__linux__) api = OrtGetApiBase()->GetApi(ORT_API_VERSION); if (!api) throw std::runtime_error("Onnxruntime is installed but is too old, please install a newer version"); -#endif // defined(__ANDROID__) +#endif // defined(__linux__) } /** \brief All C++ methods that can fail will throw an exception of this type diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 07dc86cf4..189f4135c 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -7,7 +7,12 @@ file(GLOB python_srcs CONFIGURE_DEPENDS pybind11_add_module(python ${python_srcs}) target_include_directories(python PRIVATE ${ORT_HEADER_DIR}) target_link_directories(python PRIVATE ${ORT_LIB_DIR}) -target_link_libraries(python PRIVATE onnxruntime-genai-static ${ONNXRUNTIME_LIB}) +target_link_libraries(python PRIVATE onnxruntime-genai-static) + +if(NOT (CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Linux")) +target_link_libraries(python PRIVATE ${ONNXRUNTIME_LIB}) +endif() + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") set_property(TARGET python APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN/../onnxruntime/capi/") endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1d31e0f61..da3502bb4 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -23,10 +23,13 @@ target_include_directories(unit_tests PRIVATE target_link_directories(unit_tests PRIVATE ${ORT_LIB_DIR}) target_link_libraries(unit_tests PRIVATE onnxruntime-genai-static - ${ONNXRUNTIME_LIB} GTest::gtest_main ) +if(NOT (CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Linux")) +target_link_libraries(unit_tests PRIVATE ${ONNXRUNTIME_LIB}) +endif() + if(USE_CUDA AND CMAKE_CUDA_COMPILER) file(GLOB cuda_test_srcs CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.cu"