Skip to content

Commit

Permalink
Java APIでもRwLock管理にする
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Sep 17, 2024
1 parent 695c0a9 commit d69571a
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 102 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/voicevox_core/src/__internal/interop.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod raii;

pub use crate::{
metas::merge as merge_metas, synthesizer::blocking::PerformInference,
voice_model::blocking::IdRef,
Expand Down
43 changes: 43 additions & 0 deletions crates/voicevox_core/src/__internal/interop/raii.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use std::{marker::PhantomData, ops::Deref};

use ouroboros::self_referencing;

pub enum MaybeClosed<T> {
Open(T),
Closed,
}

// [`mapped_lock_guards`]のようなことをやるためのユーティリティ。
//
// [`mapped_lock_guards`]: https://github.com/rust-lang/rust/issues/117108
pub fn try_map_guard<'a, G, F, T, E>(guard: G, f: F) -> Result<impl Deref<Target = T> + 'a, E>
where
G: 'a,
F: FnOnce(&G) -> Result<&T, E>,
T: 'static,
{
return MappedLockTryBuilder {
guard,
content_builder: f,
marker: PhantomData,
}
.try_build();

#[self_referencing]
struct MappedLock<'a, G: 'a, T: 'static> {
guard: G,

#[borrows(guard)]
content: &'this T,

marker: PhantomData<&'a ()>,
}

impl<'a, G: 'a, T: 'static> Deref for MappedLock<'a, G, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
self.borrow_content()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ public VoiceModelFile.SpeakerMeta[] metas() {
* @throws InvalidModelDataException 無効なモデルデータの場合。
*/
public void loadVoiceModel(VoiceModelFile voiceModel) throws InvalidModelDataException {
synchronized (voiceModel) {
rsLoadVoiceModel(voiceModel.opened());
}
rsLoadVoiceModel(voiceModel);
}

/**
Expand Down Expand Up @@ -287,8 +285,7 @@ public TtsConfigurator tts(String text, int styleId) {
@Nonnull
private native String rsGetMetasJson();

private native void rsLoadVoiceModel(VoiceModelFile.Opened voiceModel)
throws InvalidModelDataException;
private native void rsLoadVoiceModel(VoiceModelFile voiceModel) throws InvalidModelDataException;

private native void rsUnloadVoiceModel(UUID voiceModelId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,22 @@
import jakarta.annotation.Nonnull;
import jakarta.annotation.Nullable;
import java.io.Closeable;
import java.util.Optional;
import java.util.UUID;

/** 音声モデルファイル。 */
public class VoiceModelFile extends Dll implements Closeable {
private long handle;

/** ID。 */
@Nonnull public final UUID id;

/** メタ情報。 */
@Nonnull public final SpeakerMeta[] metas;

@Nullable private Opened inner;

public VoiceModelFile(String modelPath) {
inner = new Opened(modelPath);
id = inner.rsGetId();
String metasJson = inner.rsGetMetasJson();
rsOpen(modelPath);
id = rsGetId();
String metasJson = rsGetMetasJson();
Gson gson = new Gson();
SpeakerMeta[] rawMetas = gson.fromJson(metasJson, SpeakerMeta[].class);
if (rawMetas == null) {
Expand All @@ -31,48 +30,28 @@ public VoiceModelFile(String modelPath) {
metas = rawMetas;
}

// `synchronized`は`Synthesizer`側でやる
Opened opened() {
if (inner == null) {
throw new IllegalStateException("this `VoiceModelFile` is closed");
}
return inner;
@Override
public void close() {
rsClose();
}

@Override
public synchronized void close() {
Optional<Opened> inner = Optional.ofNullable(this.inner);
this.inner = null;
if (inner.isPresent()) {
inner.get().rsDrop();
}
protected void finalize() throws Throwable {
rsDrop();
super.finalize();
}

static class Opened {
private long handle;

private Opened(String modelPath) {
rsOpen(modelPath);
}

@Override
protected void finalize() throws Throwable {
if (handle != 0) {
rsDrop();
}
super.finalize();
}
private native void rsOpen(String modelPath);

private native void rsOpen(String modelPath);
@Nonnull
private native UUID rsGetId();

@Nonnull
private native UUID rsGetId();
@Nonnull
private native String rsGetMetasJson();

@Nonnull
private native String rsGetMetasJson();
private native void rsClose();

private native void rsDrop();
}
private native void rsDrop();

/** 話者(speaker)のメタ情報。 */
public static class SpeakerMeta {
Expand Down
72 changes: 71 additions & 1 deletion crates/voicevox_core_java_api/src/common.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::{error::Error as _, iter};
use std::{error::Error as _, iter, mem, ops::Deref};

use derive_more::From;
use easy_ext::ext;
use jni::{
objects::{JObject, JThrowable},
JNIEnv,
};
use tracing::{debug, warn};
use uuid::Uuid;
use voicevox_core::__internal::interop::raii::MaybeClosed;

#[macro_export]
macro_rules! object {
Expand Down Expand Up @@ -154,13 +156,18 @@ where
env.throw_new("java/lang/IllegalArgumentException", error.to_string())
)
}
JavaApiError::IllegalState(msg) => {
or_panic!(env.throw_new("java/lang/IllegalStateException", msg))
}
};
}
fallback
}
}
}

type JavaApiResult<T> = Result<T, JavaApiError>;

#[derive(From, Debug)]
pub(crate) enum JavaApiError {
#[from]
Expand All @@ -173,6 +180,69 @@ pub(crate) enum JavaApiError {
Uuid(uuid::Error),

DeJson(serde_json::Error),

IllegalState(String),
}

pub(crate) struct Closable<T: HasJavaClassIdent>(std::sync::RwLock<MaybeClosed<T>>);

impl<T: HasJavaClassIdent + 'static> Closable<T> {
pub(crate) fn new(content: T) -> Self {
Self(MaybeClosed::Open(content).into())
}

pub(crate) fn read(&self) -> JavaApiResult<impl Deref<Target = T> + '_> {
let lock = self.0.try_read().map_err(|e| match e {
std::sync::TryLockError::Poisoned(e) => panic!("{e}"),
std::sync::TryLockError::WouldBlock => {
JavaApiError::IllegalState(format!("The `{}` is being closed", T::JAVA_CLASS_IDENT))
}
})?;

voicevox_core::__internal::interop::raii::try_map_guard(lock, |lock| match &**lock {
MaybeClosed::Open(content) => Ok(content),
MaybeClosed::Closed => Err(JavaApiError::IllegalState(format!(
"The `{}` is closed",
T::JAVA_CLASS_IDENT,
))),
})
}

pub(crate) fn close(&self) {
let lock = &mut *match self.0.try_write() {
Ok(lock) => lock,
Err(std::sync::TryLockError::Poisoned(e)) => panic!("{e}"),
Err(std::sync::TryLockError::WouldBlock) => {
self.0.write().unwrap_or_else(|e| panic!("{e}"))
}
};

if matches!(*lock, MaybeClosed::Open(_)) {
debug!("Closing a `{}`", T::JAVA_CLASS_IDENT);
}
drop(mem::replace(lock, MaybeClosed::Closed));
}
}

impl<T: HasJavaClassIdent> Drop for Closable<T> {
fn drop(&mut self) {
let content = mem::replace(
&mut *self.0.write().unwrap_or_else(|e| panic!("{e}")),
MaybeClosed::Closed,
);
if let MaybeClosed::Open(content) = content {
warn!(
"デストラクタにより`{}`のクローズを行います。通常は、可能な限り`close`でクローズす\
るようにして下さい",
T::JAVA_CLASS_IDENT,
);
drop(content);
}
}
}

pub(crate) trait HasJavaClassIdent {
const JAVA_CLASS_IDENT: &str;
}

#[ext(JNIEnvExt)]
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core_java_api/src/synthesizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_Synthesizer_rsLoadVoice
) {
throw_if_err(env, (), |env| {
let model = env
.get_rust_field::<_, _, Arc<voicevox_core::blocking::VoiceModelFile>>(&model, "handle")?
.get_rust_field::<_, _, Arc<crate::voice_model::VoiceModelFile>>(&model, "handle")?
.clone();
let model = model.read()?;
let internal = env
.get_rust_field::<_, _, Arc<voicevox_core::blocking::Synthesizer<voicevox_core::blocking::OpenJtalk>>>(
&this, "handle",
Expand Down
46 changes: 29 additions & 17 deletions crates/voicevox_core_java_api/src/voice_model.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use std::{borrow::Cow, sync::Arc};

use crate::common::{throw_if_err, JNIEnvExt as _};
use crate::common::{throw_if_err, Closable, HasJavaClassIdent, JNIEnvExt as _};
use jni::{
objects::{JObject, JString},
sys::jobject,
JNIEnv,
};

pub(crate) type VoiceModelFile = Closable<voicevox_core::blocking::VoiceModelFile>;

impl HasJavaClassIdent for voicevox_core::blocking::VoiceModelFile {
const JAVA_CLASS_IDENT: &str = "VoiceModelFile";
}

#[no_mangle]
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_00024Opened_rsOpen<
'local,
>(
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_rsOpen<'local>(
env: JNIEnv<'local>,
this: JObject<'local>,
model_path: JString<'local>,
Expand All @@ -20,24 +24,23 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_00024Ope
let model_path = &*Cow::from(&model_path);

let internal = voicevox_core::blocking::VoiceModelFile::open(model_path)?;

env.set_rust_field(&this, "handle", Arc::new(internal))?;
let internal = Arc::new(Closable::new(internal));
env.set_rust_field(&this, "handle", internal)?;

Ok(())
})
}

#[no_mangle]
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_00024Opened_rsGetId<
'local,
>(
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_rsGetId<'local>(
env: JNIEnv<'local>,
this: JObject<'local>,
) -> jobject {
throw_if_err(env, std::ptr::null_mut(), |env| {
let internal = env
.get_rust_field::<_, _, Arc<voicevox_core::blocking::VoiceModelFile>>(&this, "handle")?
.get_rust_field::<_, _, Arc<VoiceModelFile>>(&this, "handle")?
.clone();
let internal = internal.read()?;

let id = env.new_uuid(internal.id().raw_voice_model_id())?;

Expand All @@ -46,16 +49,15 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_00024Ope
}

#[no_mangle]
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_00024Opened_rsGetMetasJson<
'local,
>(
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_rsGetMetasJson<'local>(
env: JNIEnv<'local>,
this: JObject<'local>,
) -> jobject {
throw_if_err(env, std::ptr::null_mut(), |env| {
let internal = env
.get_rust_field::<_, _, Arc<voicevox_core::blocking::VoiceModelFile>>(&this, "handle")?
.get_rust_field::<_, _, Arc<VoiceModelFile>>(&this, "handle")?
.clone();
let internal = internal.read()?;

let metas = internal.metas();
let metas_json = serde_json::to_string(&metas).expect("should not fail");
Expand All @@ -64,9 +66,19 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_00024Ope
}

#[no_mangle]
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_00024Opened_rsDrop<
'local,
>(
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_rsClose<'local>(
env: JNIEnv<'local>,
this: JObject<'local>,
) {
throw_if_err(env, (), |env| {
env.take_rust_field::<_, _, Arc<VoiceModelFile>>(&this, "handle")?
.close();
Ok(())
})
}

#[no_mangle]
unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_VoiceModelFile_rsDrop<'local>(
env: JNIEnv<'local>,
this: JObject<'local>,
) {
Expand Down
1 change: 0 additions & 1 deletion crates/voicevox_core_python_api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ easy-ext.workspace = true
futures-lite.workspace = true
log.workspace = true
once_cell.workspace = true
ouroboros.workspace = true
pyo3 = { workspace = true, features = ["abi3-py38", "extension-module"] }
pyo3-asyncio = { workspace = true, features = ["tokio-runtime"] }
pyo3-log.workspace = true
Expand Down
Loading

0 comments on commit d69571a

Please sign in to comment.