From 93d10aee7f308511c3dacf2537b372781374aeeb Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 24 Sep 2024 10:42:34 +0900 Subject: [PATCH] =?UTF-8?q?#830=20=E3=81=AE=E8=A8=AD=E8=A8=88=E3=82=92`Use?= =?UTF-8?q?rDict`=E3=81=AB=E3=82=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/Cargo.toml | 2 +- crates/voicevox_core/src/asyncs.rs | 20 ++ crates/voicevox_core/src/user_dict/dict.rs | 204 +++++++++++++-------- 3 files changed, 153 insertions(+), 73 deletions(-) diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index d5fb28321..8390e1de3 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -16,7 +16,7 @@ link-onnxruntime = [] [dependencies] anyhow.workspace = true -async-fs.workspace = true # 今これを使っている箇所はどこにも無いが、`UserDict`にはこれを使った方がよいはず +async-fs.workspace = true async-lock.workspace = true async_zip = { workspace = true, features = ["deflate"] } blocking.workspace = true diff --git a/crates/voicevox_core/src/asyncs.rs b/crates/voicevox_core/src/asyncs.rs index d89aa7d4b..24fdd82b4 100644 --- a/crates/voicevox_core/src/asyncs.rs +++ b/crates/voicevox_core/src/asyncs.rs @@ -34,6 +34,10 @@ pub(crate) trait Async: 'static { /// `io::Error`は素(`i32`相当)のままにしておき、この関数を呼び出す側でfs-err風のメッセージを付 /// ける。 async fn open_file_ro(path: impl AsRef) -> io::Result; + + async fn read(path: impl AsRef) -> io::Result>; + + async fn write(path: impl AsRef, content: impl AsRef<[u8]>) -> io::Result<()>; } pub(crate) trait Mutex: From + Send + Sync + Unpin { @@ -59,6 +63,14 @@ impl Async for SingleTasked { async fn open_file_ro(path: impl AsRef) -> io::Result { std::fs::File::open(path).map(StdFile) } + + async fn read(path: impl AsRef) -> io::Result> { + std::fs::read(path) + } + + async fn write(path: impl AsRef, content: impl AsRef<[u8]>) -> io::Result<()> { + std::fs::write(path, content) + } } pub(crate) struct StdMutex(std::sync::Mutex); @@ -111,6 +123,14 @@ impl Async for BlockingThreadPool { async fn open_file_ro(path: impl AsRef) -> io::Result { AsyncRoFile::open(path).await } + + async fn read(path: impl AsRef) -> io::Result> { + async_fs::read(path).await + } + + async fn write(path: impl AsRef, content: impl AsRef<[u8]>) -> io::Result<()> { + async_fs::write(path, content).await + } } impl Mutex for async_lock::Mutex { diff --git a/crates/voicevox_core/src/user_dict/dict.rs b/crates/voicevox_core/src/user_dict/dict.rs index 0e1c89ca2..2f851b138 100644 --- a/crates/voicevox_core/src/user_dict/dict.rs +++ b/crates/voicevox_core/src/user_dict/dict.rs @@ -1,33 +1,128 @@ -// TODO: `VoiceModelFile`のように、次のような設計にする。 -// -// ``` -// pub(crate) mod blocking { -// pub struct UserDict(Inner); -// // … -// } -// pub(crate) mod nonblocking { -// pub struct UserDict(Inner); -// // … -// } -// ``` +use std::{marker::PhantomData, path::Path}; + +use anyhow::Context as _; +use easy_ext::ext; +use educe::Educe; +use indexmap::IndexMap; +use itertools::Itertools as _; +use uuid::Uuid; + +use crate::{asyncs::Async, error::ErrorRepr}; + +use super::UserDictWord; + +#[derive(Educe)] +#[educe(Default(bound = "A:"))] +#[educe(Debug(bound = "A:"))] +struct Inner { + words: std::sync::Mutex>, + _marker: PhantomData, +} + +impl Inner { + fn to_json(&self) -> String { + self.with_words(|words| serde_json::to_string(words).expect("should not fail")) + } + + fn with_words(&self, f: F) -> R + where + F: FnOnce(&mut IndexMap) -> R, + { + f(&mut self.words.lock().unwrap_or_else(|e| panic!("{e}"))) + } + + async fn load(&self, store_path: &str) -> crate::Result<()> { + let words = async { + let words = &A::fs_err_read(store_path).await?; + let words = serde_json::from_slice::>(words)?; + Ok(words) + } + .await + .map_err(ErrorRepr::LoadUserDict)?; + + self.with_words(|words_| words_.extend(words)); + Ok(()) + } + + fn add_word(&self, word: UserDictWord) -> crate::Result { + let word_uuid = Uuid::new_v4(); + self.with_words(|word_| word_.insert(word_uuid, word)); + Ok(word_uuid) + } + + fn update_word(&self, word_uuid: Uuid, new_word: UserDictWord) -> crate::Result<()> { + self.with_words(|words| { + if !words.contains_key(&word_uuid) { + return Err(ErrorRepr::WordNotFound(word_uuid).into()); + } + words.insert(word_uuid, new_word); + Ok(()) + }) + } + + fn remove_word(&self, word_uuid: Uuid) -> crate::Result { + let Some(word) = self.with_words(|words| words.remove(&word_uuid)) else { + return Err(ErrorRepr::WordNotFound(word_uuid).into()); + }; + Ok(word) + } + + fn import(&self, other: &Self) -> crate::Result<()> { + self.with_words(|self_words| { + other.with_words(|other_words| { + for (word_uuid, word) in other_words { + self_words.insert(*word_uuid, word.clone()); + } + Ok(()) + }) + }) + } + + async fn save(&self, store_path: &str) -> crate::Result<()> { + A::fs_err_write( + store_path, + serde_json::to_vec(&self.words).expect("should not fail"), + ) + .await + .map_err(ErrorRepr::SaveUserDict) + .map_err(Into::into) + } + + fn to_mecab_format(&self) -> String { + self.with_words(|words| words.values().map(UserDictWord::to_mecab_format).join("\n")) + } +} + +#[ext] +impl A { + async fn fs_err_read(path: impl AsRef) -> anyhow::Result> { + Self::read(&path) + .await + .with_context(|| format!("failed to read from file `{}`", path.as_ref().display())) + } + + async fn fs_err_write(path: impl AsRef, content: impl AsRef<[u8]>) -> anyhow::Result<()> { + Self::write(&path, content) + .await + .with_context(|| format!("failed to write to file `{}`", path.as_ref().display())) + } +} pub(crate) mod blocking { use indexmap::IndexMap; - use itertools::join; use uuid::Uuid; - use crate::{error::ErrorRepr, Result}; + use crate::{asyncs::SingleTasked, future::FutureExt as _, Result}; - use super::super::word::UserDictWord; + use super::{super::word::UserDictWord, Inner}; /// ユーザー辞書。 /// /// 単語はJSONとの相互変換のために挿入された順序を保つ。 #[derive(Debug, Default)] - pub struct UserDict { - words: std::sync::Mutex>, - } + pub struct UserDict(Inner); + // TODO: 引数の`path`は全部`AsRef`にする impl self::UserDict { /// ユーザー辞書を作成する。 pub fn new() -> Self { @@ -35,11 +130,12 @@ pub(crate) mod blocking { } pub fn to_json(&self) -> String { - serde_json::to_string(&*self.words.lock().unwrap()).expect("should not fail") + self.0.to_json() } + // TODO: `&mut IndexMap<_>`を取れるようにする pub fn with_words(&self, f: impl FnOnce(&IndexMap) -> R) -> R { - f(&self.words.lock().unwrap()) + self.0.with_words(|words| f(words)) } /// ユーザー辞書をファイルから読み込む。 @@ -48,82 +144,48 @@ pub(crate) mod blocking { /// /// ファイルが読めなかった、または内容が不正だった場合はエラーを返す。 pub fn load(&self, store_path: &str) -> Result<()> { - let words = (|| { - let words = &fs_err::read(store_path)?; - let words = serde_json::from_slice::>(words)?; - Ok(words) - })() - .map_err(ErrorRepr::LoadUserDict)?; - - self.words.lock().unwrap().extend(words); - Ok(()) + self.0.load(store_path).block_on() } /// ユーザー辞書に単語を追加する。 pub fn add_word(&self, word: UserDictWord) -> Result { - let word_uuid = Uuid::new_v4(); - self.words.lock().unwrap().insert(word_uuid, word); - Ok(word_uuid) + self.0.add_word(word) } /// ユーザー辞書の単語を変更する。 pub fn update_word(&self, word_uuid: Uuid, new_word: UserDictWord) -> Result<()> { - let mut words = self.words.lock().unwrap(); - if !words.contains_key(&word_uuid) { - return Err(ErrorRepr::WordNotFound(word_uuid).into()); - } - words.insert(word_uuid, new_word); - Ok(()) + self.0.update_word(word_uuid, new_word) } /// ユーザー辞書から単語を削除する。 pub fn remove_word(&self, word_uuid: Uuid) -> Result { - let Some(word) = self.words.lock().unwrap().remove(&word_uuid) else { - return Err(ErrorRepr::WordNotFound(word_uuid).into()); - }; - Ok(word) + self.0.remove_word(word_uuid) } /// 他のユーザー辞書をインポートする。 pub fn import(&self, other: &Self) -> Result<()> { - for (word_uuid, word) in &*other.words.lock().unwrap() { - self.words.lock().unwrap().insert(*word_uuid, word.clone()); - } - Ok(()) + self.0.import(&other.0) } /// ユーザー辞書を保存する。 pub fn save(&self, store_path: &str) -> Result<()> { - fs_err::write( - store_path, - serde_json::to_vec(&self.words).expect("should not fail"), - ) - .map_err(|e| ErrorRepr::SaveUserDict(e.into()).into()) + self.0.save(store_path).block_on() } /// MeCabで使用する形式に変換する。 pub(crate) fn to_mecab_format(&self) -> String { - join( - self.words - .lock() - .unwrap() - .values() - .map(UserDictWord::to_mecab_format), - "\n", - ) + self.0.to_mecab_format() } } } pub(crate) mod nonblocking { - use std::sync::Arc; - use indexmap::IndexMap; use uuid::Uuid; - use crate::Result; + use crate::{asyncs::BlockingThreadPool, Result}; - use super::super::word::UserDictWord; + use super::{super::word::UserDictWord, Inner}; /// ユーザー辞書。 /// @@ -136,20 +198,22 @@ pub(crate) mod nonblocking { /// [blocking]: https://docs.rs/crate/blocking /// [`nonblocking`モジュールのドキュメント]: crate::nonblocking #[derive(Debug, Default)] - pub struct UserDict(Arc); + pub struct UserDict(Inner); + // TODO: 引数の`path`は全部`AsRef`にする impl self::UserDict { /// ユーザー辞書を作成する。 pub fn new() -> Self { - Self(super::blocking::UserDict::new().into()) + Default::default() } pub fn to_json(&self) -> String { self.0.to_json() } + // TODO: `&mut IndexMap<_>`を取れるようにする pub fn with_words(&self, f: impl FnOnce(&IndexMap) -> R) -> R { - self.0.with_words(f) + self.0.with_words(|words| f(words)) } /// ユーザー辞書をファイルから読み込む。 @@ -158,9 +222,7 @@ pub(crate) mod nonblocking { /// /// ファイルが読めなかった、または内容が不正だった場合はエラーを返す。 pub async fn load(&self, store_path: &str) -> Result<()> { - let blocking = self.0.clone(); - let store_path = store_path.to_owned(); - crate::task::asyncify(move || blocking.load(&store_path)).await + self.0.load(store_path).await } /// ユーザー辞書に単語を追加する。 @@ -185,9 +247,7 @@ pub(crate) mod nonblocking { /// ユーザー辞書を保存する。 pub async fn save(&self, store_path: &str) -> Result<()> { - let blocking = self.0.clone(); - let store_path = store_path.to_owned(); - crate::task::asyncify(move || blocking.save(&store_path)).await + self.0.save(store_path).await } /// MeCabで使用する形式に変換する。