Skip to content

Commit

Permalink
Python APIのSynthesizerclose()可能にする (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip authored Aug 15, 2023
1 parent 263beb4 commit 2ffd87e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
``Synthesizer`` について、(広義の)RAIIができることをテストする。
"""

import conftest
import pytest
import pytest_asyncio
from voicevox_core import OpenJtalk, Synthesizer, VoicevoxError


def test_enter_returns_workable_self(synthesizer: Synthesizer) -> None:
with synthesizer as ctx:
assert ctx is synthesizer
_ = synthesizer.metas


def test_closing_multiple_times_is_allowed(synthesizer: Synthesizer) -> None:
with synthesizer:
with synthesizer:
pass
synthesizer.close()
synthesizer.close()


def test_access_after_close_denied(synthesizer: Synthesizer) -> None:
synthesizer.close()
with pytest.raises(VoicevoxError, match="^The `Synthesizer` is closed$"):
_ = synthesizer.metas


def test_access_after_exit_denied(synthesizer: Synthesizer) -> None:
with synthesizer:
pass
with pytest.raises(VoicevoxError, match="^The `Synthesizer` is closed$"):
_ = synthesizer.metas


@pytest_asyncio.fixture
async def synthesizer(open_jtalk: OpenJtalk) -> Synthesizer:
return await Synthesizer.new_with_initialize(open_jtalk)


@pytest.fixture(scope="module")
def open_jtalk() -> OpenJtalk:
return OpenJtalk(conftest.open_jtalk_dic_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class Synthesizer:
"""
...
def __repr__(self) -> str: ...
def __enter__(self) -> "Synthesizer": ...
def __exit__(self, exc_type, exc_value, traceback) -> None: ...
@property
def is_gpu_mode(self) -> bool:
"""ハードウェアアクセラレーションがGPUモードかどうか。"""
Expand Down Expand Up @@ -219,6 +221,7 @@ class Synthesizer:
:returns: WAVデータ。
"""
...
def close(self) -> None: ...

class UserDict:
"""ユーザー辞書。
Expand Down
113 changes: 85 additions & 28 deletions crates/voicevox_core_python_api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{marker::PhantomData, sync::Arc};

mod convert;
use convert::*;
Expand All @@ -9,7 +9,7 @@ use pyo3::{
exceptions::PyException,
pyclass, pyfunction, pymethods, pymodule,
types::{IntoPyDict as _, PyBytes, PyDict, PyList, PyModule},
wrap_pyfunction, PyAny, PyObject, PyResult, Python, ToPyObject,
wrap_pyfunction, PyAny, PyObject, PyRef, PyResult, PyTypeInfo, Python, ToPyObject,
};
use tokio::{runtime::Runtime, sync::Mutex};
use uuid::Uuid;
Expand Down Expand Up @@ -114,7 +114,7 @@ impl OpenJtalk {

#[pyclass]
struct Synthesizer {
synthesizer: Arc<Mutex<voicevox_core::Synthesizer>>,
synthesizer: Closable<Arc<Mutex<voicevox_core::Synthesizer>>, Self>,
}

#[pymethods]
Expand Down Expand Up @@ -143,9 +143,10 @@ impl Synthesizer {
},
)
.await
.into_py_result()?;
.into_py_result()?
.into();
Ok(Self {
synthesizer: Arc::new(Mutex::new(synthesizer)),
synthesizer: Closable::new(Arc::new(synthesizer)),
})
})
}
Expand All @@ -154,14 +155,30 @@ impl Synthesizer {
"Synthesizer { .. }"
}

fn __enter__(slf: PyRef<'_, Self>) -> PyResult<PyRef<'_, Self>> {
slf.synthesizer.get()?;
Ok(slf)
}

fn __exit__(
&mut self,
#[allow(unused_variables)] exc_type: &PyAny,
#[allow(unused_variables)] exc_value: &PyAny,
#[allow(unused_variables)] traceback: &PyAny,
) {
self.close();
}

#[getter]
fn is_gpu_mode(&self) -> bool {
RUNTIME.block_on(self.synthesizer.lock()).is_gpu_mode()
fn is_gpu_mode(&self) -> PyResult<bool> {
let synthesizer = self.synthesizer.get()?;
Ok(RUNTIME.block_on(synthesizer.lock()).is_gpu_mode())
}

#[getter]
fn metas<'py>(&self, py: Python<'py>) -> Vec<&'py PyAny> {
to_pydantic_voice_model_meta(RUNTIME.block_on(self.synthesizer.lock()).metas(), py).unwrap()
fn metas<'py>(&self, py: Python<'py>) -> PyResult<Vec<&'py PyAny>> {
let synthesizer = self.synthesizer.get()?;
to_pydantic_voice_model_meta(RUNTIME.block_on(synthesizer.lock()).metas(), py)
}

