Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Aug 31, 2024
1 parent 26b59a8 commit cb4acff
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ pub enum Error {
#[error("Could't get device ID from memory info: {0}")]
GetDeviceId(ErrorInternal),
#[error("Training API is not enabled in this build of ONNX Runtime.")]
TrainingNotEnabled
TrainingNotEnabled,
#[error("This ONNX Runtime does not support \"vv-bin\" format (note: load/link `voicevox_onnxruntime` instead of `onnxruntime`)")]
VvBinNotSupported
}

impl Error {
Expand Down
26 changes: 24 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ thread_local! {
#[cfg_attr(docsrs, doc(cfg(feature = "__init-for-voicevox")))]
#[derive(Debug)]
pub struct EnvHandle {
is_voicevox_onnxruntime: bool,
_env: std::sync::Arc<Environment>,
api: AssertSendSync<NonNull<ort_sys::OrtApi>>,
#[cfg(feature = "load-dynamic")]
Expand Down Expand Up @@ -259,13 +260,21 @@ pub fn try_init_from(filename: &std::ffi::OsStr, tp_options: Option<EnvironmentG

let get_api: extern_system_fn! { unsafe fn(u32) -> *const ort_sys::OrtApi } = (*base).GetApi.expect("`GetApi` must be present in `OrtApiBase`");
let api = get_api(ORT_API_VERSION);
dbg!(CStr::from_ptr((*api).GetBuildInfoString.unwrap()()));
(dylib, api)
};
let api = AssertSendSync(NonNull::new(api.cast_mut()).unwrap_or_else(|| panic!("`GetApi({ORT_API_VERSION})`が失敗しました")));

let _env = create_env(api.0, tp_options)?;

Ok(EnvHandle { _env, api, dylib })
let is_voicevox_onnxruntime = is_voicevox_onnxruntime(api.0);

Ok(EnvHandle {
is_voicevox_onnxruntime,
_env,
api,
dylib
})
})
}

Expand All @@ -291,7 +300,9 @@ pub fn try_init(tp_options: Option<EnvironmentGlobalThreadPoolOptions>) -> anyho

let _env = create_env(api.0, tp_options)?;

Ok(EnvHandle { _env, api })
let is_voicevox_onnxruntime = is_voicevox_onnxruntime(api.0);

Ok(EnvHandle { is_voicevox_onnxruntime, _env, api })
})
}

Expand All @@ -317,6 +328,17 @@ fn create_env(api: NonNull<ort_sys::OrtApi>, tp_options: Option<EnvironmentGloba
}
}

#[cfg(feature = "__init-for-voicevox")]
fn is_voicevox_onnxruntime(api: NonNull<ort_sys::OrtApi>) -> bool {
unsafe {
let build_info = api.as_ref().GetBuildInfoString.expect("`GetBuildInfoString` must be present")();
CStr::from_ptr(build_info)
.to_str()
.expect("should be UTF-8")
.starts_with("VOICEVOX ORT Build Info: ")
}
}

pub(crate) static G_ORT_API: OnceLock<AtomicPtr<ort_sys::OrtApi>> = OnceLock::new();

/// Returns a pointer to the global [`ort_sys::OrtApi`] object.
Expand Down
4 changes: 4 additions & 0 deletions src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,11 @@ impl SessionBuilder {
Ok(session)
}

#[cfg(feature = "__init-for-voicevox")]
pub fn commit_from_vv_bin(self, bin: &[u8]) -> Result<Session> {
if !crate::EnvHandle::get().expect("should be present").is_voicevox_onnxruntime {
return Err(Error::VvBinNotSupported);
}
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr.as_ptr(), c"session.use_vv_bin".as_ptr(), c"1".as_ptr())];
self.commit_from_memory(bin)
}
Expand Down

0 comments on commit cb4acff

Please sign in to comment.