diff --git a/src/models/onnxruntime_api.h b/src/models/onnxruntime_api.h index d26eeebc4..9b767ea32 100644 --- a/src/models/onnxruntime_api.h +++ b/src/models/onnxruntime_api.h @@ -116,6 +116,34 @@ inline void InitApi() { } #if defined(__ANDROID__) || defined(__linux__) + inline void InitApiWithDynamicFn(OrtApiBaseFn ort_api_base_fn) { + if (ort_api_base_fn == nullptr) { + throw std::runtime_error("OrtGetApiBase not found"); + } + + const OrtApiBase* ort_api_base = ort_api_base_fn(); + if (ort_api_base == nullptr) { + throw std::runtime_error("OrtGetApiBase() returned nullptr"); + } + + // loop from the ORT version GenAI was built with, down to the minimum ORT version we require. + // as long as the libonnxruntime.so we loaded supports one of those we're good. + constexpr int genai_min_ort_api_version = 18; // GenAI was first released around the time of ORT 1.18 so use that + for (int i = ORT_API_VERSION; i >= genai_min_ort_api_version; --i) { + api = ort_api_base->GetApi(i); + if (api) { + LOG_INFO("ORT API Version %d was found.", i); + break; + } + } + + if (!api) { + LOG_WARN("The loaded library did not have an ORT API version between %d and %d.", + ORT_API_VERSION, genai_min_ort_api_version); + throw std::runtime_error("Failed to load onnxruntime. Please make sure you installed the correct version"); + } + } + // If the GenAI library links against the onnxruntime library, it will have a dependency on a specific // version of OrtGetApiBase. // @@ -145,8 +173,7 @@ inline void InitApi() { void* ort_lib_handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); if (ort_lib_handle == nullptr) { - LOG_FATAL("Failed to load %s: %s", path.c_str(), dlerror()); - exit(EXIT_FAILURE); + throw std::runtime_error(std::string("Failed to load ") + path.c_str() + ": " + dlerror()); } #if !defined(__ANDROID__) // RTLD_DI_ORIGIN not available on Android @@ -157,33 +184,10 @@ inline void InitApi() { ort_api_base_fn = (OrtApiBaseFn)dlsym(ort_lib_handle, "OrtGetApiBase"); if (ort_api_base_fn == nullptr) { - LOG_FATAL("OrtGetApiBase not found"); - exit(EXIT_FAILURE); - } - - const OrtApiBase* ort_api_base = ort_api_base_fn(); - if (ort_api_base == nullptr) { - LOG_FATAL("OrtGetApiBase() returned nullptr"); - exit(EXIT_FAILURE); - } - - // loop from the ORT version GenAI was built with, down to the minimum ORT version we require. - // as long as the libonnxruntime.so we loaded supports one of those we're good. - constexpr int genai_min_ort_api_version = 18; // GenAI was first released around the time of ORT 1.18 so use that - for (int i = ORT_API_VERSION; i >= genai_min_ort_api_version; --i) { - api = ort_api_base->GetApi(i); - if (!api) { - LOG_INFO("ORT API Version %d was not found.", i); - } else { - LOG_INFO("ORT API Version %d was found.", i); - break; - } + throw std::runtime_error("OrtGetApiBase not found"); } - if (!api) { - LOG_WARN("%s did not have an ORT API version between %d and %d.", - path.c_str(), ORT_API_VERSION, genai_min_ort_api_version); - } + InitApiWithDynamicFn(ort_api_base_fn); #else // defined(__ANDROID__) api = OrtGetApiBase()->GetApi(ORT_API_VERSION); if (!api)