Skip to content

Commit

Permalink
Merge branch 'main' into onnxruntime-rs-to-ort
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed May 6, 2024
2 parents ef9c216 + e55ea58 commit 3c7529d
Show file tree
Hide file tree
Showing 30 changed files with 1,112 additions and 582 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ open_jtalk.workspace = true
ouroboros.workspace = true
rayon.workspace = true
regex.workspace = true
serde = { workspace = true, features = ["derive"] }
serde = { workspace = true, features = ["derive", "rc"] }
serde_json = { workspace = true, features = ["preserve_order"] }
serde_with.workspace = true
smallvec.workspace = true
strum = { workspace = true, features = ["derive"] }
tempfile.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["rt"] } # FIXME: feature-gateする
Expand Down
20 changes: 15 additions & 5 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::{
engine::{FullContextLabelError, KanaParseError},
user_dict::InvalidWordError,
StyleId, VoiceModelId,
StyleId, StyleType, VoiceModelId,
};
//use engine::
use duplicate::duplicate_item;
use std::path::PathBuf;
use itertools::Itertools as _;
use std::{collections::BTreeSet, path::PathBuf};
use thiserror::Error;
use uuid::Uuid;

Expand Down Expand Up @@ -38,6 +39,7 @@ impl Error {
LoadModelErrorKind::ReadZipEntry { .. } => ErrorKind::ReadZipEntry,
LoadModelErrorKind::ModelAlreadyLoaded { .. } => ErrorKind::ModelAlreadyLoaded,
LoadModelErrorKind::StyleAlreadyLoaded { .. } => ErrorKind::StyleAlreadyLoaded,
LoadModelErrorKind::InvalidModelFormat { .. } => ErrorKind::InvalidModelFormat,
LoadModelErrorKind::InvalidModelData => ErrorKind::InvalidModelData,
},
ErrorRepr::GetSupportedDevices(_) => ErrorKind::GetSupportedDevices,
Expand Down Expand Up @@ -70,10 +72,14 @@ pub(crate) enum ErrorRepr {
GetSupportedDevices(#[source] anyhow::Error),

#[error(
"`{style_id}`に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\
み込みが解除されています"
"`{style_id}` ([{style_types}])に対するスタイルが見つかりませんでした。音声モデルが\
読み込まれていないか、読み込みが解除されています",
style_types = style_types.iter().format(", ")
)]
StyleNotFound { style_id: StyleId },
StyleNotFound {
style_id: StyleId,
style_types: &'static BTreeSet<StyleType>,
},

#[error(
"`{model_id}`に対する音声モデルが見つかりませんでした。読み込まれていないか、読み込みが既\
Expand Down Expand Up @@ -117,6 +123,8 @@ pub enum ErrorKind {
OpenZipFile,
/// ZIP内のファイルが読めなかった。
ReadZipEntry,
/// モデルの形式が不正。
InvalidModelFormat,
/// すでに読み込まれている音声モデルを読み込もうとした。
ModelAlreadyLoaded,
/// すでに読み込まれているスタイルを読み込もうとした。
Expand Down Expand Up @@ -165,6 +173,8 @@ pub(crate) enum LoadModelErrorKind {
OpenZipFile,
#[display(fmt = "`{filename}`を読み取れませんでした")]
ReadZipEntry { filename: String },
#[display(fmt = "モデルの形式が不正です")]
InvalidModelFormat,
#[display(fmt = "モデル`{id}`は既に読み込まれています")]
ModelAlreadyLoaded { id: VoiceModelId },
#[display(fmt = "スタイル`{id}`は既に読み込まれています")]
Expand Down
21 changes: 15 additions & 6 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
pub(crate) mod domain;
pub(crate) mod domains;
mod model_file;
pub(crate) mod runtimes;
pub(crate) mod status;
pub(crate) mod session_set;

use std::{borrow::Cow, fmt::Debug};
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug};

use derive_new::new;
use duplicate::duplicate_item;
use enum_map::{Enum, EnumMap};
use ndarray::{Array, ArrayD, Dimension, ShapeError};
use thiserror::Error;

use crate::SupportedDevices;
use crate::{StyleType, SupportedDevices};

pub(crate) trait InferenceRuntime: 'static {
// TODO: "session"とは何なのかを定め、ドキュメントを書く。`InferenceSessionSet`も同様。
type Session: Sized + Send + 'static;
type RunContext<'a>: From<&'a mut Self::Session> + PushInputTensor;

Expand All @@ -32,9 +33,17 @@ pub(crate) trait InferenceRuntime: 'static {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

