From 07c047c449b959d8f76593046e139bae520d59c3 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 2 Jul 2024 04:33:15 +0900 Subject: [PATCH] =?UTF-8?q?add:=20VOICEVOX=20CORE=E7=94=A8=E3=81=AE?= =?UTF-8?q?=E5=88=9D=E6=9C=9F=E5=8C=96=E7=B5=8C=E8=B7=AF=E3=82=92=E6=A7=8B?= =?UTF-8?q?=E7=AF=89=E3=81=99=E3=82=8B=20(#6)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add: VOICEVOX CORE用の初期化経路を構築する * エラーと警告のメッセージを日本語に * `G_ORT_API_FOR_ENV_BUILD`を`LocalKey`にする * Minor refactor --- Cargo.toml | 10 +- ort-sys/VERSION_NUMBER | 1 + ort-sys/build.rs | 10 ++ ort-sys/src/internal/mod.rs | 2 + src/environment.rs | 6 ++ src/lib.rs | 184 ++++++++++++++++++++++++++++++++++++ 6 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 ort-sys/VERSION_NUMBER diff --git a/Cargo.toml b/Cargo.toml index 31f56d5d..e7b3785e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ strip = true codegen-units = 1 [package.metadata.docs.rs] -features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ] +features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs", "__init-for-voicevox" ] targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"] rustdoc-args = [ "--cfg", "docsrs" ] @@ -80,9 +80,16 @@ vitis = [ "voicevox-ort-sys/vitis" ] cann = [ "voicevox-ort-sys/cann" ] qnn = [ "voicevox-ort-sys/qnn" ] +# 動的ライブラリの読み込みから`OrtEnv`の作成までを、VOICEVOX独自の方法で行えるようにする。 +# +# ortとしての通常の初期化の経路は禁止される。 +__init-for-voicevox = [] + [dependencies] +anyhow = "1.0" ndarray = { version = "0.15", optional = true } thiserror = "1.0" +once_cell = "1.19.0" voicevox-ort-sys = { version = "2.0.0-rc.2", path = "ort-sys" } libloading = { version = "0.8", optional = true } @@ -101,7 +108,6 @@ js-sys = "0.3" web-sys = "0.3" [dev-dependencies] -anyhow = "1.0" ureq = "2.1" image = "0.25" test-log = { version = "0.2", default-features = false, features = [ "trace" ] } diff --git a/ort-sys/VERSION_NUMBER b/ort-sys/VERSION_NUMBER new file mode 100644 index 00000000..b9a05a6d --- /dev/null +++ b/ort-sys/VERSION_NUMBER @@ -0,0 +1 @@ +1.17.3 diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 0e853858..8ceefd79 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -392,6 +392,10 @@ fn prepare_libort_dir() -> (PathBuf, bool) { copy_libraries(&lib_dir.join("lib"), &out_dir); } + const OUR_VERSION: &str = include_str!("./VERSION_NUMBER"); + let their_version = fs::read_to_string(lib_dir.join("VERSION_NUMBER")).unwrap_or_else(|e| panic!("`VERSION_NUMBER`を読めませんでした: {e}")); + assert_eq!(OUR_VERSION.trim_end(), their_version.trim_end(), "`VERSION_NUMBER`が異なります"); + (lib_dir, true) } #[cfg(not(feature = "download-binaries"))] @@ -421,6 +425,12 @@ fn real_main(link: bool) { } fn main() { + if cfg!(feature = "download-binaries") { + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let version = include_str!("./VERSION_NUMBER").trim_end(); + fs::write(out_dir.join("downloaded_version.rs"), format!("#[macro_export] macro_rules! downloaded_version(() => ({version:?}));")).unwrap(); + } + if env::var("DOCS_RS").is_ok() { return; } diff --git a/ort-sys/src/internal/mod.rs b/ort-sys/src/internal/mod.rs index 16ec3672..c13a982e 100644 --- a/ort-sys/src/internal/mod.rs +++ b/ort-sys/src/internal/mod.rs @@ -1 +1,3 @@ pub mod dirs; + +include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs")); diff --git a/src/environment.rs b/src/environment.rs index 9b82d0d1..58d10681 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -123,11 +123,17 @@ impl EnvironmentBuilder { /// Commit the environment configuration and set the global environment. pub fn commit(self) -> Result<()> { + if cfg!(feature = "__init-for-voicevox") { + panic!("`__init-for-voicevox`により禁止されています"); + } // drop global reference to previous environment if let Some(env_arc) = unsafe { (*G_ENV.cell.get()).take() } { drop(env_arc); } + self.commit_() + } + pub(crate) fn commit_(self) -> Result<()> { let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options { let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); diff --git a/src/lib.rs b/src/lib.rs index ed5a2adf..bd8f560b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,8 @@ use std::{ }; pub use ort_sys as sys; +#[cfg(feature = "download-binaries")] +pub use ort_sys::downloaded_version; #[cfg(feature = "load-dynamic")] pub use self::environment::init_from; @@ -73,6 +75,26 @@ pub use self::value::{ ValueRefMut, ValueType, ValueTypeMarker }; +/// このクレートのフィーチャが指定された状態になっていなければコンパイルエラー。 +#[cfg(feature = "load-dynamic")] +#[macro_export] +macro_rules! assert_feature { + (cfg(feature = "load-dynamic"), $msg:literal $(,)?) => {}; + (cfg(not(feature = "load-dynamic")), $msg:literal $(,)?) => { + ::std::compile_error!($msg); + }; +} + +/// このクレートのフィーチャが指定された状態になっていなければコンパイルエラー。 +#[cfg(not(feature = "load-dynamic"))] +#[macro_export] +macro_rules! assert_feature { + (cfg(feature = "load-dynamic"), $msg:literal $(,)?) => { + ::std::compile_error!($msg); + }; + (cfg(not(feature = "load-dynamic")), $msg:literal $(,)?) => {}; +} + #[cfg(not(all(target_arch = "x86", target_os = "windows")))] macro_rules! extern_system_fn { ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*); @@ -100,6 +122,9 @@ pub(crate) static G_ORT_LIB: OnceLock> = OnceLock::new( #[cfg(feature = "load-dynamic")] pub(crate) fn dylib_path() -> &'static String { + if cfg!(feature = "__init-for-voicevox") { + panic!("`__init-for-voicevox`により禁止されています"); + } G_ORT_DYLIB_PATH.get_or_init(|| { let path = match std::env::var("ORT_DYLIB_PATH") { Ok(s) if !s.is_empty() => s, @@ -116,6 +141,13 @@ pub(crate) fn dylib_path() -> &'static String { #[cfg(feature = "load-dynamic")] pub(crate) fn lib_handle() -> &'static libloading::Library { + #[cfg(feature = "__init-for-voicevox")] + if true { + return &G_ENV_FOR_VOICEVOX + .get() + .expect("`try_init_from`または`try_init`で初期化されていなくてはなりません") + .dylib; + } G_ORT_LIB.get_or_init(|| { // resolve path relative to executable let path: std::path::PathBuf = dylib_path().into(); @@ -135,6 +167,150 @@ pub(crate) fn lib_handle() -> &'static libloading::Library { }) } +#[cfg(feature = "__init-for-voicevox")] +static G_ENV_FOR_VOICEVOX: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); + +#[cfg(feature = "__init-for-voicevox")] +thread_local! { + static G_ORT_API_FOR_ENV_BUILD: std::cell::Cell>> = const { std::cell::Cell::new(None) }; +} + +#[cfg(feature = "__init-for-voicevox")] +#[cfg_attr(docsrs, doc(cfg(feature = "__init-for-voicevox")))] +#[derive(Debug)] +pub struct EnvHandle { + _env: std::sync::Arc, + api: AssertSendSync>, + #[cfg(feature = "load-dynamic")] + dylib: libloading::Library +} + +#[cfg(feature = "__init-for-voicevox")] +impl EnvHandle { + /// インスタンスが既に作られているならそれを得る。 + /// + /// 作られていなければ`None`。 + pub fn get() -> Option<&'static Self> { + G_ENV_FOR_VOICEVOX.get() + } +} + +#[cfg(feature = "__init-for-voicevox")] +#[derive(Clone, Copy, Debug)] +struct AssertSendSync(T); + +// SAFETY: `OrtApi`はスレッドセーフとされているはず +#[cfg(feature = "__init-for-voicevox")] +unsafe impl Send for AssertSendSync> {} + +// SAFETY: `OrtApi`はスレッドセーフとされているはず +#[cfg(feature = "__init-for-voicevox")] +unsafe impl Sync for AssertSendSync> {} + +/// VOICEVOX CORE用に、`OrtEnv`の作成までをやる。 +/// +/// 一度成功したら以後は同じ参照を返す。 +#[cfg(all(feature = "__init-for-voicevox", feature = "load-dynamic"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "__init-for-voicevox", feature = "load-dynamic"))))] +pub fn try_init_from(filename: &std::ffi::OsStr, tp_options: Option) -> anyhow::Result<&'static EnvHandle> { + use anyhow::bail; + use ort_sys::ORT_API_VERSION; + + G_ENV_FOR_VOICEVOX.get_or_try_init(|| { + let (dylib, api) = unsafe { + let dylib = libloading::Library::new(filename)?; + + // この下にある`api()`のものをできるだけ真似る + + let base_getter: libloading::Symbol *const ort_sys::OrtApiBase> = dylib + .get(b"OrtGetApiBase") + .expect("`OrtGetApiBase` must be present in ONNX Runtime dylib"); + let base: *const ort_sys::OrtApiBase = base_getter(); + assert_ne!(base, ptr::null()); + + let get_version_string: extern_system_fn! { unsafe fn () -> *const c_char } = + (*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`"); + let version_string = get_version_string(); + let version_string = CStr::from_ptr(version_string).to_string_lossy(); + tracing::info!("Loaded ONNX Runtime dylib with version '{version_string}'"); + + let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::().unwrap_or(0)); + match lib_minor_version.cmp(&MINOR_VERSION) { + std::cmp::Ordering::Less if cfg!(windows) => { + bail!(r"`{dylib:?}`はバージョン{version_string}のONNX Runtimeです。ONNX Runtimeはバージョン1.{MINOR_VERSION}でなくてはなりません"); + } + std::cmp::Ordering::Less => bail!( + "`{filename}`で指定されたONNX Runtimeはバージョン{version_string}です。ONNX Runtimeはバージョン1.{MINOR_VERSION}でなくてはなりません", + filename = filename.to_string_lossy(), + ), + std::cmp::Ordering::Greater => tracing::warn!( + "`{filename}`で指定されたONNX Runtimeはバージョン{version_string}です。対応しているONNX Runtimeのバージョンは1.{MINOR_VERSION}なので、\ + 互換性の問題があるかもしれません", + filename = filename.to_string_lossy(), + ), + std::cmp::Ordering::Equal => {} + }; + + let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`"); + let api = get_api(ORT_API_VERSION); + (dylib, api) + }; + let api = AssertSendSync(NonNull::new(api.cast_mut()).unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました"))); + + let _env = create_env(api.0, tp_options)?; + + Ok(EnvHandle { _env, api, dylib }) + }) +} + +/// VOICEVOX CORE用に、`OrtEnv`の作成までをやる。 +/// +/// 一度成功したら以後は同じ参照を返す。 +#[cfg(all(feature = "__init-for-voicevox", any(doc, not(feature = "load-dynamic"))))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "__init-for-voicevox", not(feature = "load-dynamic")))))] +pub fn try_init(tp_options: Option) -> anyhow::Result<&'static EnvHandle> { + use ort_sys::ORT_API_VERSION; + + G_ENV_FOR_VOICEVOX.get_or_try_init(|| { + let api = unsafe { + // この下にある`api()`のものをできるだけ真似る + let base: *const ort_sys::OrtApiBase = ort_sys::OrtGetApiBase(); + assert_ne!(base, ptr::null()); + let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`"); + get_api(ORT_API_VERSION) + }; + let api = NonNull::new(api.cast_mut()) + .unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました。おそらく1.{MINOR_VERSION}より古いものがリンクされています")); + let api = AssertSendSync(api); + + let _env = create_env(api.0, tp_options)?; + + Ok(EnvHandle { _env, api }) + }) +} + +#[cfg(feature = "__init-for-voicevox")] +fn create_env(api: NonNull, tp_options: Option) -> anyhow::Result> { + G_ORT_API_FOR_ENV_BUILD.set(Some(api)); + let _unset_api = UnsetOrtApi; + + let mut env = EnvironmentBuilder::default().with_name(env!("CARGO_PKG_NAME")); + if let Some(tp_options) = tp_options { + env = env.with_global_thread_pool(tp_options); + } + env.commit_()?; + + return Ok(get_environment().expect("失敗しないはず").clone()); + + struct UnsetOrtApi; + + impl Drop for UnsetOrtApi { + fn drop(&mut self) { + G_ORT_API_FOR_ENV_BUILD.set(None); + } + } +} + pub(crate) static G_ORT_API: OnceLock> = OnceLock::new(); /// Returns a pointer to the global [`ort_sys::OrtApi`] object. @@ -144,6 +320,14 @@ pub(crate) static G_ORT_API: OnceLock> = OnceLock::ne /// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime. /// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled. pub fn api() -> NonNull { + #[cfg(feature = "__init-for-voicevox")] + if true { + return G_ENV_FOR_VOICEVOX + .get() + .map(|&EnvHandle { api: AssertSendSync(api), .. }| api) + .or_else(|| G_ORT_API_FOR_ENV_BUILD.get()) + .expect("`try_init_from`または`try_init`で初期化されていなくてはなりません"); + } unsafe { NonNull::new_unchecked( G_ORT_API