Skip to content

Commit

Permalink
add: VOICEVOX CORE用の初期化経路を構築する (#6)
Browse files Browse the repository at this point in the history
* add: VOICEVOX CORE用の初期化経路を構築する

* エラーと警告のメッセージを日本語に

* `G_ORT_API_FOR_ENV_BUILD`を`LocalKey`にする

* Minor refactor
  • Loading branch information
qryxip authored Jul 1, 2024
1 parent b6c41c6 commit 07c047c
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 2 deletions.
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ strip = true
codegen-units = 1

[package.metadata.docs.rs]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs", "__init-for-voicevox" ]
targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
rustdoc-args = [ "--cfg", "docsrs" ]

Expand Down Expand Up @@ -80,9 +80,16 @@ vitis = [ "voicevox-ort-sys/vitis" ]
cann = [ "voicevox-ort-sys/cann" ]
qnn = [ "voicevox-ort-sys/qnn" ]

# 動的ライブラリの読み込みから`OrtEnv`の作成までを、VOICEVOX独自の方法で行えるようにする。
#
# ortとしての通常の初期化の経路は禁止される。
__init-for-voicevox = []

[dependencies]
anyhow = "1.0"
ndarray = { version = "0.15", optional = true }
thiserror = "1.0"
once_cell = "1.19.0"
voicevox-ort-sys = { version = "2.0.0-rc.2", path = "ort-sys" }
libloading = { version = "0.8", optional = true }

Expand All @@ -101,7 +108,6 @@ js-sys = "0.3"
web-sys = "0.3"

[dev-dependencies]
anyhow = "1.0"
ureq = "2.1"
image = "0.25"
test-log = { version = "0.2", default-features = false, features = [ "trace" ] }
Expand Down
1 change: 1 addition & 0 deletions ort-sys/VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1.17.3
10 changes: 10 additions & 0 deletions ort-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ fn prepare_libort_dir() -> (PathBuf, bool) {
copy_libraries(&lib_dir.join("lib"), &out_dir);
}

const OUR_VERSION: &str = include_str!("./VERSION_NUMBER");
let their_version = fs::read_to_string(lib_dir.join("VERSION_NUMBER")).unwrap_or_else(|e| panic!("`VERSION_NUMBER`を読めませんでした: {e}"));
assert_eq!(OUR_VERSION.trim_end(), their_version.trim_end(), "`VERSION_NUMBER`が異なります");

(lib_dir, true)
}
#[cfg(not(feature = "download-binaries"))]
Expand Down Expand Up @@ -421,6 +425,12 @@ fn real_main(link: bool) {
}

