Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python APIのSynthesizerclose()可能にする #555

Merged
merged 10 commits into from
Aug 15, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
``Synthesizer`` について、(広義の)RAIIができることをテストする。
"""

import conftest
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
import pytest
import pytest_asyncio
from voicevox_core import OpenJtalk, Synthesizer, VoicevoxError


def test_enter_returns_workable_self(synthesizer: Synthesizer) -> None:
with synthesizer as ctx:
assert ctx is synthesizer
_ = synthesizer.metas


def test_closing_multiple_times_is_allowed(synthesizer: Synthesizer) -> None:
with synthesizer:
with synthesizer:
pass
synthesizer.close()
synthesizer.close()


def test_access_after_close_denied(synthesizer: Synthesizer) -> None:
synthesizer.close()
with pytest.raises(VoicevoxError, match="^The `Synthesizer` is closed$"):
_ = synthesizer.metas


def test_access_after_exit_denied(synthesizer: Synthesizer) -> None:
with synthesizer:
pass
with pytest.raises(VoicevoxError, match="^The `Synthesizer` is closed$"):
_ = synthesizer.metas


@pytest_asyncio.fixture
async def synthesizer(open_jtalk: OpenJtalk) -> Synthesizer:
return await Synthesizer.new_with_initialize(open_jtalk)


@pytest.fixture(scope="module")
def open_jtalk() -> OpenJtalk:
return OpenJtalk(conftest.open_jtalk_dic_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class Synthesizer:
"""
...
def __repr__(self) -> str: ...
def __enter__(self) -> "Synthesizer": ...
def __exit__(self, exc_type, exc_value, traceback) -> None: ...
@property
def is_gpu_mode(self) -> bool:
"""ハードウェアアクセラレーションがGPUモードかどうか。"""
Expand Down Expand Up @@ -219,6 +221,7 @@ class Synthesizer:
:returns: WAVデータ。
"""
...
def close(self) -> None: ...

class UserDict:
"""ユーザー辞書。
Expand Down
113 changes: 85 additions & 28 deletions crates/voicevox_core_python_api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{marker::PhantomData, sync::Arc};

mod convert;
use convert::*;
Expand All @@ -9,7 +9,7 @@ use pyo3::{
exceptions::PyException,
pyclass, pyfunction, pymethods, pymodule,
types::{IntoPyDict as _, PyBytes, PyDict, PyList, PyModule},
wrap_pyfunction, PyAny, PyObject, PyResult, Python, ToPyObject,
wrap_pyfunction, PyAny, PyObject, PyRef, PyResult, PyTypeInfo, Python, ToPyObject,
};
use tokio::{runtime::Runtime, sync::Mutex};
use uuid::Uuid;
Expand Down Expand Up @@ -114,7 +114,7 @@ impl OpenJtalk {

#[pyclass]
struct Synthesizer {
synthesizer: Arc<Mutex<voicevox_core::Synthesizer>>,
synthesizer: Closable<Arc<Mutex<voicevox_core::Synthesizer>>, Self>,
}

#[pymethods]
Expand Down Expand Up @@ -143,9 +143,10 @@ impl Synthesizer {
},
)
.await
.into_py_result()?;
.into_py_result()?
.into();
Ok(Self {
synthesizer: Arc::new(Mutex::new(synthesizer)),
synthesizer: Closable::new(Arc::new(synthesizer)),
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
})
})
}
Expand All @@ -154,14 +155,30 @@ impl Synthesizer {
"Synthesizer { .. }"
}

fn __enter__(slf: PyRef<'_, Self>) -> PyResult<PyRef<'_, Self>> {
slf.synthesizer.get()?;
Ok(slf)
}

fn __exit__(
&mut self,
#[allow(unused_variables)] exc_type: &PyAny,
#[allow(unused_variables)] exc_value: &PyAny,
#[allow(unused_variables)] traceback: &PyAny,
) {
self.close();
}

#[getter]
fn is_gpu_mode(&self) -> bool {
RUNTIME.block_on(self.synthesizer.lock()).is_gpu_mode()
fn is_gpu_mode(&self) -> PyResult<bool> {
let synthesizer = self.synthesizer.get()?;
Ok(RUNTIME.block_on(synthesizer.lock()).is_gpu_mode())
}

#[getter]
fn metas<'py>(&self, py: Python<'py>) -> Vec<&'py PyAny> {
to_pydantic_voice_model_meta(RUNTIME.block_on(self.synthesizer.lock()).metas(), py).unwrap()
fn metas<'py>(&self, py: Python<'py>) -> PyResult<Vec<&'py PyAny>> {
let synthesizer = self.synthesizer.get()?;
to_pydantic_voice_model_meta(RUNTIME.block_on(synthesizer.lock()).metas(), py)
}

