Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dlopen in favor of direct dependency of ORT on Linux #724

Merged
merged 34 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/linux-cpu-arm64-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ jobs:
run: |
docker run --rm \
--volume $GITHUB_WORKSPACE:/onnxruntime_src \
-w /onnxruntime_src ort_genai_linux_arm64_gha bash -c "/onnxruntime_src/build/cpu/test/unit_tests"
-w /onnxruntime_src ort_genai_linux_arm64_gha bash -c "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/onnxruntime_src/build/cpu/ /onnxruntime_src/build/cpu/test/unit_tests"
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ target_link_directories(onnxruntime-genai PRIVATE ${ORT_LIB_DIR})

# we keep the shared libraries disconnected on Android as they will come from separate AARs and we don't want to force
# the ORT version to match in both.
if(NOT ANDROID)
if(NOT (CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Linux"))
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
target_link_libraries(onnxruntime-genai PRIVATE ${ONNXRUNTIME_LIB})
endif()

Expand Down
18 changes: 17 additions & 1 deletion src/logging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "json.h"
#include <iostream>
#include <fstream>
#include <cstdarg>

namespace Generators {

Expand Down Expand Up @@ -87,4 +88,19 @@ std::ostream& Log(std::string_view label, std::string_view string) {
return *gp_stream;
}

} // namespace Generators
std::ostream& Log(std::string_view label, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
va_list args_copy;
va_copy(args_copy, args);
size_t len = vsnprintf(0, 0, fmt, args_copy);
if (len <= 0) {
throw std::runtime_error("Invalid format");
}
std::unique_ptr<char[]> buf(new char[len + 1]);
vsnprintf(buf.get(), len + 1, fmt, args);
va_end(args);
return Log(label, std::string(buf.get(), buf.get() + len));
}

} // namespace Generators
3 changes: 2 additions & 1 deletion src/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@ std::ostream& operator<<(std::ostream& stream, SGR sgr_code);

std::ostream& Log(std::string_view label, std::string_view text = {});

} // namespace Generators
std::ostream& Log(std::string_view label, const char* fmt, ...);
} // namespace Generators
150 changes: 117 additions & 33 deletions src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,114 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o
#include <string>
#include <vector>
#include <unordered_map>
#include <array>

#include "onnxruntime_c_api.h"
#include "../span.h"
#include "../logging.h"

#if defined(__ANDROID__)
#include <android/log.h>
#include <dlfcn.h>

#define TAG "GenAI"

#define LOG_DEBUG(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
#define LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOG_WARN(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
#define LOG_ERROR(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
#define LOG_FATAL(...) __android_log_print(ANDROID_LOG_FATAL, TAG, __VA_ARGS__)

#elif defined(__linux__)
#include <dlfcn.h>

#ifndef PATH_MAX
#define PATH_MAX (4096)
#endif

#define LOG_DEBUG(...) Generators::Log("debug", __VA_ARGS__)
#define LOG_INFO(...) Generators::Log("info", __VA_ARGS__)
#define LOG_WARN(...) Generators::Log("warning", __VA_ARGS__)
#define LOG_ERROR(...) Generators::Log("error", __VA_ARGS__)
#define LOG_FATAL(...) Generators::Log("fatal", __VA_ARGS__)

#endif