/// ある`VoiceModel`が提供する推論操作の集合を示す
pub(crate) trait InferenceDomain {
/// 共に扱われるべき推論操作の集合を示す
pub(crate) trait InferenceDomain: Sized {
type Operation: InferenceOperation;

/// 対応する`StyleType`。
///
/// 複数の`InferenceDomain`に対応する`StyleType`があってもよい。
///
/// また、どの`InferenceDomain`にも属さない`StyleType`があってもよい。そのような`StyleType`は
/// 音声モデルのロード時に単に拒否されるべきである。
fn style_types() -> &'static BTreeSet<StyleType>;
}

/// `InferenceDomain`の推論操作を表す列挙型。
Expand Down
22 changes: 22 additions & 0 deletions crates/voicevox_core/src/infer/domains.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
mod talk;

pub(crate) use self::talk::{
DecodeInput, DecodeOutput, PredictDurationInput, PredictDurationOutput, PredictIntonationInput,
PredictIntonationOutput, TalkDomain, TalkOperation,
};

pub(crate) struct InferenceDomainMap<V: InferenceDomainMapValues + ?Sized> {
pub(crate) talk: V::Talk,
}

pub(crate) trait InferenceDomainMapValues {
type Talk;
}

impl<T> InferenceDomainMapValues for (T,) {
type Talk = T;
}

impl<A> InferenceDomainMapValues for [A] {
type Talk = A;
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
use std::collections::BTreeSet;

use enum_map::Enum;
use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};
use once_cell::sync::Lazy;

use crate::StyleType;

use super::{
use super::super::{
InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor,
};

pub(crate) enum InferenceDomainImpl {}
pub(crate) enum TalkDomain {}

impl InferenceDomain for TalkDomain {
type Operation = TalkOperation;

impl InferenceDomain for InferenceDomainImpl {
type Operation = InferenceOperationImpl;
fn style_types() -> &'static BTreeSet<StyleType> {
static STYLE_TYPES: Lazy<BTreeSet<StyleType>> = Lazy::new(|| [StyleType::Talk].into());
&STYLE_TYPES
}
}

#[derive(Clone, Copy, Enum, InferenceOperation)]
#[inference_operation(
type Domain = InferenceDomainImpl;
type Domain = TalkDomain;
)]
pub(crate) enum InferenceOperationImpl {
pub(crate) enum TalkOperation {
#[inference_operation(
type Input = PredictDurationInput;
type Output = PredictDurationOutput;
Expand Down
101 changes: 101 additions & 0 deletions crates/voicevox_core/src/infer/session_set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use std::{collections::HashMap, fmt::Display, marker::PhantomData, sync::Arc};

use anyhow::bail;
use enum_map::{Enum as _, EnumMap};
use itertools::Itertools as _;

use crate::error::ErrorRepr;

use super::{
model_file, InferenceDomain, InferenceInputSignature, InferenceOperation, InferenceRuntime,
InferenceSessionOptions, InferenceSignature, ParamInfo,
};

pub(crate) struct InferenceSessionSet<R: InferenceRuntime, D: InferenceDomain>(
EnumMap<D::Operation, Arc<std::sync::Mutex<R::Session>>>,
);

impl<R: InferenceRuntime, D: InferenceDomain> InferenceSessionSet<R, D> {
pub(crate) fn new(
model_bytes: &EnumMap<D::Operation, Vec<u8>>,
options: &EnumMap<D::Operation, InferenceSessionOptions>,
) -> anyhow::Result<Self> {
let mut sessions = model_bytes
.iter()
.map(|(op, model_bytes)| {
let (expected_input_param_infos, expected_output_param_infos) =
<D::Operation as InferenceOperation>::PARAM_INFOS[op];

let (sess, actual_input_param_infos, actual_output_param_infos) =
R::new_session(|| model_file::decrypt(model_bytes), options[op])?;

check_param_infos(expected_input_param_infos, &actual_input_param_infos)?;
check_param_infos(expected_output_param_infos, &actual_output_param_infos)?;

Ok((op.into_usize(), std::sync::Mutex::new(sess).into()))
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;

return Ok(Self(EnumMap::<D::Operation, _>::from_fn(|k| {
sessions.remove(&k.into_usize()).expect("should exist")
})));

fn check_param_infos<D: PartialEq + Display>(
expected: &[ParamInfo<D>],
actual: &[ParamInfo<D>],
) -> anyhow::Result<()> {
if !(expected.len() == actual.len()
&& itertools::zip_eq(expected, actual)
.all(|(expected, actual)| expected.accepts(actual)))
{
let expected = display_param_infos(expected);
let actual = display_param_infos(actual);
bail!("expected {{{expected}}}, got {{{actual}}}")
}
Ok(())
}

fn display_param_infos(infos: &[ParamInfo<impl Display>]) -> impl Display {
infos
.iter()
.map(|ParamInfo { name, dt, ndim }| {
let brackets = match *ndim {
Some(ndim) => "[]".repeat(ndim),
None => "[]...".to_owned(),
};
format!("{name}: {dt}{brackets}")
})
.join(", ")
}
}
}

impl<R: InferenceRuntime, D: InferenceDomain> InferenceSessionSet<R, D> {
pub(crate) fn get<I>(&self) -> InferenceSessionCell<R, I>
where
I: InferenceInputSignature,
I::Signature: InferenceSignature<Domain = D>,
{
InferenceSessionCell {
inner: self.0[I::Signature::OPERATION].clone(),
marker: PhantomData,
}
}
}

pub(crate) struct InferenceSessionCell<R: InferenceRuntime, I> {
inner: Arc<std::sync::Mutex<R::Session>>,
marker: PhantomData<fn(I)>,
}

impl<R: InferenceRuntime, I: InferenceInputSignature> InferenceSessionCell<R, I> {
pub(crate) fn run(
self,
input: I,
) -> crate::Result<<I::Signature as InferenceSignature>::Output> {
let inner = &mut self.inner.lock().unwrap();
(|| R::run(input.make_run_context::<R>(inner)?)?.try_into())()
.map_err(ErrorRepr::InferenceFailed)
.map_err(Into::into)
}
}
Loading

0 comments on commit 3c7529d

Please sign in to comment.