Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: VOICEVOX CORE用の初期化経路を構築する #6

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ strip = true
codegen-units = 1

[package.metadata.docs.rs]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs", "__init-for-voicevox" ]
targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
rustdoc-args = [ "--cfg", "docsrs" ]

Expand Down Expand Up @@ -80,9 +80,16 @@ vitis = [ "voicevox-ort-sys/vitis" ]
cann = [ "voicevox-ort-sys/cann" ]
qnn = [ "voicevox-ort-sys/qnn" ]

# 動的ライブラリの読み込みから`OrtEnv`の作成までを、VOICEVOX独自の方法で行えるようにする。
#
# ortとしての通常の初期化の経路は禁止される。
__init-for-voicevox = []

[dependencies]
anyhow = "1.0"
ndarray = { version = "0.15", optional = true }
thiserror = "1.0"
once_cell = "1.19.0"
voicevox-ort-sys = { version = "2.0.0-rc.2", path = "ort-sys" }
libloading = { version = "0.8", optional = true }

Expand All @@ -101,7 +108,6 @@ js-sys = "0.3"
web-sys = "0.3"

[dev-dependencies]
anyhow = "1.0"
ureq = "2.1"
image = "0.25"
test-log = { version = "0.2", default-features = false, features = [ "trace" ] }
Expand Down
1 change: 1 addition & 0 deletions ort-sys/VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1.17.3
10 changes: 10 additions & 0 deletions ort-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ fn prepare_libort_dir() -> (PathBuf, bool) {
copy_libraries(&lib_dir.join("lib"), &out_dir);
}

const OUR_VERSION: &str = include_str!("./VERSION_NUMBER");
let their_version = fs::read_to_string(lib_dir.join("VERSION_NUMBER")).unwrap_or_else(|e| panic!("`VERSION_NUMBER`を読めませんでした: {e}"));
assert_eq!(OUR_VERSION.trim_end(), their_version.trim_end(), "`VERSION_NUMBER`が異なります");

(lib_dir, true)
}
#[cfg(not(feature = "download-binaries"))]
Expand Down Expand Up @@ -421,6 +425,12 @@ fn real_main(link: bool) {
}

