Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#830 の設計をUserDictにも #834

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ link-onnxruntime = []

[dependencies]
anyhow.workspace = true
async-fs.workspace = true # 今これを使っている箇所はどこにも無いが、`UserDict`にはこれを使った方がよいはず
async-fs.workspace = true
async-lock.workspace = true
async_zip = { workspace = true, features = ["deflate"] }
blocking.workspace = true
Expand Down
20 changes: 20 additions & 0 deletions crates/voicevox_core/src/asyncs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ pub(crate) trait Async: 'static {
/// `io::Error`は素(`i32`相当)のままにしておき、この関数を呼び出す側でfs-err風のメッセージを付
/// ける。
async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile>;

async fn read(path: impl AsRef<Path>) -> io::Result<Vec<u8>>;

async fn write(path: impl AsRef<Path>, content: impl AsRef<[u8]>) -> io::Result<()>;
}

pub(crate) trait Mutex<T>: From<T> + Send + Sync + Unpin {
Expand All @@ -59,6 +63,14 @@ impl Async for SingleTasked {
async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile> {
std::fs::File::open(path).map(StdFile)
}

async fn read(path: impl AsRef<Path>) -> io::Result<Vec<u8>> {
std::fs::read(path)
}

async fn write(path: impl AsRef<Path>, content: impl AsRef<[u8]>) -> io::Result<()> {
std::fs::write(path, content)
}
}

pub(crate) struct StdMutex<T>(std::sync::Mutex<T>);
Expand Down Expand Up @@ -111,6 +123,14 @@ impl Async for BlockingThreadPool {
async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile> {
AsyncRoFile::open(path).await
}

async fn read(path: impl AsRef<Path>) -> io::Result<Vec<u8>> {
async_fs::read(path).await
}

async fn write(path: impl AsRef<Path>, content: impl AsRef<[u8]>) -> io::Result<()> {
async_fs::write(path, content).await
}
}

impl<T: Send + Sync + Unpin> Mutex<T> for async_lock::Mutex<T> {
Expand Down
204 changes: 132 additions & 72 deletions crates/voicevox_core/src/user_dict/dict.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,141 @@
// TODO: `VoiceModelFile`のように、次のような設計にする。
//
// ```
// pub(crate) mod blocking {
// pub struct UserDict(Inner<SingleTasked>);
// // …
// }
// pub(crate) mod nonblocking {
// pub struct UserDict(Inner<BlockingThreadPool>);
// // …
// }
// ```
use std::{marker::PhantomData, path::Path};

use anyhow::Context as _;
use easy_ext::ext;
use educe::Educe;
use indexmap::IndexMap;
use itertools::Itertools as _;
use uuid::Uuid;

use crate::{asyncs::Async, error::ErrorRepr};

use super::UserDictWord;

#[derive(Educe)]
#[educe(Default(bound = "A:"))]
#[educe(Debug(bound = "A:"))]
struct Inner<A> {
words: std::sync::Mutex<IndexMap<Uuid, UserDictWord>>,
_marker: PhantomData<A>,
}

impl<A: Async> Inner<A> {
fn to_json(&self) -> String {
self.with_words(|words| serde_json::to_string(words).expect("should not fail"))
}

fn with_words<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut IndexMap<Uuid, UserDictWord>) -> R,
{
f(&mut self.words.lock().unwrap_or_else(|e| panic!("{e}")))
}

async fn load(&self, store_path: &str) -> crate::Result<()> {
let words = async {
let words = &A::fs_err_read(store_path).await?;
let words = serde_json::from_slice::<IndexMap<_, _>>(words)?;
Ok(words)
}
.await
.map_err(ErrorRepr::LoadUserDict)?;

self.with_words(|words_| words_.extend(words));
Ok(())
}

fn add_word(&self, word: UserDictWord) -> crate::Result<Uuid> {
let word_uuid = Uuid::new_v4();
self.with_words(|word_| word_.insert(word_uuid, word));
Ok(word_uuid)
}

fn update_word(&self, word_uuid: Uuid, new_word: UserDictWord) -> crate::Result<()> {
self.with_words(|words| {
if !words.contains_key(&word_uuid) {
return Err(ErrorRepr::WordNotFound(word_uuid).into());
}
words.insert(word_uuid, new_word);
Ok(())
})
}

fn remove_word(&self, word_uuid: Uuid) -> crate::Result<UserDictWord> {
let Some(word) = self.with_words(|words| words.remove(&word_uuid)) else {
return Err(ErrorRepr::WordNotFound(word_uuid).into());
};
Ok(word)
}

fn import(&self, other: &Self) -> crate::Result<()> {
self.with_words(|self_words| {
other.with_words(|other_words| {
for (word_uuid, word) in other_words {
self_words.insert(*word_uuid, word.clone());
}
Ok(())
})
})
}

async fn save(&self, store_path: &str) -> crate::Result<()> {
A::fs_err_write(
store_path,
serde_json::to_vec(&self.words).expect("should not fail"),
)
.await
.map_err(ErrorRepr::SaveUserDict)
.map_err(Into::into)
}

fn to_mecab_format(&self) -> String {
self.with_words(|words| words.values().map(UserDictWord::to_mecab_format).join("\n"))
}
}

#[ext]
impl<A: Async> A {
async fn fs_err_read(path: impl AsRef<Path>) -> anyhow::Result<Vec<u8>> {
Self::read(&path)
.await
.with_context(|| format!("failed to read from file `{}`", path.as_ref().display()))
}

async fn fs_err_write(path: impl AsRef<Path>, content: impl AsRef<[u8]>) -> anyhow::Result<()> {
Self::write(&path, content)
.await
.with_context(|| format!("failed to write to file `{}`", path.as_ref().display()))
}
}

pub(crate) mod blocking {
use indexmap::IndexMap;
use itertools::join;
use uuid::Uuid;

use crate::{error::ErrorRepr, Result};
use crate::{asyncs::SingleTasked, future::FutureExt as _, Result};

use super::super::word::UserDictWord;
use super::{super::word::UserDictWord, Inner};

/// ユーザー辞書。
///
/// 単語はJSONとの相互変換のために挿入された順序を保つ。
#[derive(Debug, Default)]
pub struct UserDict {
words: std::sync::Mutex<IndexMap<Uuid, UserDictWord>>,
}
pub struct UserDict(Inner<SingleTasked>);

// TODO: 引数の`path`は全部`AsRef<Path>`にする
impl self::UserDict {
/// ユーザー辞書を作成する。
pub fn new() -> Self {
Default::default()
}

pub fn to_json(&self) -> String {
serde_json::to_string(&*self.words.lock().unwrap()).expect("should not fail")
self.0.to_json()
}

// TODO: `&mut IndexMap<_>`を取れるようにする
pub fn with_words<R>(&self, f: impl FnOnce(&IndexMap<Uuid, UserDictWord>) -> R) -> R {
f(&self.words.lock().unwrap())
self.0.with_words(|words| f(words))
}

/// ユーザー辞書をファイルから読み込む。
Expand All @@ -48,82 +144,48 @@ pub(crate) mod blocking {
///
/// ファイルが読めなかった、または内容が不正だった場合はエラーを返す。
pub fn load(&self, store_path: &str) -> Result<()> {
let words = (|| {
let words = &fs_err::read(store_path)?;
let words = serde_json::from_slice::<IndexMap<_, _>>(words)?;
Ok(words)
})()
.map_err(ErrorRepr::LoadUserDict)?;

self.words.lock().unwrap().extend(words);
Ok(())
self.0.load(store_path).block_on()
}

/// ユーザー辞書に単語を追加する。
pub fn add_word(&self, word: UserDictWord) -> Result<Uuid> {
let word_uuid = Uuid::new_v4();
self.words.lock().unwrap().insert(word_uuid, word);
Ok(word_uuid)
self.0.add_word(word)
}

/// ユーザー辞書の単語を変更する。
pub fn update_word(&self, word_uuid: Uuid, new_word: UserDictWord) -> Result<()> {
let mut words = self.words.lock().unwrap();
if !words.contains_key(&word_uuid) {
return Err(ErrorRepr::WordNotFound(word_uuid).into());
}
words.insert(word_uuid, new_word);
Ok(())
self.0.update_word(word_uuid, new_word)
}

/// ユーザー辞書から単語を削除する。
pub fn remove_word(&self, word_uuid: Uuid) -> Result<UserDictWord> {
let Some(word) = self.words.lock().unwrap().remove(&word_uuid) else {
return Err(ErrorRepr::WordNotFound(word_uuid).into());
};
Ok(word)
self.0.remove_word(word_uuid)
}

/// 他のユーザー辞書をインポートする。
pub fn import(&self, other: &Self) -> Result<()> {
for (word_uuid, word) in &*other.words.lock().unwrap() {
self.words.lock().unwrap().insert(*word_uuid, word.clone());
}
Ok(())
self.0.import(&other.0)
}

/// ユーザー辞書を保存する。
pub fn save(&self, store_path: &str) -> Result<()> {
fs_err::write(
store_path,
serde_json::to_vec(&self.words).expect("should not fail"),
)
.map_err(|e| ErrorRepr::SaveUserDict(e.into()).into())
self.0.save(store_path).block_on()
}

/// MeCabで使用する形式に変換する。
pub(crate) fn to_mecab_format(&self) -> String {
join(
self.words
.lock()
.unwrap()
.values()
.map(UserDictWord::to_mecab_format),
"\n",
)
self.0.to_mecab_format()
}
}
}

pub(crate) mod nonblocking {
use std::sync::Arc;

use indexmap::IndexMap;
use uuid::Uuid;

use crate::Result;
use crate::{asyncs::BlockingThreadPool, Result};

use super::super::word::UserDictWord;
use super::{super::word::UserDictWord, Inner};

/// ユーザー辞書。
///
Expand All @@ -136,20 +198,22 @@ pub(crate) mod nonblocking {
/// [blocking]: https://docs.rs/crate/blocking
/// [`nonblocking`モジュールのドキュメント]: crate::nonblocking
#[derive(Debug, Default)]
pub struct UserDict(Arc<super::blocking::UserDict>);
pub struct UserDict(Inner<BlockingThreadPool>);

// TODO: 引数の`path`は全部`AsRef<Path>`にする
impl self::UserDict {
/// ユーザー辞書を作成する。
pub fn new() -> Self {
Self(super::blocking::UserDict::new().into())
Default::default()
}

pub fn to_json(&self) -> String {
self.0.to_json()
}

// TODO: `&mut IndexMap<_>`を取れるようにする
pub fn with_words<R>(&self, f: impl FnOnce(&IndexMap<Uuid, UserDictWord>) -> R) -> R {
self.0.with_words(f)
self.0.with_words(|words| f(words))
}

/// ユーザー辞書をファイルから読み込む。
Expand All @@ -158,9 +222,7 @@ pub(crate) mod nonblocking {
///
/// ファイルが読めなかった、または内容が不正だった場合はエラーを返す。
pub async fn load(&self, store_path: &str) -> Result<()> {
let blocking = self.0.clone();
let store_path = store_path.to_owned();
crate::task::asyncify(move || blocking.load(&store_path)).await
self.0.load(store_path).await
}

/// ユーザー辞書に単語を追加する。
Expand All @@ -185,9 +247,7 @@ pub(crate) mod nonblocking {

/// ユーザー辞書を保存する。
pub async fn save(&self, store_path: &str) -> Result<()> {
let blocking = self.0.clone();
let store_path = store_path.to_owned();
crate::task::asyncify(move || blocking.save(&store_path)).await
self.0.save(store_path).await
}

/// MeCabで使用する形式に変換する。
Expand Down
Loading