fn load_voice_model<'py>(
Expand All @@ -170,7 +187,7 @@ impl Synthesizer {
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let model: VoiceModel = model.extract()?;
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
pyo3_asyncio::tokio::future_into_py(py, async move {
synthesizer
.lock()
Expand All @@ -183,15 +200,15 @@ impl Synthesizer {

fn unload_voice_model(&mut self, voice_model_id: &str) -> PyResult<()> {
RUNTIME
.block_on(self.synthesizer.lock())
.block_on(self.synthesizer.get()?.lock())
.unload_voice_model(&VoiceModelId::new(voice_model_id.to_string()))
.into_py_result()
}

fn is_loaded_voice_model(&self, voice_model_id: &str) -> bool {
RUNTIME
.block_on(self.synthesizer.lock())
.is_loaded_voice_model(&VoiceModelId::new(voice_model_id.to_string()))
fn is_loaded_voice_model(&self, voice_model_id: &str) -> PyResult<bool> {
Ok(RUNTIME
.block_on(self.synthesizer.get()?.lock())
.is_loaded_voice_model(&VoiceModelId::new(voice_model_id.to_string())))
}

#[pyo3(signature=(text,style_id,kana = AudioQueryOptions::default().kana))]
Expand All @@ -202,7 +219,7 @@ impl Synthesizer {
kana: bool,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
let text = text.to_owned();
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
Expand Down Expand Up @@ -232,7 +249,7 @@ impl Synthesizer {
kana: bool,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
let text = text.to_owned();
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
Expand Down Expand Up @@ -267,7 +284,7 @@ impl Synthesizer {
style_id: u32,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
modify_accent_phrases(
accent_phrases,
StyleId::new(style_id),
Expand All @@ -282,7 +299,7 @@ impl Synthesizer {
style_id: u32,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
modify_accent_phrases(
accent_phrases,
StyleId::new(style_id),
Expand All @@ -297,7 +314,7 @@ impl Synthesizer {
style_id: u32,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
modify_accent_phrases(
accent_phrases,
StyleId::new(style_id),
Expand All @@ -314,7 +331,7 @@ impl Synthesizer {
enable_interrogative_upspeak: bool,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
pyo3_asyncio::tokio::get_current_locals(py)?,
Expand Down Expand Up @@ -355,7 +372,7 @@ impl Synthesizer {
kana,
enable_interrogative_upspeak,
};
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
let text = text.to_owned();
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
Expand All @@ -371,6 +388,52 @@ impl Synthesizer {
},
)
}

fn close(&mut self) {
self.synthesizer.close()
}
}

struct Closable<T, C: PyTypeInfo> {
content: MaybeClosed<T>,
marker: PhantomData<C>,
}
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved

enum MaybeClosed<T> {
Open(T),
Closed,
}

impl<T, C: PyTypeInfo> Closable<T, C> {
fn new(content: T) -> Self {
Self {
content: MaybeClosed::Open(content),
marker: PhantomData,
}
}

fn get(&self) -> PyResult<&T> {
match &self.content {
MaybeClosed::Open(content) => Ok(content),
MaybeClosed::Closed => Err(VoicevoxError::new_err(format!(
"The `{}` is closed",
C::NAME,
))),
}
}

fn close(&mut self) {
if matches!(self.content, MaybeClosed::Open(_)) {
debug!("Closing a {}", C::NAME);
}
self.content = MaybeClosed::Closed;
}
}

impl<T, C: PyTypeInfo> Drop for Closable<T, C> {
fn drop(&mut self) {
self.close();
}
}

#[pyfunction]
Expand Down Expand Up @@ -451,9 +514,3 @@ impl UserDict {
Ok(words.into_py_dict(py))
}
}

impl Drop for Synthesizer {
fn drop(&mut self) {
debug!("Destructing a VoicevoxCore");
}
}