fn main() {
if cfg!(feature = "download-binaries") {
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let version = include_str!("./VERSION_NUMBER").trim_end();
fs::write(out_dir.join("downloaded_version.rs"), format!("#[macro_export] macro_rules! downloaded_version(() => ({version:?}));")).unwrap();
}

if env::var("DOCS_RS").is_ok() {
return;
}
Expand Down
2 changes: 2 additions & 0 deletions ort-sys/src/internal/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
pub mod dirs;

include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs"));
6 changes: 6 additions & 0 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,17 @@ impl EnvironmentBuilder {

/// Commit the environment configuration and set the global environment.
pub fn commit(self) -> Result<()> {
if cfg!(feature = "__init-for-voicevox") {
panic!("`__init-for-voicevox`により禁止されています");
}
// drop global reference to previous environment
if let Some(env_arc) = unsafe { (*G_ENV.cell.get()).take() } {
drop(env_arc);
}
self.commit_()
}

pub(crate) fn commit_(self) -> Result<()> {
let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
Expand Down
184 changes: 184 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ use std::{
};

pub use ort_sys as sys;
#[cfg(feature = "download-binaries")]
pub use ort_sys::downloaded_version;

#[cfg(feature = "load-dynamic")]
pub use self::environment::init_from;
Expand Down Expand Up @@ -73,6 +75,26 @@ pub use self::value::{
ValueRefMut, ValueType, ValueTypeMarker
};

/// このクレートのフィーチャが指定された状態になっていなければコンパイルエラー。
#[cfg(feature = "load-dynamic")]
#[macro_export]
macro_rules! assert_feature {
(cfg(feature = "load-dynamic"), $msg:literal $(,)?) => {};
(cfg(not(feature = "load-dynamic")), $msg:literal $(,)?) => {
::std::compile_error!($msg);
};
}

/// このクレートのフィーチャが指定された状態になっていなければコンパイルエラー。
#[cfg(not(feature = "load-dynamic"))]
#[macro_export]
macro_rules! assert_feature {
(cfg(feature = "load-dynamic"), $msg:literal $(,)?) => {
::std::compile_error!($msg);
};
(cfg(not(feature = "load-dynamic")), $msg:literal $(,)?) => {};
}

#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
Expand Down Expand Up @@ -100,6 +122,9 @@ pub(crate) static G_ORT_LIB: OnceLock<Arc<libloading::Library>> = OnceLock::new(

#[cfg(feature = "load-dynamic")]
pub(crate) fn dylib_path() -> &'static String {
if cfg!(feature = "__init-for-voicevox") {
panic!("`__init-for-voicevox`により禁止されています");
}
G_ORT_DYLIB_PATH.get_or_init(|| {
let path = match std::env::var("ORT_DYLIB_PATH") {
Ok(s) if !s.is_empty() => s,
Expand All @@ -116,6 +141,13 @@ pub(crate) fn dylib_path() -> &'static String {

#[cfg(feature = "load-dynamic")]
pub(crate) fn lib_handle() -> &'static libloading::Library {
#[cfg(feature = "__init-for-voicevox")]
if true {
return &G_ENV_FOR_VOICEVOX
.get()
.expect("`try_init_from`または`try_init`で初期化されていなくてはなりません")
.dylib;
}
G_ORT_LIB.get_or_init(|| {
// resolve path relative to executable
let path: std::path::PathBuf = dylib_path().into();
Expand All @@ -135,6 +167,150 @@ pub(crate) fn lib_handle() -> &'static libloading::Library {
})
}

#[cfg(feature = "__init-for-voicevox")]
static G_ENV_FOR_VOICEVOX: once_cell::sync::OnceCell<EnvHandle> = once_cell::sync::OnceCell::new();

#[cfg(feature = "__init-for-voicevox")]
thread_local! {
static G_ORT_API_FOR_ENV_BUILD: std::cell::Cell<Option<NonNull<ort_sys::OrtApi>>> = const { std::cell::Cell::new(None) };
}

#[cfg(feature = "__init-for-voicevox")]
#[cfg_attr(docsrs, doc(cfg(feature = "__init-for-voicevox")))]
#[derive(Debug)]
pub struct EnvHandle {
_env: std::sync::Arc<Environment>,
api: AssertSendSync<NonNull<ort_sys::OrtApi>>,
#[cfg(feature = "load-dynamic")]
dylib: libloading::Library
}

#[cfg(feature = "__init-for-voicevox")]
impl EnvHandle {
/// インスタンスが既に作られているならそれを得る。
///
/// 作られていなければ`None`。
pub fn get() -> Option<&'static Self> {
G_ENV_FOR_VOICEVOX.get()
}
}

#[cfg(feature = "__init-for-voicevox")]
#[derive(Clone, Copy, Debug)]
struct AssertSendSync<T>(T);

// SAFETY: `OrtApi`はスレッドセーフとされているはず
#[cfg(feature = "__init-for-voicevox")]
unsafe impl Send for AssertSendSync<NonNull<ort_sys::OrtApi>> {}

// SAFETY: `OrtApi`はスレッドセーフとされているはず
#[cfg(feature = "__init-for-voicevox")]
unsafe impl Sync for AssertSendSync<NonNull<ort_sys::OrtApi>> {}

/// VOICEVOX CORE用に、`OrtEnv`の作成までをやる。
///
/// 一度成功したら以後は同じ参照を返す。
#[cfg(all(feature = "__init-for-voicevox", feature = "load-dynamic"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "__init-for-voicevox", feature = "load-dynamic"))))]
pub fn try_init_from(filename: &std::ffi::OsStr, tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyhow::Result<&'static EnvHandle> {
use anyhow::bail;
use ort_sys::ORT_API_VERSION;

G_ENV_FOR_VOICEVOX.get_or_try_init(|| {
let (dylib, api) = unsafe {
let dylib = libloading::Library::new(filename)?;

// この下にある`api()`のものをできるだけ真似る

let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const ort_sys::OrtApiBase> = dylib
.get(b"OrtGetApiBase")
.expect("`OrtGetApiBase` must be present in ONNX Runtime dylib");
let base: *const ort_sys::OrtApiBase = base_getter();
assert_ne!(base, ptr::null());

let get_version_string: extern_system_fn! { unsafe fn () -> *const c_char } =
(*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`");
let version_string = get_version_string();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
tracing::info!("Loaded ONNX Runtime dylib with version '{version_string}'");
qryxip marked this conversation as resolved.
Show resolved Hide resolved

let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::<u32>().unwrap_or(0));
match lib_minor_version.cmp(&MINOR_VERSION) {
std::cmp::Ordering::Less if cfg!(windows) => {
bail!(r"`{dylib:?}`はバージョン{version_string}のONNX Runtimeです。ONNX Runtimeはバージョン1.{MINOR_VERSION}でなくてはなりません");
}
std::cmp::Ordering::Less => bail!(
"`{filename}`で指定されたONNX Runtimeはバージョン{version_string}です。ONNX Runtimeはバージョン1.{MINOR_VERSION}でなくてはなりません",
filename = filename.to_string_lossy(),
),
std::cmp::Ordering::Greater => tracing::warn!(
"`{filename}`で指定されたONNX Runtimeはバージョン{version_string}です。対応しているONNX Runtimeのバージョンは1.{MINOR_VERSION}なので、\
互換性の問題があるかもしれません",
filename = filename.to_string_lossy(),
),
std::cmp::Ordering::Equal => {}
};

let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
let api = get_api(ORT_API_VERSION);
(dylib, api)
};
let api = AssertSendSync(NonNull::new(api.cast_mut()).unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました")));

let _env = create_env(api.0, tp_options)?;

Ok(EnvHandle { _env, api, dylib })
})
}

/// VOICEVOX CORE用に、`OrtEnv`の作成までをやる。
///
/// 一度成功したら以後は同じ参照を返す。
#[cfg(all(feature = "__init-for-voicevox", any(doc, not(feature = "load-dynamic"))))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "__init-for-voicevox", not(feature = "load-dynamic")))))]
pub fn try_init(tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyhow::Result<&'static EnvHandle> {
use ort_sys::ORT_API_VERSION;

G_ENV_FOR_VOICEVOX.get_or_try_init(|| {
let api = unsafe {
// この下にある`api()`のものをできるだけ真似る
let base: *const ort_sys::OrtApiBase = ort_sys::OrtGetApiBase();
assert_ne!(base, ptr::null());
let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
get_api(ORT_API_VERSION)
};
let api = NonNull::new(api.cast_mut())
.unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました。おそらく1.{MINOR_VERSION}より古いものがリンクされています"));
let api = AssertSendSync(api);

let _env = create_env(api.0, tp_options)?;

Ok(EnvHandle { _env, api })
})
}

#[cfg(feature = "__init-for-voicevox")]
fn create_env(api: NonNull<ort_sys::OrtApi>, tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyhow::Result<std::sync::Arc<Environment>> {
G_ORT_API_FOR_ENV_BUILD.set(Some(api));
let _unset_api = UnsetOrtApi;

let mut env = EnvironmentBuilder::default().with_name(env!("CARGO_PKG_NAME"));
if let Some(tp_options) = tp_options {
env = env.with_global_thread_pool(tp_options);
}
env.commit_()?;

return Ok(get_environment().expect("失敗しないはず").clone());

struct UnsetOrtApi;

impl Drop for UnsetOrtApi {
fn drop(&mut self) {
G_ORT_API_FOR_ENV_BUILD.set(None);
}
}
}

pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::new();

/// Returns a pointer to the global [`ort_sys::OrtApi`] object.
Expand All @@ -144,6 +320,14 @@ pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::ne
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
pub fn api() -> NonNull<ort_sys::OrtApi> {
#[cfg(feature = "__init-for-voicevox")]
if true {
return G_ENV_FOR_VOICEVOX
.get()
.map(|&EnvHandle { api: AssertSendSync(api), .. }| api)
.or_else(|| G_ORT_API_FOR_ENV_BUILD.get())
.expect("`try_init_from`または`try_init`で初期化されていなくてはなりません");
}
unsafe {
NonNull::new_unchecked(
G_ORT_API
Expand Down
Loading