/** \brief Free functions and a few helpers are defined inside this namespace. Otherwise all types are the C API types
*
*/
namespace Ort {

using OrtApiBaseFn = const OrtApiBase* (*)(void);

/// Before using this C++ wrapper API, you MUST call Ort::InitApi to set the below 'api' variable
inline const OrtApi* api{};

#if defined(__linux__)
inline void* LoadDynamicLibraryIfExists(const std::string& path) {
LOG_INFO("Attempting to dlopen %s native library", path.c_str());
void* ort_lib_handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
if (ort_lib_handle == nullptr) {
return nullptr;
}

#if !defined(__ANDROID__) // RTLD_DI_ORIGIN not available on Android
char pathname[PATH_MAX];
dlinfo((void*)ort_lib_handle, RTLD_DI_ORIGIN, &pathname);
LOG_INFO("Loaded native library at %s", pathname);
#endif
return ort_lib_handle;
}

inline std::string GetCurrentModuleDir() {
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
Dl_info dl_info;
dladdr((void*)GetCurrentModuleDir, &dl_info);
std::string module_name(dl_info.dli_fname);
std::string module_directory{};

const size_t last_slash_idx = module_name.rfind('/');
if (std::string::npos != last_slash_idx) {
module_directory = module_name.substr(0, last_slash_idx);
}
return module_directory;
}

inline void InitApiWithDynamicFn(OrtApiBaseFn ort_api_base_fn) {
if (ort_api_base_fn == nullptr) {
throw std::runtime_error("OrtGetApiBase not found");
}

const OrtApiBase* ort_api_base = ort_api_base_fn();
if (ort_api_base == nullptr) {
throw std::runtime_error("OrtGetApiBase() returned nullptr");
}

// loop from the ORT version GenAI was built with, down to the minimum ORT version we require.
// as long as the libonnxruntime.so we loaded supports one of those we're good.
constexpr int genai_min_ort_api_version = 18; // GenAI was first released around the time of ORT 1.18 so use that
for (int i = ORT_API_VERSION; i >= genai_min_ort_api_version; --i) {
api = ort_api_base->GetApi(i);
if (api) {
LOG_INFO("ORT API Version %d was found.", i);
break;
}
}

if (!api) {
LOG_WARN("The loaded library did not have an ORT API version between %d and %d.",
ORT_API_VERSION, genai_min_ort_api_version);
throw std::runtime_error("Failed to load onnxruntime. Please make sure you installed the correct version");
}
}
#endif

inline void InitApi() {
#if defined(__ANDROID__)
if (api) {
// api was already set.
return;
}

#if defined(__linux__)
// If the GenAI library links against the onnxruntime library, it will have a dependency on a specific
// version of OrtGetApiBase.
//
Expand All @@ -108,49 +198,43 @@ inline void InitApi() {
//

const std::string path = "libonnxruntime.so"; // "libonnxruntime4j_jni.so" is also an option if we have issues
__android_log_print(ANDROID_LOG_INFO, "GenAI", "Attempting to dlopen %s native library", path.c_str());
void* ort_lib_handle = LoadDynamicLibraryIfExists(path);

using OrtApiBaseFn = const OrtApiBase* (*)(void);
OrtApiBaseFn ort_api_base_fn = nullptr;

void* ort_lib_handle = dlopen(path.c_str(), RTLD_LOCAL);
#if !defined(__ANDROID__)
// Search for pip installation
if (ort_lib_handle == nullptr) {
__android_log_assert("ort_lib_handle != nullptr", "GenAI", "Failed to load %s", path.c_str());
}

ort_api_base_fn = (OrtApiBaseFn)dlsym(ort_lib_handle, "OrtGetApiBase");
if (ort_api_base_fn == nullptr) {
__android_log_assert("ort_api_base_fn != nullptr", "GenAI", "OrtGetApiBase not found");
const std::array<std::string, 4> target_libraries = {
std::string("libonnxruntime.so"),
std::string("libonnxruntime.so.1.18.0"),
std::string("libonnxruntime.so.1.19.0"),
std::string("libonnxruntime.so.1.20.0")};
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved

std::string current_module_dir = GetCurrentModuleDir();
for (const std::string& lib_name : target_libraries) {
std::string pip_path{current_module_dir + "/../onnxruntime/capi/" + lib_name};
ort_lib_handle = LoadDynamicLibraryIfExists(pip_path);
if (ort_lib_handle != nullptr) {
break;
}
}
}
#endif

const OrtApiBase* ort_api_base = ort_api_base_fn();
if (ort_api_base == nullptr) {
__android_log_assert("ort_api_base != nullptr", "GenAI", "OrtGetApiBase() returned nullptr");
if (ort_lib_handle == nullptr) {
throw std::runtime_error(std::string("Failed to load ") + path.c_str() + ": " + dlerror());
}

// loop from the ORT version GenAI was built with, down to the minimum ORT version we require.
// as long as the libonnxruntime.so we loaded supports one of those we're good.
constexpr int genai_min_ort_api_version = 18; // GenAI was first released around the time of ORT 1.18 so use that
for (int i = ORT_API_VERSION; i >= genai_min_ort_api_version; --i) {
api = ort_api_base->GetApi(i);
if (!api) {
__android_log_print(ANDROID_LOG_INFO, "GenAI", "ORT API Version %d was not found.", i);
} else {
__android_log_print(ANDROID_LOG_INFO, "GenAI", "ORT API Version %d was found", i);
break;
}
OrtApiBaseFn ort_api_base_fn = (OrtApiBaseFn)dlsym(ort_lib_handle, "OrtGetApiBase");
if (ort_api_base_fn == nullptr) {
throw std::runtime_error(std::string("Failed to load symbol OrtGetApiBase: ") + dlerror());
}

if (!api) {
__android_log_assert("api != nullptr", "GenAI",
"%s did not have an ORT API version between %d and %d.",
path.c_str(), ORT_API_VERSION, genai_min_ort_api_version);
}
#else // defined(__ANDROID__)
InitApiWithDynamicFn(ort_api_base_fn);
#else // defined(__linux__)
api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
if (!api)
throw std::runtime_error("Onnxruntime is installed but is too old, please install a newer version");
#endif // defined(__ANDROID__)
#endif // defined(__linux__)
}

/** \brief All C++ methods that can fail will throw an exception of this type
Expand Down
7 changes: 6 additions & 1 deletion src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ file(GLOB python_srcs CONFIGURE_DEPENDS
pybind11_add_module(python ${python_srcs})
target_include_directories(python PRIVATE ${ORT_HEADER_DIR})
target_link_directories(python PRIVATE ${ORT_LIB_DIR})
target_link_libraries(python PRIVATE onnxruntime-genai-static ${ONNXRUNTIME_LIB})
target_link_libraries(python PRIVATE onnxruntime-genai-static)

if(NOT (CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Linux"))
target_link_libraries(python PRIVATE ${ONNXRUNTIME_LIB})
endif()

if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
set_property(TARGET python APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN/../onnxruntime/capi/")
endif()
Expand Down
5 changes: 4 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ target_include_directories(unit_tests PRIVATE
target_link_directories(unit_tests PRIVATE ${ORT_LIB_DIR})
target_link_libraries(unit_tests PRIVATE
onnxruntime-genai-static
${ONNXRUNTIME_LIB}
GTest::gtest_main
)

if(NOT (CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Linux"))
target_link_libraries(unit_tests PRIVATE ${ONNXRUNTIME_LIB})
endif()

if(USE_CUDA AND CMAKE_CUDA_COMPILER)
file(GLOB cuda_test_srcs CONFIGURE_DEPENDS
"${CMAKE_CURRENT_SOURCE_DIR}/*.cu"
Expand Down
Loading