fn main() {
if cfg!(feature = "download-binaries") {
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let version = include_str!("./VERSION_NUMBER").trim_end();
fs::write(out_dir.join("downloaded_version.rs"), format!("#[macro_export] macro_rules! downloaded_version(() => ({version:?}));")).unwrap();
}

if env::var("DOCS_RS").is_ok() {
return;
}
Expand Down
2 changes: 2 additions & 0 deletions ort-sys/src/internal/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
pub mod dirs;

include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs"));
6 changes: 6 additions & 0 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,17 @@ impl EnvironmentBuilder {

/// Commit the environment configuration and set the global environment.
pub fn commit(self) -> Result<()> {
if cfg!(feature = "__init-for-voicevox") {
panic!("`__init-for-voicevox`により禁止されています");
}
// drop global reference to previous environment
if let Some(env_arc) = unsafe { (*G_ENV.cell.get()).take() } {
drop(env_arc);
}
self.commit_()
}

pub(crate) fn commit_(self) -> Result<()> {
let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
Expand Down
184 changes: 184 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ use std::{
};

pub use ort_sys as sys;
#[cfg(feature = "download-binaries")]
pub use ort_sys::downloaded_version;

#[cfg(feature = "load-dynamic")]
pub use self::environment::init_from;
Expand Down Expand Up @@ -73,6 +75,26 @@ pub use self::value::{
ValueRefMut, ValueType, ValueTypeMarker
};

/// このクレートのフィーチャが指定された状態になっていなければコンパイルエラー。
#[cfg(feature = "load-dynamic")]
#[macro_export]
macro_rules! assert_feature {
(cfg(feature = "load-dynamic"), $msg:literal $(,)?) => {};
(cfg(not(feature = "load-dynamic")), $msg:literal $(,)?) => {
::std::compile_error!($msg);
};
}

/// このクレートのフィーチャが指定された状態になっていなければコンパイルエラー。
#[cfg(not(feature = "load-dynamic"))]
#[macro_export]
macro_rules! assert_feature {
(cfg(feature = "load-dynamic"), $msg:literal $(,)?) => {
::std::compile_error!($msg);
};
(cfg(not(feature = "load-dynamic")), $msg:literal $(,)?) => {};
}

#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
Expand Down Expand Up @@ -100,6 +122,9 @@ pub(crate) static G_ORT_LIB: OnceLock<Arc<libloading::Library>> = OnceLock::new(

#[cfg(feature = "load-dynamic")]
pub(crate) fn dylib_path() -> &'static String {
if cfg!(feature = "__init-for-voicevox") {
panic!("`__init-for-voicevox`により禁止されています");
}
G_ORT_DYLIB_PATH.get_or_init(|| {
let path = match std::env::var("ORT_DYLIB_PATH") {
Ok(s) if !s.is_empty() => s,
Expand All @@ -116,6 +141,13 @@ pub(crate) fn dylib_path() -> &'static String {

#[cfg(feature = "load-dynamic")]
pub(crate) fn lib_handle() -> &'static libloading::Library {
#[cfg(feature = "__init-for-voicevox")]
if true {
return &G_ENV_FOR_VOICEVOX
.get()
.expect("`try_init_from`または`try_init`で初期化されていなくてはなりません")
.dylib;
}
G_ORT_LIB.get_or_init(|| {
// resolve path relative to executable
let path: std::path::PathBuf = dylib_path().into();
Expand All @@ -135,6 +167,150 @@ pub(crate) fn lib_handle() -> &'static libloading::Library {
})
}

#[cfg(feature = "__init-for-voicevox")]
static G_ENV_FOR_VOICEVOX: once_cell::sync::OnceCell<EnvHandle> = once_cell::sync::OnceCell::new();

#[cfg(feature = "__init-for-voicevox")]
thread_local! {
static G_ORT_API_FOR_ENV_BUILD: std::cell::Cell<Option<NonNull<ort_sys::OrtApi>>> = const { std::cell::Cell::new(None) };
}

#[cfg(feature = "__init-for-voicevox")]
#[cfg_attr(docsrs, doc(cfg(feature = "__init-for-voicevox")))]
#[derive(Debug)]
pub struct EnvHandle {
_env: std::sync::Arc<Environment>,
api: AssertSendSync<NonNull<ort_sys::OrtApi>>,
#[cfg(feature = "load-dynamic")]
dylib: libloading::Library
}

#[cfg(feature = "__init-for-voicevox")]
impl EnvHandle {
/// インスタンスが既に作られているならそれを得る。
///
/// 作られていなければ`None`。
pub fn get() -> Option<&'static Self> {
G_ENV_FOR_VOICEVOX.get()
}
}

#[cfg(feature = "__init-for-voicevox")]
#[derive(Clone, Copy, Debug)]
struct AssertSendSync<T>(T);

// SAFETY: `OrtApi`はスレッドセーフとされているはず
#[cfg(feature = "__init-for-voicevox")]
unsafe impl Send for AssertSendSync<NonNull<ort_sys::OrtApi>> {}

// SAFETY: `OrtApi`はスレッドセーフとされているはず
#[cfg(feature = "__init-for-voicevox")]
unsafe impl Sync for AssertSendSync<NonNull<ort_sys::OrtApi>> {}

/// VOICEVOX CORE用に、`OrtEnv`の作成までをやる。
///
/// 一度成功したら以後は同じ参照を返す。
#[cfg(all(feature = "__init-for-voicevox", feature = "load-dynamic"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "__init-for-voicevox", feature = "load-dynamic"))))]
pub fn try_init_from(filename: &std::ffi::OsStr, tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyhow::Result<&'static EnvHandle> {
use anyhow::bail;
use ort_sys::ORT_API_VERSION;

G_ENV_FOR_VOICEVOX.get_or_try_init(|| {
let (dylib, api) = unsafe {
let dylib = libloading::Library::new(filename)?;

// この下にある`api()`のものをできるだけ真似る

let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const ort_sys::OrtApiBase> = dylib
.get(b"OrtGetApiBase")
.expect("`OrtGetApiBase` must be present in ONNX Runtime dylib");
let base: *const ort_sys::OrtApiBase = base_getter();
assert_ne!(base, ptr::null());

let get_version_string: extern_system_fn! { unsafe fn () -> *const c_char } =
(*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`");
let version_string = get_version_string();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
tracing::info!("Loaded ONNX Runtime dylib with version '{version_string}'");

let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::<u32>().unwrap_or(0));
match lib_minor_version.cmp(&MINOR_VERSION) {
std::cmp::Ordering::Less if cfg!(windows) => {
bail!(r"`{dylib:?}`はバージョン{version_string}のONNX Runtimeです。ONNX Runtimeはバージョン1.{MINOR_VERSION}でなくてはなりません");
}
std::cmp::Ordering::Less => bail!(
"`{filename}`で指定されたONNX Runtimeはバージョン{version_string}です。ONNX Runtimeはバージョン1.{MINOR_VERSION}でなくてはなりません",
filename = filename.to_string_lossy(),
),
std::cmp::Ordering::Greater => tracing::warn!(
"`{filename}`で指定されたONNX Runtimeはバージョン{version_string}です。対応しているONNX Runtimeのバージョンは1.{MINOR_VERSION}なので、\
互換性の問題があるかもしれません",
filename = filename.to_string_lossy(),
),
std::cmp::Ordering::Equal => {}
};

let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
let api = get_api(ORT_API_VERSION);
(dylib, api)
};
let api = AssertSendSync(NonNull::new(api.cast_mut()).unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました")));

let _env = create_env(api.0, tp_options)?;

Ok(EnvHandle { _env, api, dylib })
})
}

/// VOICEVOX CORE用に、`OrtEnv`の作成までをやる。
///
/// 一度成功したら以後は同じ参照を返す。
#[cfg(all(feature = "__init-for-voicevox", any(doc, not(feature = "load-dynamic"))))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "__init-for-voicevox", not(feature = "load-dynamic")))))]
pub fn try_init(tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyhow::Result<&'static EnvHandle> {
use ort_sys::ORT_API_VERSION;

G_ENV_FOR_VOICEVOX.get_or_try_init(|| {
let api = unsafe {
// この下にある`api()`のものをできるだけ真似る
let base: *const ort_sys::OrtApiBase = ort_sys::OrtGetApiBase();
assert_ne!(base, ptr::null());
let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
get_api(ORT_API_VERSION)
};
let api = NonNull::new(api.cast_mut())
.unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました。おそらく1.{MINOR_VERSION}より古いものがリンクされています"));
let api = AssertSendSync(api);

let _env = create_env(api.0, tp_options)?;

Ok(EnvHandle { _env, api })
})
}

#[cfg(feature = "__init-for-voicevox")]
fn create_env(api: NonNull<ort_sys::OrtApi>, tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyhow::Result<std::sync::Arc<Environment>> {
G_ORT_API_FOR_ENV_BUILD.set(Some(api));
let _unset_api = UnsetOrtApi;

let mut env = EnvironmentBuilder::default().with_name(env!("CARGO_PKG_NAME"));
if let Some(tp_options) = tp_options {
env = env.with_global_thread_pool(tp_options);
}
env.commit_()?;

return Ok(get_environment().expect("失敗しないはず").clone());

struct UnsetOrtApi;

impl Drop for UnsetOrtApi {
fn drop(&mut self) {
G_ORT_API_FOR_ENV_BUILD.set(None);
}
}
}

pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::new();

/// Returns a pointer to the global [`ort_sys::OrtApi`] object.
Expand All @@ -144,6 +320,14 @@ pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::ne
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
pub fn api() -> NonNull<ort_sys::OrtApi> {
#[cfg(feature = "__init-for-voicevox")]
if true {
return G_ENV_FOR_VOICEVOX
.get()
.map(|&EnvHandle { api: AssertSendSync(api), .. }| api)
.or_else(|| G_ORT_API_FOR_ENV_BUILD.get())
.expect("`try_init_from`または`try_init`で初期化されていなくてはなりません");
}
unsafe {
NonNull::new_unchecked(
G_ORT_API
Expand Down

0 comments on commit 07c047c

Please sign in to comment.