From 5d178022498f7c95031cdfe19d06fc38aa7ad93f Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 24 Jul 2023 09:11:28 +0900 Subject: [PATCH 01/16] [wip] Add `gpu_num_sessions` options to `load_voice_model` --- Cargo.lock | 222 +++++++-- crates/voicevox_core/Cargo.toml | 5 + .../voicevox_core/benches/decode-with-gpu.rs | 78 +++ .../src/engine/synthesis_engine.rs | 4 - crates/voicevox_core/src/inference_core.rs | 85 ++-- crates/voicevox_core/src/status.rs | 454 ++++++++++++------ crates/voicevox_core/src/voice_synthesizer.rs | 51 +- .../include/voicevox_core.h | 13 +- crates/voicevox_core_c_api/src/c_impls.rs | 34 +- .../src/compatible_engine.rs | 3 +- crates/voicevox_core_c_api/src/helpers.rs | 25 +- crates/voicevox_core_c_api/src/lib.rs | 22 +- .../voicevox_core_c_api/tests/e2e/symbols.rs | 17 +- .../tests/e2e/testcases/simple_tts.rs | 7 +- .../e2e/testcases/tts_via_audio_query.rs | 7 +- .../tests/e2e/testcases/user_dict_load.rs | 7 +- .../python/voicevox_core/_rust.pyi | 5 +- crates/voicevox_core_python_api/src/lib.rs | 16 +- 18 files changed, 757 insertions(+), 298 deletions(-) create mode 100644 crates/voicevox_core/benches/decode-with-gpu.rs diff --git a/Cargo.lock b/Cargo.lock index a783bb6d7..3a6091d07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,6 +101,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anyhow" version = "1.0.65" @@ -352,7 +358,7 @@ version = "0.60.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "062dddbc1ba4aca46de6338e2bf87771414c335f7b2f2036e8f3e9befebf88e6" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cexpr", "clang-sys", "clap 3.2.22", @@ -386,6 +392,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" + [[package]] name = "block-buffer" version = "0.9.0" @@ -481,6 +493,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1db59621ec70f09c5e9b597b220c7a2b43611f4710dc03ceb8748637775692c" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cbindgen" version = "0.24.3" @@ -543,6 +561,33 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fff857943da45f546682664a79488be82e69e43c1a7a2307679ab9afb3a66d2e" +[[package]] +name = "ciborium" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" + +[[package]] +name = "ciborium-ll" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.2.5" @@ -579,7 +624,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86447ad904c7fb335a790c9d7fe3d0d971dc523b8ccd1561a520de9a85302750" dependencies = [ "atty", - "bitflags", + "bitflags 1.3.2", "clap_lex 0.2.4", "indexmap 1.9.1", "strsim", @@ -594,7 +639,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b1a0a4208c6c483b952ad35c6eed505fc13b46f08f631b81e828084a9318d74" dependencies = [ "atty", - "bitflags", + "bitflags 1.3.2", "clap_derive", "clap_lex 0.3.0", "once_cell", @@ -685,7 +730,7 @@ version = "0.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "318d6c16e73b3a900eb212ad6a82fc7d298c5ab8184c7a9998646455bc474a16" dependencies = [ - "bitflags", + "bitflags 1.3.2", "concolor-query", "is-terminal", ] @@ -798,6 +843,44 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap 4.0.10", + "criterion-plot", + "futures", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "tokio", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + [[package]] name = "crossbeam-channel" version = "0.5.6" @@ -1125,17 +1208,6 @@ dependencies = [ "serde", ] -[[package]] -name = "errno" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" -dependencies = [ - "errno-dragonfly", - "libc", - "winapi", -] - [[package]] name = "errno" version = "0.3.1" @@ -1446,6 +1518,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + [[package]] name = "hashbrown" version = "0.12.3" @@ -1763,14 +1841,13 @@ checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" [[package]] name = "is-terminal" -version = "0.4.2" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28dfb6c8100ccc63462345b67d1bbc3679177c75ee4bf59bf29c8b1d110b8189" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.2.6", - "io-lifetimes", - "rustix 0.36.7", - "windows-sys 0.42.0", + "hermit-abi 0.3.2", + "rustix 0.38.4", + "windows-sys 0.48.0", ] [[package]] @@ -1866,9 +1943,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.142" +version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "libloading" @@ -1930,15 +2007,15 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.1.4" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.3.8" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" [[package]] name = "lock_api" @@ -2276,6 +2353,12 @@ dependencies = [ "zip", ] +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + [[package]] name = "opaque-debug" version = "0.3.0" @@ -2474,6 +2557,34 @@ version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3d7ddaed09e0eb771a79ab0fd64609ba0afb0a8366421957936ad14cbd13630" +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polling" version = "2.3.0" @@ -2808,7 +2919,7 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -2817,7 +2928,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -2958,29 +3069,28 @@ dependencies = [ [[package]] name = "rustix" -version = "0.36.7" +version = "0.37.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4fdebc4b395b7fbb9ab11e462e20ed9051e7b16e42d24042c776eca0ac81b03" +checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" dependencies = [ - "bitflags", - "errno 0.2.8", + "bitflags 1.3.2", + "errno", "io-lifetimes", "libc", - "linux-raw-sys 0.1.4", - "windows-sys 0.42.0", + "linux-raw-sys 0.3.8", + "windows-sys 0.48.0", ] [[package]] name = "rustix" -version = "0.37.19" +version = "0.38.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" +checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5" dependencies = [ - "bitflags", - "errno 0.3.1", - "io-lifetimes", + "bitflags 2.3.3", + "errno", "libc", - "linux-raw-sys 0.3.8", + "linux-raw-sys 0.4.3", "windows-sys 0.48.0", ] @@ -3017,6 +3127,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.20" @@ -3613,6 +3732,16 @@ dependencies = [ "syn 1.0.102", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -3982,6 +4111,7 @@ dependencies = [ "async_zip", "cfg-if", "const-default", + "criterion", "derive-getters", "derive-new", "duplicate", @@ -4085,6 +4215,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" +[[package]] +name = "walkdir" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.0" diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 9e5febe22..a6f746d22 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -4,6 +4,10 @@ version.workspace = true edition.workspace = true publish.workspace = true +[[bench]] +name = "decode-with-gpu" +harness = false + [features] default = [] directml = ["onnxruntime/directml"] @@ -43,6 +47,7 @@ git = "https://github.com/VOICEVOX/open_jtalk-rs.git" rev = "a16714ce16dec76fd0e3041a7acfa484921db3b5" [dev-dependencies] +criterion = { version = "0.5.1", features = ["async_tokio"] } flate2 = "1.0.24" heck = "0.4.0" pretty_assertions = "1.3.0" diff --git a/crates/voicevox_core/benches/decode-with-gpu.rs b/crates/voicevox_core/benches/decode-with-gpu.rs new file mode 100644 index 000000000..85fe1952b --- /dev/null +++ b/crates/voicevox_core/benches/decode-with-gpu.rs @@ -0,0 +1,78 @@ +use std::{num::NonZeroU16, sync::Arc}; + +use criterion::{criterion_group, criterion_main, Criterion}; +use test_util::OPEN_JTALK_DIC_DIR; +use tokio::{join, runtime::Runtime}; +use voicevox_core::{ + AccelerationMode, AudioQueryModel, InitializeOptions, LoadVoiceModelOptions, OpenJtalk, + StyleId, SynthesisOptions, Synthesizer, VoiceModel, +}; + +criterion_main!(benches); +criterion_group!(benches, benchmark); + +fn benchmark(c: &mut Criterion) { + let (synthesizer, aq) = &Runtime::new().unwrap().block_on(setup()).unwrap(); + + let decode = || async { + synthesizer + .synthesis( + aq, + StyleId::new(0), + &SynthesisOptions { + enable_interrogative_upspeak: true, + }, + ) + .await + .unwrap() + }; + + c.bench_function("decode_parallel", |b| { + b.to_async(Runtime::new().unwrap()) + .iter(|| async { join!(decode(), decode(), decode(), decode()) }) + }); + + c.bench_function("decode_sequential", |b| { + b.to_async(Runtime::new().unwrap()).iter(|| async { + for _ in 0..4 { + decode().await; + } + }) + }); +} + +async fn setup() -> voicevox_core::Result<(Synthesizer, AudioQueryModel)> { + let syntesizer = Synthesizer::new_with_initialize( + Arc::new(OpenJtalk::new_with_initialize(OPEN_JTALK_DIC_DIR).unwrap()), + &InitializeOptions { + acceleration_mode: AccelerationMode::Gpu, + cpu_num_threads: 4, + ..Default::default() + }, + ) + .await?; + + let model = &VoiceModel::from_path(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../model/sample.vvm", + )) + .await?; + syntesizer + .load_voice_model( + model, + &LoadVoiceModelOptions { + gpu_num_sessions: NonZeroU16::new(4).unwrap(), + }, + ) + .await?; + + let aq = syntesizer + .audio_query( + "寿限無寿限無五劫の擦り切れ海砂利水魚の水行末雲来末", + StyleId::new(0), + &Default::default(), + ) + .await?; + + Ok((syntesizer, aq)) +} diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index 0adce4fe2..96d2353e4 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -30,10 +30,6 @@ impl SynthesisEngine { &self.inference_core } - pub fn inference_core_mut(&mut self) -> &mut InferenceCore { - &mut self.inference_core - } - pub async fn create_accent_phrases( &self, text: &str, diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 578e57453..a2f6c0ce1 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,9 +1,8 @@ +use std::num::NonZeroU16; + use self::status::*; use super::*; -use onnxruntime::{ - ndarray, - session::{AnyArray, NdArray}, -}; +use onnxruntime::{ndarray, session::NdArray}; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; @@ -18,11 +17,13 @@ impl InferenceCore { load_all_models: bool, ) -> Result { if !use_gpu || Self::can_support_gpu_feature()? { - let mut status = Status::new(use_gpu, cpu_num_threads); + let status = Status::new(use_gpu, cpu_num_threads); if load_all_models { for model in &VoiceModel::get_all_models().await? { - status.load_model(model).await?; + status + .load_model(model, NonZeroU16::new(1).unwrap()) + .await?; } } Ok(Self { status }) @@ -43,14 +44,14 @@ impl InferenceCore { } } - pub async fn load_model(&mut self, model: &VoiceModel) -> Result<()> { - self.status.load_model(model).await + pub async fn load_model(&self, model: &VoiceModel, gpu_num_sessions: NonZeroU16) -> Result<()> { + self.status.load_model(model, gpu_num_sessions).await } - pub fn unload_model(&mut self, voice_model_id: &VoiceModelId) -> Result<()> { + pub fn unload_model(&self, voice_model_id: &VoiceModelId) -> Result<()> { self.status.unload_model(voice_model_id) } - pub fn metas(&self) -> &VoiceModelMeta { + pub fn metas(&self) -> VoiceModelMeta { self.status.metas() } @@ -71,15 +72,13 @@ impl InferenceCore { return Err(Error::InvalidStyleId { style_id }); } - let mut phoneme_vector_array = NdArray::new(ndarray::arr1(phoneme_vector)); - let mut speaker_id_array = NdArray::new(ndarray::arr1(&[style_id.raw_id() as i64])); - - let input_tensors: Vec<&mut dyn AnyArray> = - vec![&mut phoneme_vector_array, &mut speaker_id_array]; + let phoneme_vector_array = NdArray::new(ndarray::arr1(phoneme_vector)); + let speaker_id_array = NdArray::new(ndarray::arr1(&[style_id.raw_id() as i64])); let mut output = self .status - .predict_duration_session_run(style_id, input_tensors)?; + .predict_duration_session_run(style_id, phoneme_vector_array, speaker_id_array) + .await?; for output_item in output.iter_mut() { if *output_item < PHONEME_LENGTH_MINIMAL { @@ -106,31 +105,29 @@ impl InferenceCore { return Err(Error::InvalidStyleId { style_id }); } - let mut length_array = NdArray::new(ndarray::arr0(length as i64)); - let mut vowel_phoneme_vector_array = NdArray::new(ndarray::arr1(vowel_phoneme_vector)); - let mut consonant_phoneme_vector_array = - NdArray::new(ndarray::arr1(consonant_phoneme_vector)); - let mut start_accent_vector_array = NdArray::new(ndarray::arr1(start_accent_vector)); - let mut end_accent_vector_array = NdArray::new(ndarray::arr1(end_accent_vector)); - let mut start_accent_phrase_vector_array = + let length_array = NdArray::new(ndarray::arr0(length as i64)); + let vowel_phoneme_vector_array = NdArray::new(ndarray::arr1(vowel_phoneme_vector)); + let consonant_phoneme_vector_array = NdArray::new(ndarray::arr1(consonant_phoneme_vector)); + let start_accent_vector_array = NdArray::new(ndarray::arr1(start_accent_vector)); + let end_accent_vector_array = NdArray::new(ndarray::arr1(end_accent_vector)); + let start_accent_phrase_vector_array = NdArray::new(ndarray::arr1(start_accent_phrase_vector)); - let mut end_accent_phrase_vector_array = - NdArray::new(ndarray::arr1(end_accent_phrase_vector)); - let mut speaker_id_array = NdArray::new(ndarray::arr1(&[style_id.raw_id() as i64])); - - let input_tensors: Vec<&mut dyn AnyArray> = vec![ - &mut length_array, - &mut vowel_phoneme_vector_array, - &mut consonant_phoneme_vector_array, - &mut start_accent_vector_array, - &mut end_accent_vector_array, - &mut start_accent_phrase_vector_array, - &mut end_accent_phrase_vector_array, - &mut speaker_id_array, - ]; + let end_accent_phrase_vector_array = NdArray::new(ndarray::arr1(end_accent_phrase_vector)); + let speaker_id_array = NdArray::new(ndarray::arr1(&[style_id.raw_id() as i64])); self.status - .predict_intonation_session_run(style_id, input_tensors) + .predict_intonation_session_run( + style_id, + length_array, + vowel_phoneme_vector_array, + consonant_phoneme_vector_array, + start_accent_vector_array, + end_accent_vector_array, + start_accent_phrase_vector_array, + end_accent_phrase_vector_array, + speaker_id_array, + ) + .await } pub async fn decode( @@ -161,23 +158,21 @@ impl InferenceCore { padding_size, ); - let mut f0_array = NdArray::new( + let f0_array = NdArray::new( ndarray::arr1(&f0_with_padding) .into_shape([length_with_padding, 1]) .unwrap(), ); - let mut phoneme_array = NdArray::new( + let phoneme_array = NdArray::new( ndarray::arr1(&phoneme_with_padding) .into_shape([length_with_padding, phoneme_size]) .unwrap(), ); - let mut speaker_id_array = NdArray::new(ndarray::arr1(&[style_id.raw_id() as i64])); - - let input_tensors: Vec<&mut dyn AnyArray> = - vec![&mut f0_array, &mut phoneme_array, &mut speaker_id_array]; + let speaker_id_array = NdArray::new(ndarray::arr1(&[style_id.raw_id() as i64])); self.status - .decode_session_run(style_id, input_tensors) + .decode_session_run(style_id, f0_array, phoneme_array, speaker_id_array) + .await .map(|output| Self::trim_padding_from_output(output, padding_size)) } diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 68f839a6b..3388d2375 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,11 +1,13 @@ use super::*; +use itertools::iproduct; use once_cell::sync::Lazy; use onnxruntime::{ environment::Environment, - session::{AnyArray, Session}, + ndarray::{Ix0, Ix1, Ix2}, + session::{NdArray, Session}, GraphOptimizationLevel, LoggingLevel, }; -use std::sync::Mutex; +use std::{collections::VecDeque, iter, num::NonZeroU16, sync::Arc}; use std::{env, path::Path}; use tracing::error; @@ -19,18 +21,9 @@ cfg_if! { use std::collections::BTreeMap; pub struct Status { - models: StatusModels, - merged_metas: VoiceModelMeta, + loaded_models: std::sync::Mutex, light_session_options: SessionOptions, // 軽いモデルはこちらを使う heavy_session_options: SessionOptions, // 重いモデルはこちらを使う - id_relations: BTreeMap, -} - -struct StatusModels { - metas: BTreeMap, - predict_duration: BTreeMap>>, - predict_intonation: BTreeMap>>, - decode: BTreeMap>>, } #[derive(new, Getters)] @@ -58,38 +51,21 @@ static ENVIRONMENT: Lazy = Lazy::new(|| { .unwrap() }); -#[allow(unsafe_code)] -unsafe impl Send for Status {} - -#[allow(unsafe_code)] -unsafe impl Sync for Status {} - impl Status { pub fn new(use_gpu: bool, cpu_num_threads: u16) -> Self { Self { - models: StatusModels { - metas: BTreeMap::new(), - predict_duration: BTreeMap::new(), - predict_intonation: BTreeMap::new(), - decode: BTreeMap::new(), - }, - merged_metas: VoiceModelMeta::default(), + loaded_models: Default::default(), light_session_options: SessionOptions::new(cpu_num_threads, false), heavy_session_options: SessionOptions::new(cpu_num_threads, use_gpu), - id_relations: BTreeMap::default(), } } - pub async fn load_model(&mut self, model: &VoiceModel) -> Result<()> { - for speaker in model.metas().iter() { - for style in speaker.styles().iter() { - if self.id_relations.contains_key(style.id()) { - Err(Error::AlreadyLoadedModel { - path: model.path().clone(), - })?; - } - } - } + pub async fn load_model(&self, model: &VoiceModel, gpu_num_sessions: NonZeroU16) -> Result<()> { + self.loaded_models + .lock() + .unwrap() + .ensure_not_contains(model)?; + let models = model.read_inference_models().await?; let predict_duration_session = self.new_session( @@ -102,81 +78,46 @@ impl Status { &self.light_session_options, model.path(), )?; - let decode_model = self.new_session( - models.decode_model(), - &self.heavy_session_options, - model.path(), + let decode_models = iter::repeat_with(|| { + self.new_session( + models.decode_model(), + &self.heavy_session_options, + model.path(), + ) + }) + .take(gpu_num_sessions.get().into()) + .collect::, _>>()?; + + self.loaded_models.lock().unwrap().insert( + model, + LightOrtSessions { + predict_duration: predict_duration_session.into(), + predict_intonation: predict_intonation_session.into(), + }, + decode_models.into_iter().map(|decode| HeavyOrtSession { + decode: decode.into(), + }), )?; - self.models - .metas - .insert(model.id().clone(), model.metas().clone()); - - for speaker in model.metas().iter() { - for style in speaker.styles().iter() { - self.id_relations.insert(*style.id(), model.id().clone()); - } - } - self.set_metas(); - - self.models - .predict_duration - .insert(model.id().clone(), Mutex::new(predict_duration_session)); - self.models - .predict_intonation - .insert(model.id().clone(), Mutex::new(predict_intonation_session)); - - self.models - .decode - .insert(model.id().clone(), Mutex::new(decode_model)); - Ok(()) } - pub fn unload_model(&mut self, voice_model_id: &VoiceModelId) -> Result<()> { - if self.is_loaded_model(voice_model_id) { - self.models.predict_intonation.remove(voice_model_id); - self.models.predict_duration.remove(voice_model_id); - self.models.decode.remove(voice_model_id); - - let remove_style_ids = self - .id_relations - .iter() - .filter(|&(_, loaded_model_id)| loaded_model_id == voice_model_id) - .map(|(&style_id, _)| style_id) - .collect::>(); - - for style_id in remove_style_ids.iter() { - self.id_relations.remove(style_id); - } - self.set_metas(); - Ok(()) - } else { - Err(Error::UnloadedModel { - model_id: voice_model_id.clone(), - }) - } - } - - fn set_metas(&mut self) { - let mut meta = VoiceModelMeta::default(); - for m in self.models.metas.values() { - meta.extend_from_slice(m); - } - self.merged_metas = meta; + pub fn unload_model(&self, voice_model_id: &VoiceModelId) -> Result<()> { + self.loaded_models.lock().unwrap().remove(voice_model_id) } - pub fn metas(&self) -> &VoiceModelMeta { - &self.merged_metas + pub fn metas(&self) -> VoiceModelMeta { + self.loaded_models.lock().unwrap().metas() } pub fn is_loaded_model(&self, voice_model_id: &VoiceModelId) -> bool { - self.models.predict_duration.contains_key(voice_model_id) - && self.models.predict_intonation.contains_key(voice_model_id) - && self.models.decode.contains_key(voice_model_id) + self.loaded_models + .lock() + .unwrap() + .contains_voice_model(voice_model_id) } pub fn is_loaded_model_by_style_id(&self, style_id: StyleId) -> bool { - self.id_relations.contains_key(&style_id) + self.loaded_models.lock().unwrap().contains_style(style_id) } fn new_session( @@ -223,68 +164,274 @@ impl Status { } pub fn validate_speaker_id(&self, style_id: StyleId) -> bool { - self.id_relations.contains_key(&style_id) + self.is_loaded_model_by_style_id(style_id) } - pub fn predict_duration_session_run( + pub async fn predict_duration_session_run( &self, style_id: StyleId, - inputs: Vec<&mut dyn AnyArray>, + mut phoneme_vector_array: NdArray, + mut speaker_id_array: NdArray, ) -> Result> { - if let Some(model_id) = self.id_relations.get(&style_id) { - if let Some(model) = self.models.predict_duration.get(model_id) { - if let Ok(output_tensors) = model.lock().unwrap().run(inputs) { - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - } else { - Err(Error::InferenceFailed) - } - } else { - Err(Error::InvalidStyleId { style_id }) - } - } else { - Err(Error::InvalidStyleId { style_id }) - } + let light_sessions = self + .loaded_models + .lock() + .unwrap() + .light_sessions(style_id)?; + + tokio::task::spawn_blocking(move || { + let LightOrtSessions { + predict_duration, .. + } = &mut *light_sessions.lock().unwrap(); + + let output_tensors = predict_duration + .get_mut() + .run(vec![&mut phoneme_vector_array, &mut speaker_id_array]) + .map_err(|_| Error::InferenceFailed)?; + Ok(output_tensors[0].as_slice().unwrap().to_owned()) + }) + .await + .unwrap() } - pub fn predict_intonation_session_run( + #[allow(clippy::too_many_arguments)] + pub async fn predict_intonation_session_run( &self, style_id: StyleId, - inputs: Vec<&mut dyn AnyArray>, + mut length_array: NdArray, + mut vowel_phoneme_vector_array: NdArray, + mut consonant_phoneme_vector_array: NdArray, + mut start_accent_vector_array: NdArray, + mut end_accent_vector_array: NdArray, + mut start_accent_phrase_vector_array: NdArray, + mut end_accent_phrase_vector_array: NdArray, + mut speaker_id_array: NdArray, ) -> Result> { - if let Some(model_id) = self.id_relations.get(&style_id) { - if let Some(model) = self.models.predict_intonation.get(model_id) { - if let Ok(output_tensors) = model.lock().unwrap().run(inputs) { - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - } else { - Err(Error::InferenceFailed) - } - } else { - Err(Error::InvalidStyleId { style_id }) - } - } else { - Err(Error::InvalidStyleId { style_id }) - } + let light_sessions = self + .loaded_models + .lock() + .unwrap() + .light_sessions(style_id)?; + + tokio::task::spawn_blocking(move || { + let LightOrtSessions { + predict_intonation, .. + } = &mut *light_sessions.lock().unwrap(); + + let output_tensors = predict_intonation + .get_mut() + .run(vec![ + &mut length_array, + &mut vowel_phoneme_vector_array, + &mut consonant_phoneme_vector_array, + &mut start_accent_vector_array, + &mut end_accent_vector_array, + &mut start_accent_phrase_vector_array, + &mut end_accent_phrase_vector_array, + &mut speaker_id_array, + ]) + .map_err(|_| Error::InferenceFailed)?; + Ok(output_tensors[0].as_slice().unwrap().to_owned()) + }) + .await + .unwrap() } - pub fn decode_session_run( + pub async fn decode_session_run( &self, style_id: StyleId, - inputs: Vec<&mut dyn AnyArray>, + mut f0_array: NdArray, + mut phoneme_array: NdArray, + mut speaker_id_array: NdArray, ) -> Result> { - if let Some(model_id) = self.id_relations.get(&style_id) { - if let Some(model) = self.models.decode.get(model_id) { - if let Ok(output_tensors) = model.lock().unwrap().run(inputs) { - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - } else { - Err(Error::InferenceFailed) - } - } else { - Err(Error::InvalidStyleId { style_id }) - } - } else { - Err(Error::InvalidStyleId { style_id }) + let heavy_session = self.loaded_models.lock().unwrap().heavy_session(style_id)?; + + tokio::task::spawn_blocking(move || { + let HeavyOrtSession { decode } = &mut *heavy_session.lock().unwrap(); + + let output_tensors = decode + .get_mut() + .run(vec![ + &mut f0_array, + &mut phoneme_array, + &mut speaker_id_array, + ]) + .map_err(|_| Error::InferenceFailed)?; + Ok(output_tensors[0].as_slice().unwrap().to_owned()) + }) + .await + .unwrap() + } +} + +#[derive(Default)] +struct LoadedModels(BTreeMap); + +struct LoadedModel { + metas: VoiceModelMeta, + session_set: SessionSet, +} + +impl LoadedModels { + fn metas(&self) -> VoiceModelMeta { + self.0 + .values() + .flat_map(|LoadedModel { metas, .. }| metas) + .cloned() + .collect() + } + + fn light_sessions( + &mut self, + style_id: StyleId, + ) -> Result>> { + let LoadedModel { session_set, .. } = self.find_loaded_voice_model(style_id)?; + Ok(session_set.get_light()) + } + + fn heavy_session( + &mut self, + style_id: StyleId, + ) -> Result>> { + let LoadedModel { session_set, .. } = self.find_loaded_voice_model(style_id)?; + Ok(session_set.get_heavy()) + } + + fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool { + self.0.contains_key(model_id) + } + + fn contains_style(&self, style_id: StyleId) -> bool { + self.styles().any(|style| *style.id() == style_id) + } + + fn ensure_not_contains(&self, model: &VoiceModel) -> Result<()> { + let loaded = self.styles(); + let external = model.metas().iter().flat_map(|speaker| speaker.styles()); + + if iproduct!(loaded, external).any(|(loaded, external)| loaded.id() == external.id()) { + return Err(Error::AlreadyLoadedModel { + path: model.path().clone(), + }); } + Ok(()) + } + + fn insert( + &mut self, + model: &VoiceModel, + light_sessions: LightOrtSessions, + heavy_session: impl IntoIterator, + ) -> Result<()> { + self.ensure_not_contains(model)?; + + let prev = self.0.insert( + model.id().clone(), + LoadedModel { + metas: model.metas().clone(), + session_set: SessionSet::new(light_sessions, heavy_session), + }, + ); + assert!(prev.is_none()); + Ok(()) } + + fn remove(&mut self, model_id: &VoiceModelId) -> Result<()> { + if self.0.remove(model_id).is_none() { + return Err(Error::UnloadedModel { + model_id: model_id.clone(), + }); + } + Ok(()) + } + + fn find_loaded_voice_model(&mut self, style_id: StyleId) -> Result<&mut LoadedModel> { + self.0 + .values_mut() + .find(|LoadedModel { metas, .. }| { + metas + .iter() + .flat_map(|speaker| speaker.styles()) + .any(|style| *style.id() == style_id) + }) + .ok_or(Error::InvalidStyleId { style_id }) + } + + fn styles(&self) -> impl Iterator { + self.0 + .values() + .flat_map(|LoadedModel { metas, .. }| metas) + .flat_map(|speaker| speaker.styles()) + } +} + +struct SessionSet { + // 不変条件: `self.heavy.len() >= 1` + light: Arc>, + heavy: VecDeque>>, +} + +impl SessionSet { + fn new(light: LightOrtSessions, heavy: impl IntoIterator) -> Self { + Self { + light: Arc::new(light.into()), + heavy: heavy.into_iter().map(Into::into).map(Arc::new).collect(), + } + } + + fn get_light(&self) -> Arc> { + self.light.clone() + } + + /// # Panics + /// + /// `self.heavy`が空のときパニックする。 + fn get_heavy(&mut self) -> Arc> { + self.heavy.rotate_left(1); + self.heavy.back().unwrap().clone() + } +} + +struct LightOrtSessions { + predict_duration: AssertSend>, + predict_intonation: AssertSend>, +} + +struct HeavyOrtSession { + decode: AssertSend>, +} + +// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 +// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 + +use self::assert_send::AssertSend; + +mod assert_send { + use onnxruntime::session::Session; + + pub(super) struct AssertSend(T); + + impl AssertSend { + pub(super) fn get_mut(&mut self) -> &mut T { + &mut self.0 + } + } + + impl From> for AssertSend> { + fn from(session: Session<'static>) -> Self { + Self(session) + } + } + + impl AsRef for AssertSend { + fn as_ref(&self) -> &T { + &self.0 + } + } + + // SAFETY: `Session` is probably "send"able. + #[allow(unsafe_code)] + unsafe impl Send for AssertSend {} } #[cfg(test)] @@ -314,33 +461,30 @@ mod tests { cpu_num_threads, status.heavy_session_options.cpu_num_threads ); - assert!(status.models.predict_duration.is_empty()); - assert!(status.models.predict_intonation.is_empty()); - assert!(status.models.decode.is_empty()); - assert!(status.id_relations.is_empty()); + assert!(status.loaded_models.lock().unwrap().0.is_empty()); } #[rstest] #[tokio::test] async fn status_load_model_works() { - let mut status = Status::new(false, 0); - let result = status.load_model(&open_default_vvm_file().await).await; + let status = Status::new(false, 0); + let result = status + .load_model(&open_default_vvm_file().await, NonZeroU16::new(1).unwrap()) + .await; assert_debug_fmt_eq!(Ok(()), result); - assert_eq!(1, status.models.predict_duration.len()); - assert_eq!(1, status.models.predict_intonation.len()); - assert_eq!(1, status.models.decode.len()); + assert_eq!(1, status.loaded_models.lock().unwrap().0.len()); } #[rstest] #[tokio::test] async fn status_is_model_loaded_works() { - let mut status = Status::new(false, 0); + let status = Status::new(false, 0); let vvm = open_default_vvm_file().await; assert!( !status.is_loaded_model(vvm.id()), "model should not be loaded" ); - let result = status.load_model(&vvm).await; + let result = status.load_model(&vvm, NonZeroU16::new(1).unwrap()).await; assert_debug_fmt_eq!(Ok(()), result); assert!(status.is_loaded_model(vvm.id()), "model should be loaded"); } diff --git a/crates/voicevox_core/src/voice_synthesizer.rs b/crates/voicevox_core/src/voice_synthesizer.rs index 86b91ce83..01b2e4790 100644 --- a/crates/voicevox_core/src/voice_synthesizer.rs +++ b/crates/voicevox_core/src/voice_synthesizer.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{num::NonZeroU16, sync::Arc}; use const_default::ConstDefault; use duplicate::duplicate_item; @@ -7,6 +7,18 @@ use crate::engine::{create_kana, parse_kana, AccentPhraseModel, OpenJtalk, Synth use super::*; +pub struct LoadVoiceModelOptions { + pub gpu_num_sessions: NonZeroU16, +} + +impl ConstDefault for LoadVoiceModelOptions { + const DEFAULT: Self = Self { + // SAFETY: 1 ≠ 0 + #[allow(unsafe_code)] + gpu_num_sessions: unsafe { NonZeroU16::new_unchecked(1) }, + }; +} + pub struct SynthesisOptions { pub enable_interrogative_upspeak: bool, } @@ -79,6 +91,7 @@ pub struct InitializeOptions { #[duplicate_item( T; + [ LoadVoiceModelOptions ]; [ AccentPhrasesOptions ]; [ AudioQueryOptions ]; [ TtsOptions ]; @@ -141,18 +154,22 @@ impl Synthesizer { } /// 音声モデルを読み込む - pub async fn load_voice_model(&mut self, model: &VoiceModel) -> Result<()> { + pub async fn load_voice_model( + &self, + model: &VoiceModel, + options: &LoadVoiceModelOptions, + ) -> Result<()> { self.synthesis_engine - .inference_core_mut() - .load_model(model) + .inference_core() + .load_model(model, options.gpu_num_sessions) .await?; Ok(()) } /// 指定したモデルIdの音声モデルを開放する - pub fn unload_voice_model(&mut self, voice_model_id: &VoiceModelId) -> Result<()> { + pub fn unload_voice_model(&self, voice_model_id: &VoiceModelId) -> Result<()> { self.synthesis_engine - .inference_core_mut() + .inference_core() .unload_model(voice_model_id) } @@ -171,7 +188,7 @@ impl Synthesizer { } /// 今読み込んでいる音声モデルのメタ情報を返す - pub fn metas(&self) -> &VoiceModelMeta { + pub fn metas(&self) -> VoiceModelMeta { self.synthesis_engine.inference_core().metas() } @@ -378,7 +395,7 @@ mod tests { #[case(Ok(()))] #[tokio::test] async fn load_model_works(#[case] expected_result_at_initialized: Result<()>) { - let mut syntesizer = Synthesizer::new_with_initialize( + let syntesizer = Synthesizer::new_with_initialize( Arc::new(OpenJtalk::new_without_dic()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, @@ -389,7 +406,7 @@ mod tests { .unwrap(); let result = syntesizer - .load_voice_model(&open_default_vvm_file().await) + .load_voice_model(&open_default_vvm_file().await, &Default::default()) .await; assert_debug_fmt_eq!( @@ -419,7 +436,7 @@ mod tests { #[tokio::test] async fn is_loaded_model_by_style_id_works(#[case] style_id: u32, #[case] expected: bool) { let style_id = StyleId::new(style_id); - let mut syntesizer = Synthesizer::new_with_initialize( + let syntesizer = Synthesizer::new_with_initialize( Arc::new(OpenJtalk::new_without_dic()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, @@ -433,7 +450,7 @@ mod tests { "expected is_model_loaded to return false, but got true", ); syntesizer - .load_voice_model(&open_default_vvm_file().await) + .load_voice_model(&open_default_vvm_file().await, &Default::default()) .await .unwrap(); @@ -448,7 +465,7 @@ mod tests { #[rstest] #[tokio::test] async fn predict_duration_works() { - let mut syntesizer = Synthesizer::new_with_initialize( + let syntesizer = Synthesizer::new_with_initialize( Arc::new(OpenJtalk::new_without_dic()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, @@ -459,7 +476,7 @@ mod tests { .unwrap(); syntesizer - .load_voice_model(&open_default_vvm_file().await) + .load_voice_model(&open_default_vvm_file().await, &Default::default()) .await .unwrap(); @@ -480,7 +497,7 @@ mod tests { #[rstest] #[tokio::test] async fn predict_intonation_works() { - let mut syntesizer = Synthesizer::new_with_initialize( + let syntesizer = Synthesizer::new_with_initialize( Arc::new(OpenJtalk::new_without_dic()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, @@ -490,7 +507,7 @@ mod tests { .await .unwrap(); syntesizer - .load_voice_model(&open_default_vvm_file().await) + .load_voice_model(&open_default_vvm_file().await, &Default::default()) .await .unwrap(); @@ -522,7 +539,7 @@ mod tests { #[rstest] #[tokio::test] async fn decode_works() { - let mut syntesizer = Synthesizer::new_with_initialize( + let syntesizer = Synthesizer::new_with_initialize( Arc::new(OpenJtalk::new_without_dic()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, @@ -532,7 +549,7 @@ mod tests { .await .unwrap(); syntesizer - .load_voice_model(&open_default_vvm_file().await) + .load_voice_model(&open_default_vvm_file().await, &Default::default()) .await .unwrap(); diff --git a/crates/voicevox_core_c_api/include/voicevox_core.h b/crates/voicevox_core_c_api/include/voicevox_core.h index ecb358846..c36144b03 100644 --- a/crates/voicevox_core_c_api/include/voicevox_core.h +++ b/crates/voicevox_core_c_api/include/voicevox_core.h @@ -220,6 +220,10 @@ typedef struct VoicevoxInitializeOptions { bool load_all_models; } VoicevoxInitializeOptions; +typedef struct VoicevoxLoadVoiceModelOptions { + uint16_t gpu_num_sessions; +} VoicevoxLoadVoiceModelOptions; + /** * スタイルID */ @@ -303,6 +307,8 @@ extern const struct VoicevoxInitializeOptions voicevox_default_initialize_option extern const char *voicevox_version; +extern const struct VoicevoxLoadVoiceModelOptions voicevox_default_load_voice_model_options; + extern const struct VoicevoxAudioQueryOptions voicevox_default_audio_query_options; extern const struct VoicevoxAccentPhrasesOptions voicevox_default_accent_phrases_options; @@ -447,8 +453,9 @@ void voicevox_synthesizer_delete(struct VoicevoxSynthesizer *synthesizer); #ifdef _WIN32 __declspec(dllimport) #endif -VoicevoxResultCode voicevox_synthesizer_load_voice_model(struct VoicevoxSynthesizer *synthesizer, - const struct VoicevoxVoiceModel *model); +VoicevoxResultCode voicevox_synthesizer_load_voice_model(const struct VoicevoxSynthesizer *synthesizer, + const struct VoicevoxVoiceModel *model, + struct VoicevoxLoadVoiceModelOptions options); /** * モデルの読み込みを解除する @@ -463,7 +470,7 @@ VoicevoxResultCode voicevox_synthesizer_load_voice_model(struct VoicevoxSynthesi #ifdef _WIN32 __declspec(dllimport) #endif -VoicevoxResultCode voicevox_synthesizer_unload_voice_model(struct VoicevoxSynthesizer *synthesizer, +VoicevoxResultCode voicevox_synthesizer_unload_voice_model(const struct VoicevoxSynthesizer *synthesizer, VoicevoxVoiceModelId model_id); /** diff --git a/crates/voicevox_core_c_api/src/c_impls.rs b/crates/voicevox_core_c_api/src/c_impls.rs index f4c58bcc4..99314d7d2 100644 --- a/crates/voicevox_core_c_api/src/c_impls.rs +++ b/crates/voicevox_core_c_api/src/c_impls.rs @@ -1,12 +1,14 @@ use std::{ - ffi::{CStr, CString}, + ffi::{c_char, CString}, path::Path, sync::Arc, }; use voicevox_core::{InitializeOptions, OpenJtalk, Result, Synthesizer, VoiceModel, VoiceModelId}; -use crate::{OpenJtalkRc, VoicevoxSynthesizer, VoicevoxVoiceModel}; +use crate::{ + CApiResult, OpenJtalkRc, VoicevoxLoadVoiceModelOptions, VoicevoxSynthesizer, VoicevoxVoiceModel, +}; impl OpenJtalkRc { pub(crate) fn new_with_initialize(open_jtalk_dic_dir: impl AsRef) -> Result { @@ -24,26 +26,34 @@ impl VoicevoxSynthesizer { Ok(Self { synthesizer: Synthesizer::new_with_initialize(open_jtalk.open_jtalk.clone(), options) .await?, - metas_cstring: CString::default(), + metas_cstring: Default::default(), }) } - pub(crate) async fn load_voice_model(&mut self, model: &VoiceModel) -> Result<()> { - self.synthesizer.load_voice_model(model).await?; - let metas = self.synthesizer.metas(); - self.metas_cstring = CString::new(serde_json::to_string(metas).unwrap()).unwrap(); + pub(crate) async fn load_voice_model( + &self, + model: &VoiceModel, + options: VoicevoxLoadVoiceModelOptions, + ) -> CApiResult<()> { + self.synthesizer + .load_voice_model(model, &options.try_into()?) + .await?; + let metas = &self.synthesizer.metas(); + *self.metas_cstring.lock().unwrap() = + CString::new(serde_json::to_string(metas).unwrap()).unwrap(); Ok(()) } - pub(crate) fn unload_voice_model(&mut self, model_id: &VoiceModelId) -> Result<()> { + pub(crate) fn unload_voice_model(&self, model_id: &VoiceModelId) -> Result<()> { self.synthesizer.unload_voice_model(model_id)?; - let metas = self.synthesizer.metas(); - self.metas_cstring = CString::new(serde_json::to_string(metas).unwrap()).unwrap(); + let metas = &self.synthesizer.metas(); + *self.metas_cstring.lock().unwrap() = + CString::new(serde_json::to_string(metas).unwrap()).unwrap(); Ok(()) } - pub(crate) fn metas(&self) -> &CStr { - &self.metas_cstring + pub(crate) fn metas_ptr(&self) -> *const c_char { + self.metas_cstring.lock().unwrap().as_ptr() } } diff --git a/crates/voicevox_core_c_api/src/compatible_engine.rs b/crates/voicevox_core_c_api/src/compatible_engine.rs index e84040abc..3f455e7fc 100644 --- a/crates/voicevox_core_c_api/src/compatible_engine.rs +++ b/crates/voicevox_core_c_api/src/compatible_engine.rs @@ -100,7 +100,8 @@ pub extern "C" fn load_model(style_id: i64) -> bool { if let Some(model_id) = model_set.style_model_map.get(&style_id) { let vvm = model_set.model_map.get(model_id).unwrap(); let synthesizer = &mut *lock_synthesizer(); - let result = RUNTIME.block_on(ensure_initialized!(synthesizer).load_voice_model(vvm)); + let result = RUNTIME + .block_on(ensure_initialized!(synthesizer).load_voice_model(vvm, &Default::default())); if let Some(err) = result.err() { set_message(&format!("{err}")); false diff --git a/crates/voicevox_core_c_api/src/helpers.rs b/crates/voicevox_core_c_api/src/helpers.rs index 08d1116de..cc21301c6 100644 --- a/crates/voicevox_core_c_api/src/helpers.rs +++ b/crates/voicevox_core_c_api/src/helpers.rs @@ -53,10 +53,10 @@ pub(crate) fn into_result_code_with_error(result: CApiResult<()>) -> VoicevoxRes } } -type CApiResult = std::result::Result; +pub(crate) type CApiResult = std::result::Result; #[derive(Error, Debug)] -pub(crate) enum CApiError { +pub enum CApiError { #[error("{0}")] RustApi(#[from] voicevox_core::Error), #[error("UTF-8として不正な入力です")] @@ -81,6 +81,27 @@ pub(crate) fn ensure_utf8(s: &CStr) -> CApiResult<&str> { s.to_str().map_err(|_| CApiError::InvalidUtf8Input) } +impl ConstDefault for VoicevoxLoadVoiceModelOptions { + const DEFAULT: Self = { + let options = voicevox_core::LoadVoiceModelOptions::DEFAULT; + Self { + gpu_num_sessions: options.gpu_num_sessions.get(), + } + }; +} + +impl TryFrom for voicevox_core::LoadVoiceModelOptions { + type Error = CApiError; + + fn try_from(options: VoicevoxLoadVoiceModelOptions) -> std::result::Result { + let gpu_num_sessions = options + .gpu_num_sessions + .try_into() + .unwrap_or_else(|_| todo!()); + Ok(Self { gpu_num_sessions }) + } +} + impl ConstDefault for VoicevoxAudioQueryOptions { const DEFAULT: Self = { let options = voicevox_core::AudioQueryOptions::DEFAULT; diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index ef141b4ae..948dac9e7 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -247,7 +247,7 @@ pub extern "C" fn voicevox_voice_model_delete(model: Box) { #[derive(Getters)] pub struct VoicevoxSynthesizer { synthesizer: Synthesizer, - metas_cstring: CString, + metas_cstring: std::sync::Mutex, } /// 音声シンセサイザを生成して初期化する @@ -287,6 +287,15 @@ pub extern "C" fn voicevox_synthesizer_delete(synthesizer: Box VoicevoxResultCode { into_result_code_with_error( - RUNTIME - .block_on(synthesizer.load_voice_model(model.model())) - .map_err(Into::into), + RUNTIME.block_on(synthesizer.load_voice_model(model.model(), options)), ) } @@ -317,7 +325,7 @@ pub extern "C" fn voicevox_synthesizer_load_voice_model( /// @param model_id NULL終端文字列であること #[no_mangle] pub unsafe extern "C" fn voicevox_synthesizer_unload_voice_model( - synthesizer: &mut VoicevoxSynthesizer, + synthesizer: &VoicevoxSynthesizer, model_id: VoicevoxVoiceModelId, ) -> VoicevoxResultCode { into_result_code_with_error((|| { @@ -368,7 +376,7 @@ pub unsafe extern "C" fn voicevox_synthesizer_is_loaded_voice_model( pub extern "C" fn voicevox_synthesizer_get_metas_json( synthesizer: &VoicevoxSynthesizer, ) -> *const c_char { - synthesizer.metas().as_ptr() + synthesizer.metas_ptr() } /// サポートデバイス情報をjsonで取得する diff --git a/crates/voicevox_core_c_api/tests/e2e/symbols.rs b/crates/voicevox_core_c_api/tests/e2e/symbols.rs index 810fbf91b..ff1e300be 100644 --- a/crates/voicevox_core_c_api/tests/e2e/symbols.rs +++ b/crates/voicevox_core_c_api/tests/e2e/symbols.rs @@ -8,6 +8,8 @@ use voicevox_core::result_code::VoicevoxResultCode; pub(crate) struct Symbols<'lib> { pub(crate) voicevox_version: Symbol<'lib, &'lib &'lib c_char>, pub(crate) voicevox_default_initialize_options: Symbol<'lib, &'lib VoicevoxInitializeOptions>, + pub(crate) voicevox_default_load_voice_model_options: + Symbol<'lib, &'lib VoicevoxLoadVoiceModelOptions>, pub(crate) voicevox_default_audio_query_options: Symbol<'lib, &'lib VoicevoxAudioQueryOptions>, pub(crate) voicevox_default_synthesis_options: Symbol<'lib, &'lib VoicevoxSynthesisOptions>, pub(crate) voicevox_default_tts_options: Symbol<'lib, &'lib VoicevoxTtsOptions>, @@ -43,13 +45,17 @@ pub(crate) struct Symbols<'lib> { pub(crate) voicevox_synthesizer_load_voice_model: Symbol< 'lib, unsafe extern "C" fn( - *mut VoicevoxSynthesizer, + *const VoicevoxSynthesizer, *const VoicevoxVoiceModel, + VoicevoxLoadVoiceModelOptions, ) -> VoicevoxResultCode, >, pub(crate) voicevox_synthesizer_unload_voice_model: Symbol< 'lib, - unsafe extern "C" fn(*mut VoicevoxSynthesizer, VoicevoxVoiceModelId) -> VoicevoxResultCode, + unsafe extern "C" fn( + *const VoicevoxSynthesizer, + VoicevoxVoiceModelId, + ) -> VoicevoxResultCode, >, pub(crate) voicevox_synthesizer_is_gpu_mode: Symbol<'lib, unsafe extern "C" fn(*const VoicevoxSynthesizer) -> bool>, @@ -186,6 +192,7 @@ impl<'lib> Symbols<'lib> { Ok(new!( voicevox_version, voicevox_default_initialize_options, + voicevox_default_load_voice_model_options, voicevox_default_audio_query_options, voicevox_default_synthesis_options, voicevox_default_tts_options, @@ -253,6 +260,12 @@ pub(crate) struct VoicevoxInitializeOptions { pub(crate) _load_all_models: bool, } +#[derive(Clone, Copy)] +#[repr(C)] +pub(crate) struct VoicevoxLoadVoiceModelOptions { + pub(crate) gpu_num_sessions: u16, +} + #[derive(Clone, Copy)] #[repr(C)] pub(crate) struct VoicevoxAudioQueryOptions { diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/simple_tts.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/simple_tts.rs index 2e691100f..55bfa9558 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases/simple_tts.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/simple_tts.rs @@ -37,6 +37,7 @@ impl assert_cdylib::TestCase for TestCase { unsafe fn exec(&self, lib: &Library) -> anyhow::Result<()> { let Symbols { voicevox_default_initialize_options, + voicevox_default_load_voice_model_options, voicevox_default_tts_options, voicevox_open_jtalk_rc_new, voicevox_open_jtalk_rc_delete, @@ -82,7 +83,11 @@ impl assert_cdylib::TestCase for TestCase { synthesizer.assume_init() }; - assert_ok(voicevox_synthesizer_load_voice_model(synthesizer, model)); + assert_ok(voicevox_synthesizer_load_voice_model( + synthesizer, + model, + **voicevox_default_load_voice_model_options, + )); let (wav_length, wav) = { let mut wav_length = MaybeUninit::uninit(); diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/tts_via_audio_query.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/tts_via_audio_query.rs index 024d68b0f..fa4ca05dc 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases/tts_via_audio_query.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/tts_via_audio_query.rs @@ -37,6 +37,7 @@ impl assert_cdylib::TestCase for TestCase { unsafe fn exec(&self, lib: &Library) -> anyhow::Result<()> { let Symbols { voicevox_default_initialize_options, + voicevox_default_load_voice_model_options, voicevox_default_audio_query_options, voicevox_default_synthesis_options, voicevox_open_jtalk_rc_new, @@ -85,7 +86,11 @@ impl assert_cdylib::TestCase for TestCase { synthesizer.assume_init() }; - assert_ok(voicevox_synthesizer_load_voice_model(synthesizer, model)); + assert_ok(voicevox_synthesizer_load_voice_model( + synthesizer, + model, + **voicevox_default_load_voice_model_options, + )); let audio_query = { let mut audio_query = MaybeUninit::uninit(); diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs index b9284c153..a9a545c8e 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs @@ -38,6 +38,7 @@ impl assert_cdylib::TestCase for TestCase { voicevox_user_dict_add_word, voicevox_user_dict_delete, voicevox_default_initialize_options, + voicevox_default_load_voice_model_options, voicevox_default_audio_query_options, voicevox_open_jtalk_rc_new, voicevox_open_jtalk_rc_use_user_dict, @@ -100,7 +101,11 @@ impl assert_cdylib::TestCase for TestCase { synthesizer.assume_init() }; - assert_ok(voicevox_synthesizer_load_voice_model(synthesizer, model)); + assert_ok(voicevox_synthesizer_load_voice_model( + synthesizer, + model, + **voicevox_default_load_voice_model_options, + )); let mut audio_query_without_dict = std::ptr::null_mut(); assert_ok(voicevox_synthesizer_audio_query( diff --git a/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi b/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi index 020545528..2931b9ceb 100644 --- a/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi +++ b/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi @@ -90,13 +90,16 @@ class Synthesizer: def metas(self) -> SpeakerMeta: """メタ情報を取得する。""" ... - async def load_voice_model(self, model: VoiceModel) -> None: + async def load_voice_model( + self, model: VoiceModel, gpu_num_sessions: int = 1 + ) -> None: """モデルを読み込む。 Parameters ---------- style_id 読み込むモデルの話者ID。 + gpu_num_sessions """ ... def unload_voice_model(self, voice_model_id: str) -> None: diff --git a/crates/voicevox_core_python_api/src/lib.rs b/crates/voicevox_core_python_api/src/lib.rs index 21346680f..2495f6adb 100644 --- a/crates/voicevox_core_python_api/src/lib.rs +++ b/crates/voicevox_core_python_api/src/lib.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, future::Future, path::PathBuf, sync::Arc}; +use std::{fmt::Display, future::Future, num::NonZeroU16, path::PathBuf, sync::Arc}; use easy_ext::ext; use log::debug; @@ -16,8 +16,8 @@ use tokio::{runtime::Runtime, sync::Mutex}; use uuid::Uuid; use voicevox_core::{ AccelerationMode, AccentPhraseModel, AccentPhrasesOptions, AudioQueryModel, AudioQueryOptions, - InitializeOptions, StyleId, SynthesisOptions, TtsOptions, UserDictWord, UserDictWordType, - VoiceModelId, VoiceModelMeta, + InitializeOptions, LoadVoiceModelOptions, StyleId, SynthesisOptions, TtsOptions, UserDictWord, + UserDictWordType, VoiceModelId, VoiceModelMeta, }; static RUNTIME: Lazy = Lazy::new(|| Runtime::new().unwrap()); @@ -160,12 +160,18 @@ impl Synthesizer { #[getter] fn metas<'py>(&self, py: Python<'py>) -> Vec<&'py PyAny> { - to_pydantic_voice_model_meta(RUNTIME.block_on(self.synthesizer.lock()).metas(), py).unwrap() + to_pydantic_voice_model_meta(&RUNTIME.block_on(self.synthesizer.lock()).metas(), py) + .unwrap() } + #[pyo3(signature =( + model, + gpu_num_sessions = LoadVoiceModelOptions::default().gpu_num_sessions, + ))] fn load_voice_model<'py>( &mut self, model: &'py PyAny, + gpu_num_sessions: NonZeroU16, py: Python<'py>, ) -> PyResult<&'py PyAny> { let model: VoiceModel = model.extract()?; @@ -174,7 +180,7 @@ impl Synthesizer { synthesizer .lock() .await - .load_voice_model(&model.model) + .load_voice_model(&model.model, &LoadVoiceModelOptions { gpu_num_sessions }) .await .into_py_result() }) From ce2f8c0b97a6ebcf49ba31fe69ea895ff0bf85b6 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 29 Jul 2023 16:03:14 +0900 Subject: [PATCH 02/16] =?UTF-8?q?`decode-with-gpu`=E3=82=92=E5=89=A5?= =?UTF-8?q?=E3=81=8C=E3=81=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 222 ++++-------------- crates/voicevox_core/Cargo.toml | 5 - .../voicevox_core/benches/decode-with-gpu.rs | 78 ------ 3 files changed, 41 insertions(+), 264 deletions(-) delete mode 100644 crates/voicevox_core/benches/decode-with-gpu.rs diff --git a/Cargo.lock b/Cargo.lock index 3a6091d07..a783bb6d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,12 +101,6 @@ dependencies = [ "libc", ] -[[package]] -name = "anes" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" - [[package]] name = "anyhow" version = "1.0.65" @@ -358,7 +352,7 @@ version = "0.60.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "062dddbc1ba4aca46de6338e2bf87771414c335f7b2f2036e8f3e9befebf88e6" dependencies = [ - "bitflags 1.3.2", + "bitflags", "cexpr", "clang-sys", "clap 3.2.22", @@ -392,12 +386,6 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" -[[package]] -name = "bitflags" -version = "2.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" - [[package]] name = "block-buffer" version = "0.9.0" @@ -493,12 +481,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1db59621ec70f09c5e9b597b220c7a2b43611f4710dc03ceb8748637775692c" -[[package]] -name = "cast" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" - [[package]] name = "cbindgen" version = "0.24.3" @@ -561,33 +543,6 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fff857943da45f546682664a79488be82e69e43c1a7a2307679ab9afb3a66d2e" -[[package]] -name = "ciborium" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" - -[[package]] -name = "ciborium-ll" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" -dependencies = [ - "ciborium-io", - "half", -] - [[package]] name = "cipher" version = "0.2.5" @@ -624,7 +579,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86447ad904c7fb335a790c9d7fe3d0d971dc523b8ccd1561a520de9a85302750" dependencies = [ "atty", - "bitflags 1.3.2", + "bitflags", "clap_lex 0.2.4", "indexmap 1.9.1", "strsim", @@ -639,7 +594,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b1a0a4208c6c483b952ad35c6eed505fc13b46f08f631b81e828084a9318d74" dependencies = [ "atty", - "bitflags 1.3.2", + "bitflags", "clap_derive", "clap_lex 0.3.0", "once_cell", @@ -730,7 +685,7 @@ version = "0.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "318d6c16e73b3a900eb212ad6a82fc7d298c5ab8184c7a9998646455bc474a16" dependencies = [ - "bitflags 1.3.2", + "bitflags", "concolor-query", "is-terminal", ] @@ -843,44 +798,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "criterion" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" -dependencies = [ - "anes", - "cast", - "ciborium", - "clap 4.0.10", - "criterion-plot", - "futures", - "is-terminal", - "itertools", - "num-traits", - "once_cell", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_derive", - "serde_json", - "tinytemplate", - "tokio", - "walkdir", -] - -[[package]] -name = "criterion-plot" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" -dependencies = [ - "cast", - "itertools", -] - [[package]] name = "crossbeam-channel" version = "0.5.6" @@ -1208,6 +1125,17 @@ dependencies = [ "serde", ] +[[package]] +name = "errno" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +dependencies = [ + "errno-dragonfly", + "libc", + "winapi", +] + [[package]] name = "errno" version = "0.3.1" @@ -1518,12 +1446,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "half" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" - [[package]] name = "hashbrown" version = "0.12.3" @@ -1841,13 +1763,14 @@ checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" [[package]] name = "is-terminal" -version = "0.4.9" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +checksum = "28dfb6c8100ccc63462345b67d1bbc3679177c75ee4bf59bf29c8b1d110b8189" dependencies = [ - "hermit-abi 0.3.2", - "rustix 0.38.4", - "windows-sys 0.48.0", + "hermit-abi 0.2.6", + "io-lifetimes", + "rustix 0.36.7", + "windows-sys 0.42.0", ] [[package]] @@ -1943,9 +1866,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.142" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" [[package]] name = "libloading" @@ -2007,15 +1930,15 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.3.8" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" [[package]] name = "linux-raw-sys" -version = "0.4.3" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "lock_api" @@ -2353,12 +2276,6 @@ dependencies = [ "zip", ] -[[package]] -name = "oorandom" -version = "11.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" - [[package]] name = "opaque-debug" version = "0.3.0" @@ -2557,34 +2474,6 @@ version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3d7ddaed09e0eb771a79ab0fd64609ba0afb0a8366421957936ad14cbd13630" -[[package]] -name = "plotters" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" - -[[package]] -name = "plotters-svg" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" -dependencies = [ - "plotters-backend", -] - [[package]] name = "polling" version = "2.3.0" @@ -2919,7 +2808,7 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ - "bitflags 1.3.2", + "bitflags", ] [[package]] @@ -2928,7 +2817,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags 1.3.2", + "bitflags", ] [[package]] @@ -3069,28 +2958,29 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.19" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" +checksum = "d4fdebc4b395b7fbb9ab11e462e20ed9051e7b16e42d24042c776eca0ac81b03" dependencies = [ - "bitflags 1.3.2", - "errno", + "bitflags", + "errno 0.2.8", "io-lifetimes", "libc", - "linux-raw-sys 0.3.8", - "windows-sys 0.48.0", + "linux-raw-sys 0.1.4", + "windows-sys 0.42.0", ] [[package]] name = "rustix" -version = "0.38.4" +version = "0.37.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5" +checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" dependencies = [ - "bitflags 2.3.3", - "errno", + "bitflags", + "errno 0.3.1", + "io-lifetimes", "libc", - "linux-raw-sys 0.4.3", + "linux-raw-sys 0.3.8", "windows-sys 0.48.0", ] @@ -3127,15 +3017,6 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - [[package]] name = "schannel" version = "0.1.20" @@ -3732,16 +3613,6 @@ dependencies = [ "syn 1.0.102", ] -[[package]] -name = "tinytemplate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" -dependencies = [ - "serde", - "serde_json", -] - [[package]] name = "tinyvec" version = "1.6.0" @@ -4111,7 +3982,6 @@ dependencies = [ "async_zip", "cfg-if", "const-default", - "criterion", "derive-getters", "derive-new", "duplicate", @@ -4215,16 +4085,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" -[[package]] -name = "walkdir" -version = "2.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" -dependencies = [ - "same-file", - "winapi-util", -] - [[package]] name = "want" version = "0.3.0" diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index a6f746d22..9e5febe22 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -4,10 +4,6 @@ version.workspace = true edition.workspace = true publish.workspace = true -[[bench]] -name = "decode-with-gpu" -harness = false - [features] default = [] directml = ["onnxruntime/directml"] @@ -47,7 +43,6 @@ git = "https://github.com/VOICEVOX/open_jtalk-rs.git" rev = "a16714ce16dec76fd0e3041a7acfa484921db3b5" [dev-dependencies] -criterion = { version = "0.5.1", features = ["async_tokio"] } flate2 = "1.0.24" heck = "0.4.0" pretty_assertions = "1.3.0" diff --git a/crates/voicevox_core/benches/decode-with-gpu.rs b/crates/voicevox_core/benches/decode-with-gpu.rs deleted file mode 100644 index 85fe1952b..000000000 --- a/crates/voicevox_core/benches/decode-with-gpu.rs +++ /dev/null @@ -1,78 +0,0 @@ -use std::{num::NonZeroU16, sync::Arc}; - -use criterion::{criterion_group, criterion_main, Criterion}; -use test_util::OPEN_JTALK_DIC_DIR; -use tokio::{join, runtime::Runtime}; -use voicevox_core::{ - AccelerationMode, AudioQueryModel, InitializeOptions, LoadVoiceModelOptions, OpenJtalk, - StyleId, SynthesisOptions, Synthesizer, VoiceModel, -}; - -criterion_main!(benches); -criterion_group!(benches, benchmark); - -fn benchmark(c: &mut Criterion) { - let (synthesizer, aq) = &Runtime::new().unwrap().block_on(setup()).unwrap(); - - let decode = || async { - synthesizer - .synthesis( - aq, - StyleId::new(0), - &SynthesisOptions { - enable_interrogative_upspeak: true, - }, - ) - .await - .unwrap() - }; - - c.bench_function("decode_parallel", |b| { - b.to_async(Runtime::new().unwrap()) - .iter(|| async { join!(decode(), decode(), decode(), decode()) }) - }); - - c.bench_function("decode_sequential", |b| { - b.to_async(Runtime::new().unwrap()).iter(|| async { - for _ in 0..4 { - decode().await; - } - }) - }); -} - -async fn setup() -> voicevox_core::Result<(Synthesizer, AudioQueryModel)> { - let syntesizer = Synthesizer::new_with_initialize( - Arc::new(OpenJtalk::new_with_initialize(OPEN_JTALK_DIC_DIR).unwrap()), - &InitializeOptions { - acceleration_mode: AccelerationMode::Gpu, - cpu_num_threads: 4, - ..Default::default() - }, - ) - .await?; - - let model = &VoiceModel::from_path(concat!( - env!("CARGO_MANIFEST_DIR"), - "/../../model/sample.vvm", - )) - .await?; - syntesizer - .load_voice_model( - model, - &LoadVoiceModelOptions { - gpu_num_sessions: NonZeroU16::new(4).unwrap(), - }, - ) - .await?; - - let aq = syntesizer - .audio_query( - "寿限無寿限無五劫の擦り切れ海砂利水魚の水行末雲来末", - StyleId::new(0), - &Default::default(), - ) - .await?; - - Ok((syntesizer, aq)) -} From 47a77ee4f156a52c8a8bf6cceeb92f5282ad5b93 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 29 Jul 2023 16:13:16 +0900 Subject: [PATCH 03/16] =?UTF-8?q?`gpu=5Fnum=5Fsessions`=E3=82=AA=E3=83=97?= =?UTF-8?q?=E3=82=B7=E3=83=A7=E3=83=B3=E3=82=92=E5=89=A5=E3=81=8C=E3=81=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/inference_core.rs | 10 ++--- crates/voicevox_core/src/status.rs | 45 +++++++------------ crates/voicevox_core/src/voice_synthesizer.rs | 33 ++++---------- crates/voicevox_core_c_api/src/c_impls.rs | 14 ++---- .../src/compatible_engine.rs | 3 +- crates/voicevox_core_c_api/src/helpers.rs | 21 --------- crates/voicevox_core_c_api/src/lib.rs | 14 +----- crates/voicevox_core_python_api/src/lib.rs | 11 ++--- 8 files changed, 36 insertions(+), 115 deletions(-) diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index a2f6c0ce1..b78eadca6 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,5 +1,3 @@ -use std::num::NonZeroU16; - use self::status::*; use super::*; use onnxruntime::{ndarray, session::NdArray}; @@ -21,9 +19,7 @@ impl InferenceCore { if load_all_models { for model in &VoiceModel::get_all_models().await? { - status - .load_model(model, NonZeroU16::new(1).unwrap()) - .await?; + status.load_model(model).await?; } } Ok(Self { status }) @@ -44,8 +40,8 @@ impl InferenceCore { } } - pub async fn load_model(&self, model: &VoiceModel, gpu_num_sessions: NonZeroU16) -> Result<()> { - self.status.load_model(model, gpu_num_sessions).await + pub async fn load_model(&self, model: &VoiceModel) -> Result<()> { + self.status.load_model(model).await } pub fn unload_model(&self, voice_model_id: &VoiceModelId) -> Result<()> { diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 3388d2375..adbf9ae47 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -7,7 +7,7 @@ use onnxruntime::{ session::{NdArray, Session}, GraphOptimizationLevel, LoggingLevel, }; -use std::{collections::VecDeque, iter, num::NonZeroU16, sync::Arc}; +use std::sync::Arc; use std::{env, path::Path}; use tracing::error; @@ -60,7 +60,7 @@ impl Status { } } - pub async fn load_model(&self, model: &VoiceModel, gpu_num_sessions: NonZeroU16) -> Result<()> { + pub async fn load_model(&self, model: &VoiceModel) -> Result<()> { self.loaded_models .lock() .unwrap() @@ -78,15 +78,11 @@ impl Status { &self.light_session_options, model.path(), )?; - let decode_models = iter::repeat_with(|| { - self.new_session( - models.decode_model(), - &self.heavy_session_options, - model.path(), - ) - }) - .take(gpu_num_sessions.get().into()) - .collect::, _>>()?; + let decode_model = self.new_session( + models.decode_model(), + &self.heavy_session_options, + model.path(), + )?; self.loaded_models.lock().unwrap().insert( model, @@ -94,9 +90,9 @@ impl Status { predict_duration: predict_duration_session.into(), predict_intonation: predict_intonation_session.into(), }, - decode_models.into_iter().map(|decode| HeavyOrtSession { - decode: decode.into(), - }), + HeavyOrtSession { + decode: decode_model.into(), + }, )?; Ok(()) } @@ -321,7 +317,7 @@ impl LoadedModels { &mut self, model: &VoiceModel, light_sessions: LightOrtSessions, - heavy_session: impl IntoIterator, + heavy_session: HeavyOrtSession, ) -> Result<()> { self.ensure_not_contains(model)?; @@ -366,16 +362,15 @@ impl LoadedModels { } struct SessionSet { - // 不変条件: `self.heavy.len() >= 1` light: Arc>, - heavy: VecDeque>>, + heavy: Arc>, } impl SessionSet { - fn new(light: LightOrtSessions, heavy: impl IntoIterator) -> Self { + fn new(light: LightOrtSessions, heavy: HeavyOrtSession) -> Self { Self { light: Arc::new(light.into()), - heavy: heavy.into_iter().map(Into::into).map(Arc::new).collect(), + heavy: Arc::new(heavy.into()), } } @@ -383,12 +378,8 @@ impl SessionSet { self.light.clone() } - /// # Panics - /// - /// `self.heavy`が空のときパニックする。 fn get_heavy(&mut self) -> Arc> { - self.heavy.rotate_left(1); - self.heavy.back().unwrap().clone() + self.heavy.clone() } } @@ -468,9 +459,7 @@ mod tests { #[tokio::test] async fn status_load_model_works() { let status = Status::new(false, 0); - let result = status - .load_model(&open_default_vvm_file().await, NonZeroU16::new(1).unwrap()) - .await; + let result = status.load_model(&open_default_vvm_file().await).await; assert_debug_fmt_eq!(Ok(()), result); assert_eq!(1, status.loaded_models.lock().unwrap().0.len()); } @@ -484,7 +473,7 @@ mod tests { !status.is_loaded_model(vvm.id()), "model should not be loaded" ); - let result = status.load_model(&vvm, NonZeroU16::new(1).unwrap()).await; + let result = status.load_model(&vvm).await; assert_debug_fmt_eq!(Ok(()), result); assert!(status.is_loaded_model(vvm.id()), "model should be loaded"); } diff --git a/crates/voicevox_core/src/voice_synthesizer.rs b/crates/voicevox_core/src/voice_synthesizer.rs index 01b2e4790..377022806 100644 --- a/crates/voicevox_core/src/voice_synthesizer.rs +++ b/crates/voicevox_core/src/voice_synthesizer.rs @@ -1,4 +1,4 @@ -use std::{num::NonZeroU16, sync::Arc}; +use std::sync::Arc; use const_default::ConstDefault; use duplicate::duplicate_item; @@ -7,18 +7,6 @@ use crate::engine::{create_kana, parse_kana, AccentPhraseModel, OpenJtalk, Synth use super::*; -pub struct LoadVoiceModelOptions { - pub gpu_num_sessions: NonZeroU16, -} - -impl ConstDefault for LoadVoiceModelOptions { - const DEFAULT: Self = Self { - // SAFETY: 1 ≠ 0 - #[allow(unsafe_code)] - gpu_num_sessions: unsafe { NonZeroU16::new_unchecked(1) }, - }; -} - pub struct SynthesisOptions { pub enable_interrogative_upspeak: bool, } @@ -91,7 +79,6 @@ pub struct InitializeOptions { #[duplicate_item( T; - [ LoadVoiceModelOptions ]; [ AccentPhrasesOptions ]; [ AudioQueryOptions ]; [ TtsOptions ]; @@ -154,14 +141,10 @@ impl Synthesizer { } /// 音声モデルを読み込む - pub async fn load_voice_model( - &self, - model: &VoiceModel, - options: &LoadVoiceModelOptions, - ) -> Result<()> { + pub async fn load_voice_model(&self, model: &VoiceModel) -> Result<()> { self.synthesis_engine .inference_core() - .load_model(model, options.gpu_num_sessions) + .load_model(model) .await?; Ok(()) } @@ -406,7 +389,7 @@ mod tests { .unwrap(); let result = syntesizer - .load_voice_model(&open_default_vvm_file().await, &Default::default()) + .load_voice_model(&open_default_vvm_file().await) .await; assert_debug_fmt_eq!( @@ -450,7 +433,7 @@ mod tests { "expected is_model_loaded to return false, but got true", ); syntesizer - .load_voice_model(&open_default_vvm_file().await, &Default::default()) + .load_voice_model(&open_default_vvm_file().await) .await .unwrap(); @@ -476,7 +459,7 @@ mod tests { .unwrap(); syntesizer - .load_voice_model(&open_default_vvm_file().await, &Default::default()) + .load_voice_model(&open_default_vvm_file().await) .await .unwrap(); @@ -507,7 +490,7 @@ mod tests { .await .unwrap(); syntesizer - .load_voice_model(&open_default_vvm_file().await, &Default::default()) + .load_voice_model(&open_default_vvm_file().await) .await .unwrap(); @@ -549,7 +532,7 @@ mod tests { .await .unwrap(); syntesizer - .load_voice_model(&open_default_vvm_file().await, &Default::default()) + .load_voice_model(&open_default_vvm_file().await) .await .unwrap(); diff --git a/crates/voicevox_core_c_api/src/c_impls.rs b/crates/voicevox_core_c_api/src/c_impls.rs index 99314d7d2..596d8bf75 100644 --- a/crates/voicevox_core_c_api/src/c_impls.rs +++ b/crates/voicevox_core_c_api/src/c_impls.rs @@ -6,9 +6,7 @@ use std::{ use voicevox_core::{InitializeOptions, OpenJtalk, Result, Synthesizer, VoiceModel, VoiceModelId}; -use crate::{ - CApiResult, OpenJtalkRc, VoicevoxLoadVoiceModelOptions, VoicevoxSynthesizer, VoicevoxVoiceModel, -}; +use crate::{CApiResult, OpenJtalkRc, VoicevoxSynthesizer, VoicevoxVoiceModel}; impl OpenJtalkRc { pub(crate) fn new_with_initialize(open_jtalk_dic_dir: impl AsRef) -> Result { @@ -30,14 +28,8 @@ impl VoicevoxSynthesizer { }) } - pub(crate) async fn load_voice_model( - &self, - model: &VoiceModel, - options: VoicevoxLoadVoiceModelOptions, - ) -> CApiResult<()> { - self.synthesizer - .load_voice_model(model, &options.try_into()?) - .await?; + pub(crate) async fn load_voice_model(&self, model: &VoiceModel) -> CApiResult<()> { + self.synthesizer.load_voice_model(model).await?; let metas = &self.synthesizer.metas(); *self.metas_cstring.lock().unwrap() = CString::new(serde_json::to_string(metas).unwrap()).unwrap(); diff --git a/crates/voicevox_core_c_api/src/compatible_engine.rs b/crates/voicevox_core_c_api/src/compatible_engine.rs index 3f455e7fc..e84040abc 100644 --- a/crates/voicevox_core_c_api/src/compatible_engine.rs +++ b/crates/voicevox_core_c_api/src/compatible_engine.rs @@ -100,8 +100,7 @@ pub extern "C" fn load_model(style_id: i64) -> bool { if let Some(model_id) = model_set.style_model_map.get(&style_id) { let vvm = model_set.model_map.get(model_id).unwrap(); let synthesizer = &mut *lock_synthesizer(); - let result = RUNTIME - .block_on(ensure_initialized!(synthesizer).load_voice_model(vvm, &Default::default())); + let result = RUNTIME.block_on(ensure_initialized!(synthesizer).load_voice_model(vvm)); if let Some(err) = result.err() { set_message(&format!("{err}")); false diff --git a/crates/voicevox_core_c_api/src/helpers.rs b/crates/voicevox_core_c_api/src/helpers.rs index cc21301c6..dc798a1c8 100644 --- a/crates/voicevox_core_c_api/src/helpers.rs +++ b/crates/voicevox_core_c_api/src/helpers.rs @@ -81,27 +81,6 @@ pub(crate) fn ensure_utf8(s: &CStr) -> CApiResult<&str> { s.to_str().map_err(|_| CApiError::InvalidUtf8Input) } -impl ConstDefault for VoicevoxLoadVoiceModelOptions { - const DEFAULT: Self = { - let options = voicevox_core::LoadVoiceModelOptions::DEFAULT; - Self { - gpu_num_sessions: options.gpu_num_sessions.get(), - } - }; -} - -impl TryFrom for voicevox_core::LoadVoiceModelOptions { - type Error = CApiError; - - fn try_from(options: VoicevoxLoadVoiceModelOptions) -> std::result::Result { - let gpu_num_sessions = options - .gpu_num_sessions - .try_into() - .unwrap_or_else(|_| todo!()); - Ok(Self { gpu_num_sessions }) - } -} - impl ConstDefault for VoicevoxAudioQueryOptions { const DEFAULT: Self = { let options = voicevox_core::AudioQueryOptions::DEFAULT; diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index 948dac9e7..50b1d8340 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -287,15 +287,6 @@ pub extern "C" fn voicevox_synthesizer_delete(synthesizer: Box VoicevoxResultCode { - into_result_code_with_error( - RUNTIME.block_on(synthesizer.load_voice_model(model.model(), options)), - ) + into_result_code_with_error(RUNTIME.block_on(synthesizer.load_voice_model(model.model()))) } /// モデルの読み込みを解除する diff --git a/crates/voicevox_core_python_api/src/lib.rs b/crates/voicevox_core_python_api/src/lib.rs index bb6475107..c70b26363 100644 --- a/crates/voicevox_core_python_api/src/lib.rs +++ b/crates/voicevox_core_python_api/src/lib.rs @@ -1,4 +1,4 @@ -use std::{num::NonZeroU16, sync::Arc}; +use std::sync::Arc; mod convert; use convert::*; @@ -15,7 +15,7 @@ use tokio::{runtime::Runtime, sync::Mutex}; use uuid::Uuid; use voicevox_core::{ AccelerationMode, AccentPhrasesOptions, AudioQueryModel, AudioQueryOptions, InitializeOptions, - LoadVoiceModelOptions, StyleId, SynthesisOptions, TtsOptions, UserDictWord, VoiceModelId, + StyleId, SynthesisOptions, TtsOptions, UserDictWord, VoiceModelId, }; static RUNTIME: Lazy = Lazy::new(|| Runtime::new().unwrap()); @@ -165,14 +165,9 @@ impl Synthesizer { .unwrap() } - #[pyo3(signature =( - model, - gpu_num_sessions = LoadVoiceModelOptions::default().gpu_num_sessions, - ))] fn load_voice_model<'py>( &mut self, model: &'py PyAny, - gpu_num_sessions: NonZeroU16, py: Python<'py>, ) -> PyResult<&'py PyAny> { let model: VoiceModel = model.extract()?; @@ -181,7 +176,7 @@ impl Synthesizer { synthesizer .lock() .await - .load_voice_model(&model.model, &LoadVoiceModelOptions { gpu_num_sessions }) + .load_voice_model(&model.model) .await .into_py_result() }) From ee2dc76558ffdd9c45c06e17b1da12bb6b440248 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 29 Jul 2023 16:40:38 +0900 Subject: [PATCH 04/16] =?UTF-8?q?`predict`=E3=82=82=E5=88=A5=E3=80=85?= =?UTF-8?q?=E3=81=AB`Mutex`=E3=81=AB=E5=8C=85=E3=82=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://github.com/VOICEVOX/voicevox_core/pull/553#discussion_r1278084797 --- crates/voicevox_core/src/status.rs | 140 ++++++++++++++--------------- 1 file changed, 67 insertions(+), 73 deletions(-) diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index adbf9ae47..f27ed6d13 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -86,13 +86,9 @@ impl Status { self.loaded_models.lock().unwrap().insert( model, - LightOrtSessions { - predict_duration: predict_duration_session.into(), - predict_intonation: predict_intonation_session.into(), - }, - HeavyOrtSession { - decode: decode_model.into(), - }, + predict_duration_session, + predict_intonation_session, + decode_model, )?; Ok(()) } @@ -169,19 +165,16 @@ impl Status { mut phoneme_vector_array: NdArray, mut speaker_id_array: NdArray, ) -> Result> { - let light_sessions = self + let predict_duration = self .loaded_models .lock() .unwrap() - .light_sessions(style_id)?; + .predict_duration(style_id)?; tokio::task::spawn_blocking(move || { - let LightOrtSessions { - predict_duration, .. - } = &mut *light_sessions.lock().unwrap(); + let mut predict_duration = predict_duration.lock().unwrap(); let output_tensors = predict_duration - .get_mut() .run(vec![&mut phoneme_vector_array, &mut speaker_id_array]) .map_err(|_| Error::InferenceFailed)?; Ok(output_tensors[0].as_slice().unwrap().to_owned()) @@ -203,19 +196,16 @@ impl Status { mut end_accent_phrase_vector_array: NdArray, mut speaker_id_array: NdArray, ) -> Result> { - let light_sessions = self + let predict_intonation = self .loaded_models .lock() .unwrap() - .light_sessions(style_id)?; + .predict_intonation(style_id)?; tokio::task::spawn_blocking(move || { - let LightOrtSessions { - predict_intonation, .. - } = &mut *light_sessions.lock().unwrap(); + let mut predict_intonation = predict_intonation.lock().unwrap(); let output_tensors = predict_intonation - .get_mut() .run(vec![ &mut length_array, &mut vowel_phoneme_vector_array, @@ -240,13 +230,12 @@ impl Status { mut phoneme_array: NdArray, mut speaker_id_array: NdArray, ) -> Result> { - let heavy_session = self.loaded_models.lock().unwrap().heavy_session(style_id)?; + let decode = self.loaded_models.lock().unwrap().decode(style_id)?; tokio::task::spawn_blocking(move || { - let HeavyOrtSession { decode } = &mut *heavy_session.lock().unwrap(); + let mut decode = decode.lock().unwrap(); let output_tensors = decode - .get_mut() .run(vec![ &mut f0_array, &mut phoneme_array, @@ -277,20 +266,41 @@ impl LoadedModels { .collect() } - fn light_sessions( - &mut self, + fn predict_duration( + &self, style_id: StyleId, - ) -> Result>> { - let LoadedModel { session_set, .. } = self.find_loaded_voice_model(style_id)?; - Ok(session_set.get_light()) + ) -> Result>>>> { + let LoadedModel { + session_set: SessionSet { + predict_duration, .. + }, + .. + } = self.find_loaded_voice_model(style_id)?; + Ok(predict_duration.clone()) } - fn heavy_session( - &mut self, + fn predict_intonation( + &self, style_id: StyleId, - ) -> Result>> { - let LoadedModel { session_set, .. } = self.find_loaded_voice_model(style_id)?; - Ok(session_set.get_heavy()) + ) -> Result>>>> { + let LoadedModel { + session_set: SessionSet { + predict_intonation, .. + }, + .. + } = self.find_loaded_voice_model(style_id)?; + Ok(predict_intonation.clone()) + } + + fn decode( + &self, + style_id: StyleId, + ) -> Result>>>> { + let LoadedModel { + session_set: SessionSet { decode, .. }, + .. + } = self.find_loaded_voice_model(style_id)?; + Ok(decode.clone()) } fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool { @@ -316,8 +326,9 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - light_sessions: LightOrtSessions, - heavy_session: HeavyOrtSession, + predict_duration: Session<'static>, + predict_intonation: Session<'static>, + decode: Session<'static>, ) -> Result<()> { self.ensure_not_contains(model)?; @@ -325,7 +336,11 @@ impl LoadedModels { model.id().clone(), LoadedModel { metas: model.metas().clone(), - session_set: SessionSet::new(light_sessions, heavy_session), + session_set: SessionSet { + predict_duration: Arc::new(std::sync::Mutex::new(predict_duration.into())), + predict_intonation: Arc::new(std::sync::Mutex::new(predict_intonation.into())), + decode: Arc::new(std::sync::Mutex::new(decode.into())), + }, }, ); assert!(prev.is_none()); @@ -341,9 +356,9 @@ impl LoadedModels { Ok(()) } - fn find_loaded_voice_model(&mut self, style_id: StyleId) -> Result<&mut LoadedModel> { + fn find_loaded_voice_model(&self, style_id: StyleId) -> Result<&LoadedModel> { self.0 - .values_mut() + .values() .find(|LoadedModel { metas, .. }| { metas .iter() @@ -362,34 +377,9 @@ impl LoadedModels { } struct SessionSet { - light: Arc>, - heavy: Arc>, -} - -impl SessionSet { - fn new(light: LightOrtSessions, heavy: HeavyOrtSession) -> Self { - Self { - light: Arc::new(light.into()), - heavy: Arc::new(heavy.into()), - } - } - - fn get_light(&self) -> Arc> { - self.light.clone() - } - - fn get_heavy(&mut self) -> Arc> { - self.heavy.clone() - } -} - -struct LightOrtSessions { - predict_duration: AssertSend>, - predict_intonation: AssertSend>, -} - -struct HeavyOrtSession { - decode: AssertSend>, + predict_duration: Arc>>>, + predict_intonation: Arc>>>, + decode: Arc>>>, } // FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 @@ -398,28 +388,32 @@ struct HeavyOrtSession { use self::assert_send::AssertSend; mod assert_send { + use std::ops::{Deref, DerefMut}; + use onnxruntime::session::Session; pub(super) struct AssertSend(T); - impl AssertSend { - pub(super) fn get_mut(&mut self) -> &mut T { - &mut self.0 - } - } - impl From> for AssertSend> { fn from(session: Session<'static>) -> Self { Self(session) } } - impl AsRef for AssertSend { - fn as_ref(&self) -> &T { + impl Deref for AssertSend { + type Target = T; + + fn deref(&self) -> &Self::Target { &self.0 } } + impl DerefMut for AssertSend { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + // SAFETY: `Session` is probably "send"able. #[allow(unsafe_code)] unsafe impl Send for AssertSend {} From cf92b69d6226a663a613678c6fe81e23697d35d2 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 29 Jul 2023 16:52:11 +0900 Subject: [PATCH 05/16] =?UTF-8?q?`LoadedModels`=E3=81=AEdoc=E3=82=92?= =?UTF-8?q?=E6=9B=B8=E3=81=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://github.com/VOICEVOX/voicevox_core/pull/553#discussion_r1277932897 --- crates/voicevox_core/src/status.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index f27ed6d13..26a1abd6e 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -249,6 +249,9 @@ impl Status { } } +/// 読み込んだモデルの`Session`とそのメタ情報を保有し、追加/削除/取得の操作を提供する。 +/// +/// この構造体のメソッドは、すべて一瞬で完了すべきである。 #[derive(Default)] struct LoadedModels(BTreeMap); From b893fed34f69b3a70dd0a82919e1481e2adf4a30 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 29 Jul 2023 22:32:19 +0900 Subject: [PATCH 06/16] `cargo xtask update-c-header` --- crates/voicevox_core_c_api/include/voicevox_core.h | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/crates/voicevox_core_c_api/include/voicevox_core.h b/crates/voicevox_core_c_api/include/voicevox_core.h index c36144b03..d2747f309 100644 --- a/crates/voicevox_core_c_api/include/voicevox_core.h +++ b/crates/voicevox_core_c_api/include/voicevox_core.h @@ -220,10 +220,6 @@ typedef struct VoicevoxInitializeOptions { bool load_all_models; } VoicevoxInitializeOptions; -typedef struct VoicevoxLoadVoiceModelOptions { - uint16_t gpu_num_sessions; -} VoicevoxLoadVoiceModelOptions; - /** * スタイルID */ @@ -307,8 +303,6 @@ extern const struct VoicevoxInitializeOptions voicevox_default_initialize_option extern const char *voicevox_version; -extern const struct VoicevoxLoadVoiceModelOptions voicevox_default_load_voice_model_options; - extern const struct VoicevoxAudioQueryOptions voicevox_default_audio_query_options; extern const struct VoicevoxAccentPhrasesOptions voicevox_default_accent_phrases_options; @@ -454,8 +448,7 @@ void voicevox_synthesizer_delete(struct VoicevoxSynthesizer *synthesizer); __declspec(dllimport) #endif VoicevoxResultCode voicevox_synthesizer_load_voice_model(const struct VoicevoxSynthesizer *synthesizer, - const struct VoicevoxVoiceModel *model, - struct VoicevoxLoadVoiceModelOptions options); + const struct VoicevoxVoiceModel *model); /** * モデルの読み込みを解除する From 780d8807580a05679b0af6ec47c61d0ae65b58ad Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 29 Jul 2023 22:33:53 +0900 Subject: [PATCH 07/16] `git restore -s main -- crates/voicevox_core_c_api/tests/` --- crates/voicevox_core_c_api/tests/e2e/symbols.rs | 17 ++--------------- .../tests/e2e/testcases/simple_tts.rs | 7 +------ .../tests/e2e/testcases/tts_via_audio_query.rs | 7 +------ .../tests/e2e/testcases/user_dict_load.rs | 7 +------ 4 files changed, 5 insertions(+), 33 deletions(-) diff --git a/crates/voicevox_core_c_api/tests/e2e/symbols.rs b/crates/voicevox_core_c_api/tests/e2e/symbols.rs index ff1e300be..810fbf91b 100644 --- a/crates/voicevox_core_c_api/tests/e2e/symbols.rs +++ b/crates/voicevox_core_c_api/tests/e2e/symbols.rs @@ -8,8 +8,6 @@ use voicevox_core::result_code::VoicevoxResultCode; pub(crate) struct Symbols<'lib> { pub(crate) voicevox_version: Symbol<'lib, &'lib &'lib c_char>, pub(crate) voicevox_default_initialize_options: Symbol<'lib, &'lib VoicevoxInitializeOptions>, - pub(crate) voicevox_default_load_voice_model_options: - Symbol<'lib, &'lib VoicevoxLoadVoiceModelOptions>, pub(crate) voicevox_default_audio_query_options: Symbol<'lib, &'lib VoicevoxAudioQueryOptions>, pub(crate) voicevox_default_synthesis_options: Symbol<'lib, &'lib VoicevoxSynthesisOptions>, pub(crate) voicevox_default_tts_options: Symbol<'lib, &'lib VoicevoxTtsOptions>, @@ -45,17 +43,13 @@ pub(crate) struct Symbols<'lib> { pub(crate) voicevox_synthesizer_load_voice_model: Symbol< 'lib, unsafe extern "C" fn( - *const VoicevoxSynthesizer, + *mut VoicevoxSynthesizer, *const VoicevoxVoiceModel, - VoicevoxLoadVoiceModelOptions, ) -> VoicevoxResultCode, >, pub(crate) voicevox_synthesizer_unload_voice_model: Symbol< 'lib, - unsafe extern "C" fn( - *const VoicevoxSynthesizer, - VoicevoxVoiceModelId, - ) -> VoicevoxResultCode, + unsafe extern "C" fn(*mut VoicevoxSynthesizer, VoicevoxVoiceModelId) -> VoicevoxResultCode, >, pub(crate) voicevox_synthesizer_is_gpu_mode: Symbol<'lib, unsafe extern "C" fn(*const VoicevoxSynthesizer) -> bool>, @@ -192,7 +186,6 @@ impl<'lib> Symbols<'lib> { Ok(new!( voicevox_version, voicevox_default_initialize_options, - voicevox_default_load_voice_model_options, voicevox_default_audio_query_options, voicevox_default_synthesis_options, voicevox_default_tts_options, @@ -260,12 +253,6 @@ pub(crate) struct VoicevoxInitializeOptions { pub(crate) _load_all_models: bool, } -#[derive(Clone, Copy)] -#[repr(C)] -pub(crate) struct VoicevoxLoadVoiceModelOptions { - pub(crate) gpu_num_sessions: u16, -} - #[derive(Clone, Copy)] #[repr(C)] pub(crate) struct VoicevoxAudioQueryOptions { diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/simple_tts.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/simple_tts.rs index 55bfa9558..2e691100f 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases/simple_tts.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/simple_tts.rs @@ -37,7 +37,6 @@ impl assert_cdylib::TestCase for TestCase { unsafe fn exec(&self, lib: &Library) -> anyhow::Result<()> { let Symbols { voicevox_default_initialize_options, - voicevox_default_load_voice_model_options, voicevox_default_tts_options, voicevox_open_jtalk_rc_new, voicevox_open_jtalk_rc_delete, @@ -83,11 +82,7 @@ impl assert_cdylib::TestCase for TestCase { synthesizer.assume_init() }; - assert_ok(voicevox_synthesizer_load_voice_model( - synthesizer, - model, - **voicevox_default_load_voice_model_options, - )); + assert_ok(voicevox_synthesizer_load_voice_model(synthesizer, model)); let (wav_length, wav) = { let mut wav_length = MaybeUninit::uninit(); diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/tts_via_audio_query.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/tts_via_audio_query.rs index fa4ca05dc..024d68b0f 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases/tts_via_audio_query.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/tts_via_audio_query.rs @@ -37,7 +37,6 @@ impl assert_cdylib::TestCase for TestCase { unsafe fn exec(&self, lib: &Library) -> anyhow::Result<()> { let Symbols { voicevox_default_initialize_options, - voicevox_default_load_voice_model_options, voicevox_default_audio_query_options, voicevox_default_synthesis_options, voicevox_open_jtalk_rc_new, @@ -86,11 +85,7 @@ impl assert_cdylib::TestCase for TestCase { synthesizer.assume_init() }; - assert_ok(voicevox_synthesizer_load_voice_model( - synthesizer, - model, - **voicevox_default_load_voice_model_options, - )); + assert_ok(voicevox_synthesizer_load_voice_model(synthesizer, model)); let audio_query = { let mut audio_query = MaybeUninit::uninit(); diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs index a9a545c8e..b9284c153 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs @@ -38,7 +38,6 @@ impl assert_cdylib::TestCase for TestCase { voicevox_user_dict_add_word, voicevox_user_dict_delete, voicevox_default_initialize_options, - voicevox_default_load_voice_model_options, voicevox_default_audio_query_options, voicevox_open_jtalk_rc_new, voicevox_open_jtalk_rc_use_user_dict, @@ -101,11 +100,7 @@ impl assert_cdylib::TestCase for TestCase { synthesizer.assume_init() }; - assert_ok(voicevox_synthesizer_load_voice_model( - synthesizer, - model, - **voicevox_default_load_voice_model_options, - )); + assert_ok(voicevox_synthesizer_load_voice_model(synthesizer, model)); let mut audio_query_without_dict = std::ptr::null_mut(); assert_ok(voicevox_synthesizer_audio_query( From c460e48db62c510665b9572617434b2f2e888c25 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 12 Aug 2023 01:42:02 +0900 Subject: [PATCH 08/16] =?UTF-8?q?`synthesizer=5Fget=5Fmetas=5Fjson`=20?= =?UTF-8?q?=E2=86=92=20`synthesizer=5Fcreate=5Fmetas=5Fjson`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../include/voicevox_core.h | 6 ++++-- crates/voicevox_core_c_api/src/c_impls.rs | 18 ++++-------------- crates/voicevox_core_c_api/src/lib.rs | 10 ++++++---- .../voicevox_core_c_api/tests/e2e/symbols.rs | 4 ++-- 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/crates/voicevox_core_c_api/include/voicevox_core.h b/crates/voicevox_core_c_api/include/voicevox_core.h index 1e06d3e90..acea64ed3 100644 --- a/crates/voicevox_core_c_api/include/voicevox_core.h +++ b/crates/voicevox_core_c_api/include/voicevox_core.h @@ -623,19 +623,20 @@ bool voicevox_synthesizer_is_loaded_voice_model(const struct VoicevoxSynthesizer /** * 今読み込んでいる音声モデルのメタ情報を、JSONで取得する。 * + * JSONの解放は ::voicevox_json_free で行う。 + * * @param [in] synthesizer 音声シンセサイザ * * @return メタ情報のJSON文字列 * * \safety{ * - `synthesizer`は ::voicevox_synthesizer_new_with_initialize で得たものでなければならず、また ::voicevox_synthesizer_delete で解放されていてはいけない。 - * - 戻り値の文字列の生存期間(_lifetime_)は次にこの関数が呼ばれるか、`synthesizer`が破棄されるまでである。この生存期間を越えて文字列にアクセスしてはならない。 * } */ #ifdef _WIN32 __declspec(dllimport) #endif -const char *voicevox_synthesizer_get_metas_json(const struct VoicevoxSynthesizer *synthesizer); +const char *voicevox_synthesizer_create_metas_json(const struct VoicevoxSynthesizer *synthesizer); /** * このライブラリで利用可能なデバイスの情報を、JSONで取得する。 @@ -909,6 +910,7 @@ VoicevoxResultCode voicevox_synthesizer_tts(const struct VoicevoxSynthesizer *sy * \safety{ * - `json`は以下のAPIで得られたポインタでなくてはいけない。 * - ::voicevox_create_supported_devices_json + * - ::voicevox_synthesizer_create_metas_json * - ::voicevox_synthesizer_audio_query * - ::voicevox_synthesizer_create_accent_phrases * - ::voicevox_synthesizer_replace_mora_data diff --git a/crates/voicevox_core_c_api/src/c_impls.rs b/crates/voicevox_core_c_api/src/c_impls.rs index 596d8bf75..410a46260 100644 --- a/crates/voicevox_core_c_api/src/c_impls.rs +++ b/crates/voicevox_core_c_api/src/c_impls.rs @@ -1,8 +1,4 @@ -use std::{ - ffi::{c_char, CString}, - path::Path, - sync::Arc, -}; +use std::{ffi::CString, path::Path, sync::Arc}; use voicevox_core::{InitializeOptions, OpenJtalk, Result, Synthesizer, VoiceModel, VoiceModelId}; @@ -24,28 +20,22 @@ impl VoicevoxSynthesizer { Ok(Self { synthesizer: Synthesizer::new_with_initialize(open_jtalk.open_jtalk.clone(), options) .await?, - metas_cstring: Default::default(), }) } pub(crate) async fn load_voice_model(&self, model: &VoiceModel) -> CApiResult<()> { self.synthesizer.load_voice_model(model).await?; - let metas = &self.synthesizer.metas(); - *self.metas_cstring.lock().unwrap() = - CString::new(serde_json::to_string(metas).unwrap()).unwrap(); Ok(()) } pub(crate) fn unload_voice_model(&self, model_id: &VoiceModelId) -> Result<()> { self.synthesizer.unload_voice_model(model_id)?; - let metas = &self.synthesizer.metas(); - *self.metas_cstring.lock().unwrap() = - CString::new(serde_json::to_string(metas).unwrap()).unwrap(); Ok(()) } - pub(crate) fn metas_ptr(&self) -> *const c_char { - self.metas_cstring.lock().unwrap().as_ptr() + pub(crate) fn metas(&self) -> CString { + let metas = &self.synthesizer.metas(); + CString::new(serde_json::to_string(metas).unwrap()).unwrap() } } diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index 1a55ee8bf..a01047789 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -314,7 +314,6 @@ pub extern "C" fn voicevox_voice_model_delete(model: Box) { #[derive(Getters)] pub struct VoicevoxSynthesizer { synthesizer: Synthesizer, - metas_cstring: std::sync::Mutex, } /// ::VoicevoxSynthesizer を構築(_construct_)する。 @@ -442,19 +441,21 @@ pub unsafe extern "C" fn voicevox_synthesizer_is_loaded_voice_model( /// 今読み込んでいる音声モデルのメタ情報を、JSONで取得する。 /// +/// JSONの解放は ::voicevox_json_free で行う。 +/// /// @param [in] synthesizer 音声シンセサイザ /// /// @return メタ情報のJSON文字列 /// /// \safety{ /// - `synthesizer`は ::voicevox_synthesizer_new_with_initialize で得たものでなければならず、また ::voicevox_synthesizer_delete で解放されていてはいけない。 -/// - 戻り値の文字列の生存期間(_lifetime_)は次にこの関数が呼ばれるか、`synthesizer`が破棄されるまでである。この生存期間を越えて文字列にアクセスしてはならない。 /// } #[no_mangle] -pub extern "C" fn voicevox_synthesizer_get_metas_json( +pub extern "C" fn voicevox_synthesizer_create_metas_json( synthesizer: &VoicevoxSynthesizer, ) -> *const c_char { - synthesizer.metas_ptr() + let metas = synthesizer.metas(); + C_STRING_DROP_CHECKER.whitelist(metas).into_raw() } /// このライブラリで利用可能なデバイスの情報を、JSONで取得する。 @@ -878,6 +879,7 @@ pub unsafe extern "C" fn voicevox_synthesizer_tts( /// \safety{ /// - `json`は以下のAPIで得られたポインタでなくてはいけない。 /// - ::voicevox_create_supported_devices_json +/// - ::voicevox_synthesizer_create_metas_json /// - ::voicevox_synthesizer_audio_query /// - ::voicevox_synthesizer_create_accent_phrases /// - ::voicevox_synthesizer_replace_mora_data diff --git a/crates/voicevox_core_c_api/tests/e2e/symbols.rs b/crates/voicevox_core_c_api/tests/e2e/symbols.rs index 810fbf91b..7eb0b5650 100644 --- a/crates/voicevox_core_c_api/tests/e2e/symbols.rs +++ b/crates/voicevox_core_c_api/tests/e2e/symbols.rs @@ -57,7 +57,7 @@ pub(crate) struct Symbols<'lib> { 'lib, unsafe extern "C" fn(*const VoicevoxSynthesizer, VoicevoxVoiceModelId) -> bool, >, - pub(crate) voicevox_synthesizer_get_metas_json: + pub(crate) voicevox_synthesizer_create_metas_json: Symbol<'lib, unsafe extern "C" fn(*const VoicevoxSynthesizer) -> *const c_char>, pub(crate) voicevox_create_supported_devices_json: Symbol<'lib, unsafe extern "C" fn(*mut *mut c_char) -> VoicevoxResultCode>, @@ -202,7 +202,7 @@ impl<'lib> Symbols<'lib> { voicevox_synthesizer_unload_voice_model, voicevox_synthesizer_is_gpu_mode, voicevox_synthesizer_is_loaded_voice_model, - voicevox_synthesizer_get_metas_json, + voicevox_synthesizer_create_metas_json, voicevox_create_supported_devices_json, voicevox_synthesizer_audio_query, voicevox_synthesizer_synthesis, From 06fef88cb5db8d44dad8e7d5ffcb5c45aed2088c Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 12 Aug 2023 23:13:24 +0900 Subject: [PATCH 09/16] =?UTF-8?q?#575=20=E3=81=A7=E8=BF=BD=E5=8A=A0?= =?UTF-8?q?=E3=81=95=E3=82=8C=E3=81=9F=E3=83=86=E3=82=B9=E3=83=88=E3=82=92?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../synthesizer_new_with_initialize_output_json.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/synthesizer_new_with_initialize_output_json.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/synthesizer_new_with_initialize_output_json.rs index 0edd73300..a03bc8377 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases/synthesizer_new_with_initialize_output_json.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/synthesizer_new_with_initialize_output_json.rs @@ -31,7 +31,8 @@ impl assert_cdylib::TestCase for TestCase { voicevox_open_jtalk_rc_delete, voicevox_synthesizer_new_with_initialize, voicevox_synthesizer_delete, - voicevox_synthesizer_get_metas_json, + voicevox_synthesizer_create_metas_json, + voicevox_json_free, .. } = Symbols::new(lib)?; @@ -60,9 +61,11 @@ impl assert_cdylib::TestCase for TestCase { }; let metas_json = { - let metas_json = - CStr::from_ptr(voicevox_synthesizer_get_metas_json(synthesizer)).to_str()?; - serde_json::to_string_pretty(&metas_json.parse::()?).unwrap() + let raw = voicevox_synthesizer_create_metas_json(synthesizer) as *mut std::ffi::c_char; + let metas_json = &CStr::from_ptr(raw).to_str()?.parse::()?; + let metas_json = serde_json::to_string_pretty(metas_json).unwrap(); + voicevox_json_free(raw); + metas_json }; std::assert_eq!(SNAPSHOTS.metas, metas_json); From 4891a9c08d46d5420c87a29b4944c657517ac9ef Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 12 Aug 2023 23:13:50 +0900 Subject: [PATCH 10/16] =?UTF-8?q?warning=E3=82=92=E8=A7=A3=E6=B6=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit コンフリクト解消が適当だったので同じ宣言が二つできていてwarningが出てい た。 --- crates/voicevox_core_c_api/src/c_impls.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/crates/voicevox_core_c_api/src/c_impls.rs b/crates/voicevox_core_c_api/src/c_impls.rs index a5a87a25b..a891593a4 100644 --- a/crates/voicevox_core_c_api/src/c_impls.rs +++ b/crates/voicevox_core_c_api/src/c_impls.rs @@ -19,10 +19,7 @@ impl VoicevoxSynthesizer { ) -> Result { let synthesizer = Synthesizer::new_with_initialize(open_jtalk.open_jtalk.clone(), options).await?; - Ok(Self { - synthesizer: Synthesizer::new_with_initialize(open_jtalk.open_jtalk.clone(), options) - .await?, - }) + Ok(Self { synthesizer }) } pub(crate) async fn load_voice_model(&self, model: &VoiceModel) -> CApiResult<()> { From e971c2b18c0a821d9cbfaa7957d2c2b0670b40c8 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 12 Aug 2023 23:19:48 +0900 Subject: [PATCH 11/16] =?UTF-8?q?`create=5Fmetas=5Fjson`=E3=81=AE=E8=BF=94?= =?UTF-8?q?=E3=82=8A=E5=80=A4=E3=82=92`*mut=20c=5Fchar`=E3=81=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core_c_api/include/voicevox_core.h | 2 +- crates/voicevox_core_c_api/src/lib.rs | 2 +- crates/voicevox_core_c_api/tests/e2e/symbols.rs | 2 +- .../testcases/synthesizer_new_with_initialize_output_json.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/voicevox_core_c_api/include/voicevox_core.h b/crates/voicevox_core_c_api/include/voicevox_core.h index 724c353d3..3bcfc1cf7 100644 --- a/crates/voicevox_core_c_api/include/voicevox_core.h +++ b/crates/voicevox_core_c_api/include/voicevox_core.h @@ -636,7 +636,7 @@ bool voicevox_synthesizer_is_loaded_voice_model(const struct VoicevoxSynthesizer #ifdef _WIN32 __declspec(dllimport) #endif -const char *voicevox_synthesizer_create_metas_json(const struct VoicevoxSynthesizer *synthesizer); +char *voicevox_synthesizer_create_metas_json(const struct VoicevoxSynthesizer *synthesizer); /** * このライブラリで利用可能なデバイスの情報を、JSONで取得する。 diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index 88ce9cfb3..d948b7a66 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -453,7 +453,7 @@ pub unsafe extern "C" fn voicevox_synthesizer_is_loaded_voice_model( #[no_mangle] pub extern "C" fn voicevox_synthesizer_create_metas_json( synthesizer: &VoicevoxSynthesizer, -) -> *const c_char { +) -> *mut c_char { let metas = synthesizer.metas(); C_STRING_DROP_CHECKER.whitelist(metas).into_raw() } diff --git a/crates/voicevox_core_c_api/tests/e2e/symbols.rs b/crates/voicevox_core_c_api/tests/e2e/symbols.rs index 292c5d881..3b6279101 100644 --- a/crates/voicevox_core_c_api/tests/e2e/symbols.rs +++ b/crates/voicevox_core_c_api/tests/e2e/symbols.rs @@ -58,7 +58,7 @@ pub(crate) struct Symbols<'lib> { unsafe extern "C" fn(*const VoicevoxSynthesizer, VoicevoxVoiceModelId) -> bool, >, pub(crate) voicevox_synthesizer_create_metas_json: - Symbol<'lib, unsafe extern "C" fn(*const VoicevoxSynthesizer) -> *const c_char>, + Symbol<'lib, unsafe extern "C" fn(*const VoicevoxSynthesizer) -> *mut c_char>, pub(crate) voicevox_create_supported_devices_json: Symbol<'lib, unsafe extern "C" fn(*mut *mut c_char) -> VoicevoxResultCode>, pub(crate) voicevox_synthesizer_create_audio_query: Symbol< diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/synthesizer_new_with_initialize_output_json.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/synthesizer_new_with_initialize_output_json.rs index a03bc8377..7238f56ef 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases/synthesizer_new_with_initialize_output_json.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/synthesizer_new_with_initialize_output_json.rs @@ -61,7 +61,7 @@ impl assert_cdylib::TestCase for TestCase { }; let metas_json = { - let raw = voicevox_synthesizer_create_metas_json(synthesizer) as *mut std::ffi::c_char; + let raw = voicevox_synthesizer_create_metas_json(synthesizer); let metas_json = &CStr::from_ptr(raw).to_str()?.parse::()?; let metas_json = serde_json::to_string_pretty(metas_json).unwrap(); voicevox_json_free(raw); From 24f5e0f24066ee143329e4faa364f04305695f66 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Wed, 16 Aug 2023 22:25:38 +0900 Subject: [PATCH 12/16] Rework `ensure_not_contains` --- crates/voicevox_core/src/status.rs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 88d954583..91667eacc 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -64,7 +64,7 @@ impl Status { self.loaded_models .lock() .unwrap() - .ensure_not_contains(model)?; + .ensure_acceptable(model)?; let models = model.read_inference_models().await?; @@ -361,11 +361,18 @@ impl LoadedModels { self.styles().any(|style| *style.id() == style_id) } - fn ensure_not_contains(&self, model: &VoiceModel) -> Result<()> { + /// 与えられた`VoiceModel`を受け入れ可能かをチェックする。 + /// + /// # Errors + /// + /// 音声モデルIDかスタイルIDが`model`と重複するとき、エラーを返す。 + fn ensure_acceptable(&self, model: &VoiceModel) -> Result<()> { let loaded = self.styles(); let external = model.metas().iter().flat_map(|speaker| speaker.styles()); - if iproduct!(loaded, external).any(|(loaded, external)| loaded.id() == external.id()) { + if self.0.contains_key(model.id()) + || iproduct!(loaded, external).any(|(loaded, external)| loaded.id() == external.id()) + { return Err(Error::AlreadyLoadedModel { path: model.path().clone(), }); @@ -380,7 +387,7 @@ impl LoadedModels { predict_intonation: Session<'static>, decode: Session<'static>, ) -> Result<()> { - self.ensure_not_contains(model)?; + self.ensure_acceptable(model)?; let prev = self.0.insert( model.id().clone(), From e66168e3e47f43a8aef25377138e7a8c55c6b6e1 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Wed, 16 Aug 2023 22:52:26 +0900 Subject: [PATCH 13/16] =?UTF-8?q?`LoadedModels`=E3=81=8B=E3=82=89=E3=81=AE?= =?UTF-8?q?session=E5=8F=96=E5=BE=97=E3=82=92`get`=E3=81=A8=E3=81=84?= =?UTF-8?q?=E3=81=86=E5=8D=98=E4=B8=80=E3=81=AE=E3=83=A1=E3=82=BD=E3=83=83?= =?UTF-8?q?=E3=83=89=E3=81=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/status.rs | 69 +++++++++--------------------- 1 file changed, 20 insertions(+), 49 deletions(-) diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 91667eacc..6863af84a 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -172,11 +172,12 @@ impl Status { mut phoneme_vector_array: NdArray, mut speaker_id_array: NdArray, ) -> Result> { - let predict_duration = self - .loaded_models - .lock() - .unwrap() - .predict_duration(model_id); + let predict_duration = self.loaded_models.lock().unwrap().get( + model_id, + |SessionSet { + predict_duration, .. + }| predict_duration, + ); tokio::task::spawn_blocking(move || { let mut predict_duration = predict_duration.lock().unwrap(); @@ -206,11 +207,12 @@ impl Status { mut end_accent_phrase_vector_array: NdArray, mut speaker_id_array: NdArray, ) -> Result> { - let predict_intonation = self - .loaded_models - .lock() - .unwrap() - .predict_intonation(model_id); + let predict_intonation = self.loaded_models.lock().unwrap().get( + model_id, + |SessionSet { + predict_intonation, .. + }| predict_intonation, + ); tokio::task::spawn_blocking(move || { let mut predict_intonation = predict_intonation.lock().unwrap(); @@ -243,7 +245,11 @@ impl Status { mut phoneme_array: NdArray, mut speaker_id_array: NdArray, ) -> Result> { - let decode = self.loaded_models.lock().unwrap().decode(model_id); + let decode = self + .loaded_models + .lock() + .unwrap() + .get(model_id, |SessionSet { decode, .. }| decode); tokio::task::spawn_blocking(move || { let mut decode = decode.lock().unwrap(); @@ -310,47 +316,12 @@ impl LoadedModels { /// # Panics /// /// `self`が`model_id`を含んでいないとき、パニックする。 - fn predict_duration( - &self, - model_id: &VoiceModelId, - ) -> Arc>>> { - let LoadedModel { - session_set: SessionSet { - predict_duration, .. - }, - .. - } = &self.0[model_id]; - predict_duration.clone() - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - fn predict_intonation( - &self, - model_id: &VoiceModelId, - ) -> Arc>>> { - let LoadedModel { - session_set: SessionSet { - predict_intonation, .. - }, - .. - } = &self.0[model_id]; - predict_intonation.clone() - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - fn decode( + fn get( &self, model_id: &VoiceModelId, + which: fn(&SessionSet) -> &Arc>>>, ) -> Arc>>> { - let LoadedModel { - session_set: SessionSet { decode, .. }, - .. - } = &self.0[model_id]; - decode.clone() + which(&self.0[model_id].session_set).clone() } fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool { From 1101624e033f3b853c895ed465950ddbf1757e1f Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Fri, 18 Aug 2023 01:28:24 +0900 Subject: [PATCH 14/16] =?UTF-8?q?`Error::LoadModel`=E3=81=AB=E8=89=B2?= =?UTF-8?q?=E3=80=85=E7=B5=B1=E5=90=88=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 20 ++++ crates/voicevox_core/Cargo.toml | 1 + crates/voicevox_core/src/error.rs | 75 +++++++----- crates/voicevox_core/src/result_code.rs | 32 +++-- crates/voicevox_core/src/status.rs | 32 +++-- crates/voicevox_core/src/voice_model.rs | 112 +++++++++--------- .../include/voicevox_core.h | 28 ++--- crates/voicevox_core_c_api/src/helpers.rs | 14 ++- crates/voicevox_core_c_api/src/lib.rs | 7 -- 9 files changed, 181 insertions(+), 140 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 595e0e30f..a4460dae4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -730,6 +730,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "cookie" version = "0.14.4" @@ -956,6 +962,19 @@ dependencies = [ "syn 1.0.102", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version 0.4.0", + "syn 1.0.102", +] + [[package]] name = "diff" version = "0.1.13" @@ -3942,6 +3961,7 @@ dependencies = [ "cfg-if", "derive-getters", "derive-new", + "derive_more", "easy-ext", "flate2", "fs-err", diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 05949fa90..c542a9860 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -14,6 +14,7 @@ async_zip.workspace = true cfg-if = "1.0.0" derive-getters.workspace = true derive-new = "0.5.9" +derive_more = "0.99.17" easy-ext.workspace = true fs-err.workspace = true futures = "0.3.26" diff --git a/crates/voicevox_core/src/error.rs b/crates/voicevox_core/src/error.rs index 7ad258a34..12883f2ab 100644 --- a/crates/voicevox_core/src/error.rs +++ b/crates/voicevox_core/src/error.rs @@ -21,14 +21,8 @@ pub enum Error { #[error("{}", base_error_message(VOICEVOX_RESULT_GPU_SUPPORT_ERROR))] GpuSupport, - #[error("{} ({}): {source}", base_error_message(VOICEVOX_RESULT_LOAD_MODEL_ERROR), path.display())] - LoadModel { - path: PathBuf, - #[source] - source: anyhow::Error, - }, - #[error("{} ({})", base_error_message(VOICEVOX_RESULT_ALREADY_LOADED_MODEL_ERROR), path.display())] - AlreadyLoadedModel { path: PathBuf }, + #[error(transparent)] + LoadModel(#[from] LoadModelError), #[error( "{} ({model_id:?})", @@ -36,29 +30,6 @@ pub enum Error { )] UnloadedModel { model_id: VoiceModelId }, - #[error( - "{}({path}):{source}", - base_error_message(VOICEVOX_RESULT_OPEN_FILE_ERROR) - )] - OpenFile { - path: PathBuf, - #[source] - source: anyhow::Error, - }, - - #[error( - "{}({path}):{source}", - base_error_message(VOICEVOX_RESULT_VVM_MODEL_READ_ERROR) - )] - VvmRead { - path: PathBuf, - #[source] - source: anyhow::Error, - }, - - #[error("{},{0}", base_error_message(VOICEVOX_RESULT_LOAD_METAS_ERROR))] - LoadMetas(#[source] anyhow::Error), - #[error( "{},{0}", base_error_message(VOICEVOX_RESULT_GET_SUPPORTED_DEVICES_ERROR) @@ -111,6 +82,48 @@ pub enum Error { InvalidWord(InvalidWordError), } +pub(crate) type LoadModelResult = std::result::Result; + +/// 音声モデル読み込みのエラー。 +#[derive(Error, Debug)] +#[error( + "`{path}`の読み込みに失敗しました: {context}{}", + source.as_ref().map(|e| format!(": {e}")).unwrap_or_default()) +] +pub struct LoadModelError { + pub(crate) path: PathBuf, + pub(crate) context: LoadModelErrorKind, + #[source] + pub(crate) source: Option, +} + +impl LoadModelError { + pub fn context(&self) -> &LoadModelErrorKind { + &self.context + } +} + +#[derive(derive_more::Display, Debug)] +pub enum LoadModelErrorKind { + //#[display(fmt = "{}", "base_error_message(VOICEVOX_RESULT_OPEN_ZIP_FILE_ERROR)")] + #[display(fmt = "ZIPファイルとして開くことができませんでした")] + OpenZipFile, + //#[display(fmt = "{}", "base_error_message(VOICEVOX_RESULT_READ_ZIP_ENTRY_ERROR)")] + #[display(fmt = "`{filename}`を読み取れませんでした")] + ReadZipEntry { filename: String }, + //#[display(fmt = "{}", "base_error_message(VOICEVOX_RESULT_MODEL_ALREADY_LOADED_ERROR)")] + #[display(fmt = "モデル`{id}`は既に読み込まれています")] + ModelAlreadyLoaded { id: VoiceModelId }, + //#[display(fmt = "{}", "base_error_message(VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR)")] + #[display(fmt = "スタイル`{id}`は既に読み込まれています")] + StyleAlreadyLoaded { id: StyleId }, + #[display( + fmt = "{}", + "base_error_message(VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR)" + )] + InvalidModelData, +} + fn base_error_message(result_code: VoicevoxResultCode) -> &'static str { let c_message: &'static str = crate::result_code::error_result_to_message(result_code); &c_message[..(c_message.len() - 1)] diff --git a/crates/voicevox_core/src/result_code.rs b/crates/voicevox_core/src/result_code.rs index 07cf05e18..24d02a0f0 100644 --- a/crates/voicevox_core/src/result_code.rs +++ b/crates/voicevox_core/src/result_code.rs @@ -11,14 +11,10 @@ pub enum VoicevoxResultCode { VOICEVOX_RESULT_OK = 0, /// open_jtalk辞書ファイルが読み込まれていない VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT_ERROR = 1, - /// modelの読み込みに失敗した - VOICEVOX_RESULT_LOAD_MODEL_ERROR = 2, /// サポートされているデバイス情報取得に失敗した VOICEVOX_RESULT_GET_SUPPORTED_DEVICES_ERROR = 3, /// GPUモードがサポートされていない VOICEVOX_RESULT_GPU_SUPPORT_ERROR = 4, - /// メタ情報読み込みに失敗した - VOICEVOX_RESULT_LOAD_METAS_ERROR = 5, /// 無効なstyle_idが指定された VOICEVOX_RESULT_INVALID_STYLE_ID_ERROR = 6, /// 無効なmodel_idが指定された @@ -35,12 +31,16 @@ pub enum VoicevoxResultCode { VOICEVOX_RESULT_INVALID_AUDIO_QUERY_ERROR = 14, /// 無効なAccentPhrase VOICEVOX_RESULT_INVALID_ACCENT_PHRASE_ERROR = 15, - /// ファイルオープンエラー - VOICEVOX_RESULT_OPEN_FILE_ERROR = 16, - /// Modelを読み込めなかった - VOICEVOX_RESULT_VVM_MODEL_READ_ERROR = 17, - /// すでに読み込まれているModelを読み込もうとした - VOICEVOX_RESULT_ALREADY_LOADED_MODEL_ERROR = 18, + /// ZIPファイルを開くことに失敗した + VOICEVOX_RESULT_OPEN_ZIP_FILE_ERROR = 16, + /// ZIP内のファイルが読めなかった + VOICEVOX_RESULT_READ_ZIP_ENTRY_ERROR = 17, + /// すでに読み込まれている音声モデルを読み込もうとした + VOICEVOX_RESULT_MODEL_ALREADY_LOADED_ERROR = 18, + /// すでに読み込まれているスタイルを読み込もうとした + VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR = 2, + /// 無効なモデルデータ + VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR = 5, /// Modelが読み込まれていない VOICEVOX_RESULT_UNLOADED_MODEL_ERROR = 19, /// ユーザー辞書を読み込めなかった @@ -64,8 +64,6 @@ pub const fn error_result_to_message(result_code: VoicevoxResultCode) -> &'stati VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT_ERROR => { "OpenJTalkの辞書が読み込まれていません\0" } - VOICEVOX_RESULT_LOAD_MODEL_ERROR => "modelデータ読み込みに失敗しました\0", - VOICEVOX_RESULT_LOAD_METAS_ERROR => "メタデータ読み込みに失敗しました\0", VOICEVOX_RESULT_GPU_SUPPORT_ERROR => "GPU機能をサポートすることができません\0", VOICEVOX_RESULT_GET_SUPPORTED_DEVICES_ERROR => { @@ -85,11 +83,11 @@ pub const fn error_result_to_message(result_code: VoicevoxResultCode) -> &'stati } VOICEVOX_RESULT_INVALID_AUDIO_QUERY_ERROR => "無効なaudio_queryです\0", VOICEVOX_RESULT_INVALID_ACCENT_PHRASE_ERROR => "無効なaccent_phraseです\0", - VOICEVOX_RESULT_OPEN_FILE_ERROR => "ファイルオープンに失敗しました\0", - VOICEVOX_RESULT_VVM_MODEL_READ_ERROR => "Modelを読み込めませんでした\0", - VOICEVOX_RESULT_ALREADY_LOADED_MODEL_ERROR => { - "すでに読み込まれているModelを読み込もうとしました\0" - } + VOICEVOX_RESULT_OPEN_ZIP_FILE_ERROR => "ZIPファイルのオープンに失敗しました\0", + VOICEVOX_RESULT_READ_ZIP_ENTRY_ERROR => "ZIP内のファイルを読むことができませんでした\0", + VOICEVOX_RESULT_MODEL_ALREADY_LOADED_ERROR => "同じIDのモデルを読むことはできません\0", + VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR => "同じIDのスタイルを読むことはできません\0", + VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR => "モデルデータを読むことができませんでした\0", VOICEVOX_RESULT_UNLOADED_MODEL_ERROR => "Modelが読み込まれていません\0", VOICEVOX_RESULT_LOAD_USER_DICT_ERROR => "ユーザー辞書を読み込めませんでした\0", VOICEVOX_RESULT_SAVE_USER_DICT_ERROR => "ユーザー辞書を書き込めませんでした\0", diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 6863af84a..a92ffae8e 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -121,11 +121,12 @@ impl Status { model: &[u8], session_options: &SessionOptions, path: impl AsRef, - ) -> Result> { + ) -> LoadModelResult> { self.new_session_from_bytes(|| model_file::decrypt(model), session_options) - .map_err(|source| Error::LoadModel { - path: path.as_ref().into(), - source, + .map_err(|source| LoadModelError { + path: path.as_ref().to_owned(), + context: LoadModelErrorKind::InvalidModelData, + source: Some(source), }) } @@ -337,16 +338,27 @@ impl LoadedModels { /// # Errors /// /// 音声モデルIDかスタイルIDが`model`と重複するとき、エラーを返す。 - fn ensure_acceptable(&self, model: &VoiceModel) -> Result<()> { + fn ensure_acceptable(&self, model: &VoiceModel) -> LoadModelResult<()> { let loaded = self.styles(); let external = model.metas().iter().flat_map(|speaker| speaker.styles()); - if self.0.contains_key(model.id()) - || iproduct!(loaded, external).any(|(loaded, external)| loaded.id() == external.id()) + let error = |context| LoadModelError { + path: model.path().clone(), + context, + source: None, + }; + + if self.0.contains_key(model.id()) { + return Err(error(LoadModelErrorKind::ModelAlreadyLoaded { + id: model.id().clone(), + })); + } + if let Some((style, _)) = + iproduct!(loaded, external).find(|(loaded, external)| loaded.id() == external.id()) { - return Err(Error::AlreadyLoadedModel { - path: model.path().clone(), - }); + return Err(error(LoadModelErrorKind::StyleAlreadyLoaded { + id: *style.id(), + })); } Ok(()) } diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index f56ed9ed7..2eb114b6a 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use async_zip::{read::fs::ZipFileReader, ZipEntry}; use futures::future::{join3, join_all}; use serde::{de::DeserializeOwned, Deserialize}; @@ -6,7 +5,7 @@ use serde::{de::DeserializeOwned, Deserialize}; use super::*; use std::{ collections::{BTreeMap, HashMap}, - env, + env, io, path::{Path, PathBuf}, }; @@ -16,7 +15,9 @@ use std::{ pub type RawVoiceModelId = String; /// 音声モデルID。 -#[derive(PartialEq, Eq, Clone, Ord, PartialOrd, Deserialize, new, Getters, Debug)] +#[derive( + PartialEq, Eq, Clone, Ord, PartialOrd, Deserialize, new, Getters, derive_more::Display, Debug, +)] pub struct VoiceModelId { raw_voice_model_id: RawVoiceModelId, } @@ -42,7 +43,7 @@ pub(crate) struct InferenceModels { } impl VoiceModel { - pub(crate) async fn read_inference_models(&self) -> Result { + pub(crate) async fn read_inference_models(&self) -> LoadModelResult { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( @@ -53,39 +54,18 @@ impl VoiceModel { .await; Ok(InferenceModels { - predict_duration_model: predict_duration_model_result.map_err(|e| Error::VvmRead { - path: self.path.clone(), - source: e, - })?, - predict_intonation_model: predict_intonation_model_result.map_err(|e| { - Error::VvmRead { - path: self.path.clone(), - source: e, - } - })?, - decode_model: decode_model_result.map_err(|e| Error::VvmRead { - path: self.path.clone(), - source: e, - })?, + predict_duration_model: predict_duration_model_result?, + predict_intonation_model: predict_intonation_model_result?, + decode_model: decode_model_result?, }) } /// VVMファイルから`VoiceModel`をコンストラクトする。 - pub async fn from_path(path: impl AsRef) -> Result { - let reader = VvmEntryReader::open(&path).await?; - let manifest = reader - .read_vvm_json::("manifest.json") - .await - .map_err(|e| Error::VvmRead { - path: path.as_ref().into(), - source: e, - })?; + pub async fn from_path(path: impl AsRef) -> LoadModelResult { + let reader = VvmEntryReader::open(path.as_ref()).await?; + let manifest = reader.read_vvm_json::("manifest.json").await?; let metas = reader .read_vvm_json::(manifest.metas_filename()) - .await - .map_err(|e| Error::VvmRead { - path: path.as_ref().into(), - source: e, - })?; + .await?; let id = VoiceModelId::new(nanoid!()); Ok(Self { @@ -96,6 +76,10 @@ impl VoiceModel { }) } + // FIXME: `load_all_models`自体を廃止し、これはENGINE専用とする + /// # Panics + /// + /// 目的のディレクトリが読めなかったらパニックする pub async fn get_all_models() -> Result> { let root_dir = if cfg!(test) { Path::new(env!("CARGO_WORKSPACE_DIR")).join("model") @@ -113,15 +97,16 @@ impl VoiceModel { let vvm_paths = root_dir .read_dir() .and_then(|entries| entries.collect::, _>>()) - .map_err(|e| Error::LoadModel { - path: root_dir.clone(), - source: e.into(), - })? + .unwrap_or_else(|e| panic!("{}が読めませんでした: {e}", root_dir.display())) .into_iter() .filter(|entry| entry.path().extension().map_or(false, |ext| ext == "vvm")) .map(|entry| Self::from_path(entry.path())); - join_all(vvm_paths).await.into_iter().collect() + join_all(vvm_paths) + .await + .into_iter() + .collect::>() + .map_err(Into::into) } const ROOT_DIR_ENV_NAME: &str = "VV_MODELS_ROOT_DIR"; @@ -158,12 +143,13 @@ struct VvmEntryReader { } impl VvmEntryReader { - async fn open(path: impl AsRef) -> Result { - let reader = ZipFileReader::new(path.as_ref()) + async fn open(path: &Path) -> LoadModelResult { + let reader = ZipFileReader::new(path) .await - .map_err(|e| Error::OpenFile { - path: path.as_ref().into(), - source: e.into(), + .map_err(|source| LoadModelError { + path: path.to_owned(), + context: LoadModelErrorKind::OpenZipFile, + source: Some(source.into()), })?; let entry_map: HashMap<_, _> = reader .file() @@ -183,22 +169,38 @@ impl VvmEntryReader { .collect(); Ok(VvmEntryReader::new(reader, entry_map)) } - async fn read_vvm_json(&self, filename: &str) -> anyhow::Result { + async fn read_vvm_json(&self, filename: &str) -> LoadModelResult { let bytes = self.read_vvm_entry(filename).await?; - serde_json::from_slice(&bytes).map_err(|e| e.into()) + serde_json::from_slice(&bytes).map_err(|source| LoadModelError { + path: self.reader.path().to_owned(), + context: LoadModelErrorKind::ReadZipEntry { + filename: filename.to_owned(), + }, + source: Some(source.into()), + }) } - async fn read_vvm_entry(&self, filename: &str) -> anyhow::Result> { - let me = self - .entry_map - .get(filename) - .ok_or_else(|| anyhow!("Not found in vvm entries: {}", filename))?; - let mut manifest_reader = self.reader.entry(me.index).await?; - let mut buf = Vec::with_capacity(me.entry.uncompressed_size() as usize); - manifest_reader - .read_to_end_checked(&mut buf, &me.entry) - .await?; - Ok(buf) + async fn read_vvm_entry(&self, filename: &str) -> LoadModelResult> { + (|| async { + let me = self + .entry_map + .get(filename) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + let mut manifest_reader = self.reader.entry(me.index).await?; + let mut buf = Vec::with_capacity(me.entry.uncompressed_size() as usize); + manifest_reader + .read_to_end_checked(&mut buf, &me.entry) + .await?; + Ok::<_, anyhow::Error>(buf) + })() + .await + .map_err(|source| LoadModelError { + path: self.reader.path().to_owned(), + context: LoadModelErrorKind::ReadZipEntry { + filename: filename.to_owned(), + }, + source: Some(source), + }) } } diff --git a/crates/voicevox_core_c_api/include/voicevox_core.h b/crates/voicevox_core_c_api/include/voicevox_core.h index 4a5bcdf65..d26470df3 100644 --- a/crates/voicevox_core_c_api/include/voicevox_core.h +++ b/crates/voicevox_core_c_api/include/voicevox_core.h @@ -94,10 +94,6 @@ enum VoicevoxResultCode * open_jtalk辞書ファイルが読み込まれていない */ VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT_ERROR = 1, - /** - * modelの読み込みに失敗した - */ - VOICEVOX_RESULT_LOAD_MODEL_ERROR = 2, /** * サポートされているデバイス情報取得に失敗した */ @@ -106,10 +102,6 @@ enum VoicevoxResultCode * GPUモードがサポートされていない */ VOICEVOX_RESULT_GPU_SUPPORT_ERROR = 4, - /** - * メタ情報読み込みに失敗した - */ - VOICEVOX_RESULT_LOAD_METAS_ERROR = 5, /** * 無効なstyle_idが指定された */ @@ -143,17 +135,25 @@ enum VoicevoxResultCode */ VOICEVOX_RESULT_INVALID_ACCENT_PHRASE_ERROR = 15, /** - * ファイルオープンエラー + * ZIPファイルを開くことに失敗した + */ + VOICEVOX_RESULT_OPEN_ZIP_FILE_ERROR = 16, + /** + * ZIP内のファイルが読めなかった + */ + VOICEVOX_RESULT_READ_ZIP_ENTRY_ERROR = 17, + /** + * すでに読み込まれている音声モデルを読み込もうとした */ - VOICEVOX_RESULT_OPEN_FILE_ERROR = 16, + VOICEVOX_RESULT_MODEL_ALREADY_LOADED_ERROR = 18, /** - * Modelを読み込めなかった + * すでに読み込まれているスタイルを読み込もうとした */ - VOICEVOX_RESULT_VVM_MODEL_READ_ERROR = 17, + VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR = 2, /** - * すでに読み込まれているModelを読み込もうとした + * 無効なモデルデータ */ - VOICEVOX_RESULT_ALREADY_LOADED_MODEL_ERROR = 18, + VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR = 5, /** * Modelが読み込まれていない */ diff --git a/crates/voicevox_core_c_api/src/helpers.rs b/crates/voicevox_core_c_api/src/helpers.rs index 6b7ecf484..0e9cfd279 100644 --- a/crates/voicevox_core_c_api/src/helpers.rs +++ b/crates/voicevox_core_c_api/src/helpers.rs @@ -18,15 +18,20 @@ pub(crate) fn into_result_code_with_error(result: CApiResult<()>) -> VoicevoxRes } fn into_result_code(result: CApiResult<()>) -> VoicevoxResultCode { - use voicevox_core::{result_code::VoicevoxResultCode::*, Error::*}; + use voicevox_core::{result_code::VoicevoxResultCode::*, Error::*, LoadModelErrorKind::*}; use CApiError::*; match result { Ok(()) => VOICEVOX_RESULT_OK, Err(RustApi(NotLoadedOpenjtalkDict)) => VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT_ERROR, Err(RustApi(GpuSupport)) => VOICEVOX_RESULT_GPU_SUPPORT_ERROR, - Err(RustApi(LoadModel { .. })) => VOICEVOX_RESULT_LOAD_MODEL_ERROR, - Err(RustApi(LoadMetas(_))) => VOICEVOX_RESULT_LOAD_METAS_ERROR, + Err(RustApi(LoadModel(err))) => match err.context() { + OpenZipFile => VOICEVOX_RESULT_OPEN_ZIP_FILE_ERROR, + ReadZipEntry { .. } => VOICEVOX_RESULT_READ_ZIP_ENTRY_ERROR, + ModelAlreadyLoaded { .. } => VOICEVOX_RESULT_MODEL_ALREADY_LOADED_ERROR, + StyleAlreadyLoaded { .. } => VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR, + InvalidModelData => VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR, + }, Err(RustApi(GetSupportedDevices(_))) => VOICEVOX_RESULT_GET_SUPPORTED_DEVICES_ERROR, Err(RustApi(InvalidStyleId { .. })) => VOICEVOX_RESULT_INVALID_STYLE_ID_ERROR, Err(RustApi(InvalidModelId { .. })) => VOICEVOX_RESULT_INVALID_MODEL_ID_ERROR, @@ -35,9 +40,6 @@ pub(crate) fn into_result_code_with_error(result: CApiResult<()>) -> VoicevoxRes VOICEVOX_RESULT_EXTRACT_FULL_CONTEXT_LABEL_ERROR } Err(RustApi(UnloadedModel { .. })) => VOICEVOX_RESULT_UNLOADED_MODEL_ERROR, - Err(RustApi(AlreadyLoadedModel { .. })) => VOICEVOX_RESULT_ALREADY_LOADED_MODEL_ERROR, - Err(RustApi(OpenFile { .. })) => VOICEVOX_RESULT_OPEN_FILE_ERROR, - Err(RustApi(VvmRead { .. })) => VOICEVOX_RESULT_VVM_MODEL_READ_ERROR, Err(RustApi(ParseKana(_))) => VOICEVOX_RESULT_PARSE_KANA_ERROR, Err(RustApi(LoadUserDict(_))) => VOICEVOX_RESULT_LOAD_USER_DICT_ERROR, Err(RustApi(SaveUserDict(_))) => VOICEVOX_RESULT_SAVE_USER_DICT_ERROR, diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index c848217ca..e97b0e5bb 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -1236,13 +1236,6 @@ mod tests { Err(Error::NotLoadedOpenjtalkDict), VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT_ERROR )] - #[case( - Err(Error::LoadModel { - path: "path/to/model.onnx".into(), - source: anyhow!("some load model error"), - }), - VoicevoxResultCode::VOICEVOX_RESULT_LOAD_MODEL_ERROR - )] #[case( Err(Error::GetSupportedDevices(anyhow!("some get supported devices error"))), VoicevoxResultCode::VOICEVOX_RESULT_GET_SUPPORTED_DEVICES_ERROR From 36700d8a457e13142184948d3dec4ac89047f4ac Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Fri, 18 Aug 2023 07:09:47 +0900 Subject: [PATCH 15/16] =?UTF-8?q?=E6=AC=A0=E7=95=AA=E3=82=92=E4=BD=BF?= =?UTF-8?q?=E3=82=8F=E3=81=AA=E3=81=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/result_code.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/voicevox_core/src/result_code.rs b/crates/voicevox_core/src/result_code.rs index 24d02a0f0..541c65a79 100644 --- a/crates/voicevox_core/src/result_code.rs +++ b/crates/voicevox_core/src/result_code.rs @@ -38,9 +38,9 @@ pub enum VoicevoxResultCode { /// すでに読み込まれている音声モデルを読み込もうとした VOICEVOX_RESULT_MODEL_ALREADY_LOADED_ERROR = 18, /// すでに読み込まれているスタイルを読み込もうとした - VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR = 2, + VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR = 26, /// 無効なモデルデータ - VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR = 5, + VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR = 27, /// Modelが読み込まれていない VOICEVOX_RESULT_UNLOADED_MODEL_ERROR = 19, /// ユーザー辞書を読み込めなかった From 60cad17906df426488869c99ec60c06ec3b0267f Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Fri, 18 Aug 2023 07:27:11 +0900 Subject: [PATCH 16/16] `cargo xtask update-c-header` --- crates/voicevox_core_c_api/include/voicevox_core.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/voicevox_core_c_api/include/voicevox_core.h b/crates/voicevox_core_c_api/include/voicevox_core.h index d26470df3..8b012e901 100644 --- a/crates/voicevox_core_c_api/include/voicevox_core.h +++ b/crates/voicevox_core_c_api/include/voicevox_core.h @@ -149,11 +149,11 @@ enum VoicevoxResultCode /** * すでに読み込まれているスタイルを読み込もうとした */ - VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR = 2, + VOICEVOX_RESULT_STYLE_ALREADY_LOADED_ERROR = 26, /** * 無効なモデルデータ */ - VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR = 5, + VOICEVOX_RESULT_INVALID_MODEL_DATA_ERROR = 27, /** * Modelが読み込まれていない */