From cb4acff0c978f3615a4dbaf2164d026185b88134 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sun, 1 Sep 2024 04:47:55 +0900 Subject: [PATCH] =?UTF-8?q?VOICEVOX/voicevox=5Fcore#722=20=E7=94=A8?= =?UTF-8?q?=E3=82=A8=E3=83=A9=E3=83=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/error.rs | 4 +++- src/lib.rs | 26 ++++++++++++++++++++++++-- src/session/builder.rs | 4 ++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/error.rs b/src/error.rs index 2e84580..2017c57 100644 --- a/src/error.rs +++ b/src/error.rs @@ -263,7 +263,9 @@ pub enum Error { #[error("Could't get device ID from memory info: {0}")] GetDeviceId(ErrorInternal), #[error("Training API is not enabled in this build of ONNX Runtime.")] - TrainingNotEnabled + TrainingNotEnabled, + #[error("This ONNX Runtime does not support \"vv-bin\" format (note: load/link `voicevox_onnxruntime` instead of `onnxruntime`)")] + VvBinNotSupported } impl Error { diff --git a/src/lib.rs b/src/lib.rs index 8e83d6d..8e4fcdb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -185,6 +185,7 @@ thread_local! { #[cfg_attr(docsrs, doc(cfg(feature = "__init-for-voicevox")))] #[derive(Debug)] pub struct EnvHandle { + is_voicevox_onnxruntime: bool, _env: std::sync::Arc, api: AssertSendSync>, #[cfg(feature = "load-dynamic")] @@ -259,13 +260,21 @@ pub fn try_init_from(filename: &std::ffi::OsStr, tp_options: Option *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`"); let api = get_api(ORT_API_VERSION); + dbg!(CStr::from_ptr((*api).GetBuildInfoString.unwrap()())); (dylib, api) }; let api = AssertSendSync(NonNull::new(api.cast_mut()).unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました"))); let _env = create_env(api.0, tp_options)?; - Ok(EnvHandle { _env, api, dylib }) + let is_voicevox_onnxruntime = is_voicevox_onnxruntime(api.0); + + Ok(EnvHandle { + is_voicevox_onnxruntime, + _env, + api, + dylib + }) }) } @@ -291,7 +300,9 @@ pub fn try_init(tp_options: Option) -> anyho let _env = create_env(api.0, tp_options)?; - Ok(EnvHandle { _env, api }) + let is_voicevox_onnxruntime = is_voicevox_onnxruntime(api.0); + + Ok(EnvHandle { is_voicevox_onnxruntime, _env, api }) }) } @@ -317,6 +328,17 @@ fn create_env(api: NonNull, tp_options: Option) -> bool { + unsafe { + let build_info = api.as_ref().GetBuildInfoString.expect("`GetBuildInfoString` must be present")(); + CStr::from_ptr(build_info) + .to_str() + .expect("should be UTF-8") + .starts_with("VOICEVOX ORT Build Info: ") + } +} + pub(crate) static G_ORT_API: OnceLock> = OnceLock::new(); /// Returns a pointer to the global [`ort_sys::OrtApi`] object. diff --git a/src/session/builder.rs b/src/session/builder.rs index 35f4600..f2b2576 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -446,7 +446,11 @@ impl SessionBuilder { Ok(session) } + #[cfg(feature = "__init-for-voicevox")] pub fn commit_from_vv_bin(self, bin: &[u8]) -> Result { + if !crate::EnvHandle::get().expect("should be present").is_voicevox_onnxruntime { + return Err(Error::VvBinNotSupported); + } ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr.as_ptr(), c"session.use_vv_bin".as_ptr(), c"1".as_ptr())]; self.commit_from_memory(bin) }