fn load_voice_model<'py>(
Expand All @@ -170,7 +187,7 @@ impl Synthesizer {
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let model: VoiceModel = model.extract()?;
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
pyo3_asyncio::tokio::future_into_py(py, async move {
synthesizer
.lock()
Expand All @@ -183,15 +200,15 @@ impl Synthesizer {

fn unload_voice_model(&mut self, voice_model_id: &str) -> PyResult<()> {
RUNTIME
.block_on(self.synthesizer.lock())
.block_on(self.synthesizer.get()?.lock())
.unload_voice_model(&VoiceModelId::new(voice_model_id.to_string()))
.into_py_result()
}

fn is_loaded_voice_model(&self, voice_model_id: &str) -> bool {
RUNTIME
.block_on(self.synthesizer.lock())
.is_loaded_voice_model(&VoiceModelId::new(voice_model_id.to_string()))
fn is_loaded_voice_model(&self, voice_model_id: &str) -> PyResult<bool> {
Ok(RUNTIME
.block_on(self.synthesizer.get()?.lock())
.is_loaded_voice_model(&VoiceModelId::new(voice_model_id.to_string())))
}

#[pyo3(signature=(text,style_id,kana = AudioQueryOptions::default().kana))]
Expand All @@ -202,7 +219,7 @@ impl Synthesizer {
kana: bool,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
let text = text.to_owned();
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
Expand Down Expand Up @@ -232,7 +249,7 @@ impl Synthesizer {
kana: bool,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
let text = text.to_owned();
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
Expand Down Expand Up @@ -267,7 +284,7 @@ impl Synthesizer {
style_id: u32,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
modify_accent_phrases(
accent_phrases,
StyleId::new(style_id),
Expand All @@ -282,7 +299,7 @@ impl Synthesizer {
style_id: u32,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
modify_accent_phrases(
accent_phrases,
StyleId::new(style_id),
Expand All @@ -297,7 +314,7 @@ impl Synthesizer {
style_id: u32,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
modify_accent_phrases(
accent_phrases,
StyleId::new(style_id),
Expand All @@ -314,7 +331,7 @@ impl Synthesizer {
enable_interrogative_upspeak: bool,
py: Python<'py>,
) -> PyResult<&'py PyAny> {
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
pyo3_asyncio::tokio::get_current_locals(py)?,
Expand Down Expand Up @@ -355,7 +372,7 @@ impl Synthesizer {
kana,
enable_interrogative_upspeak,
};
let synthesizer = self.synthesizer.clone();
let synthesizer = self.synthesizer.get()?.clone();
let text = text.to_owned();
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
Expand All @@ -371,6 +388,52 @@ impl Synthesizer {
},
)
}

fn close(&mut self) {
self.synthesizer.close()
}
}

struct Closable<T, C: PyTypeInfo> {
content: MaybeClosed<T>,
marker: PhantomData<C>,
}

enum MaybeClosed<T> {
Open(T),
Closed,
}

impl<T, C: PyTypeInfo> Closable<T, C> {
fn new(content: T) -> Self {
Self {
content: MaybeClosed::Open(content),
marker: PhantomData,
}
}

fn get(&self) -> PyResult<&T> {
match &self.content {
MaybeClosed::Open(content) => Ok(content),
MaybeClosed::Closed => Err(VoicevoxError::new_err(format!(
"The `{}` is closed",
C::NAME,
))),
}
}

fn close(&mut self) {
if matches!(self.content, MaybeClosed::Open(_)) {
debug!("Closing a {}", C::NAME);
}
self.content = MaybeClosed::Closed;
}
}

impl<T, C: PyTypeInfo> Drop for Closable<T, C> {
fn drop(&mut self) {
self.close();
}
}

#[pyfunction]
Expand Down Expand Up @@ -451,9 +514,3 @@ impl UserDict {
Ok(words.into_py_dict(py))
}
}

impl Drop for Synthesizer {
fn drop(&mut self) {
debug!("Destructing a VoicevoxCore");
}
}

0 comments on commit 2ffd87e

Please sign in to comment.