Skip to content

Commit

Permalink
Merge branch 'main' into onnxruntime-rs-to-ort
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed May 21, 2024
2 parents 93f7885 + 18aec9f commit d6a7bb0
Show file tree
Hide file tree
Showing 27 changed files with 116 additions and 119 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ jobs:
- name: Exampleを実行
run: |
for file in ../../example/python/run{,-asyncio}.py; do
poetry run python "$file" ../../model/sample.vvm --dict-dir ../test_util/data/open_jtalk_dic_utf_8-1.11
poetry run python "$file" ../test_util/data/model/sample.vvm --dict-dir ../test_util/data/open_jtalk_dic_utf_8-1.11
done
build-and-test-java-api:
strategy:
Expand Down
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ heck = "0.4.1"
humansize = "2.1.2"
indexmap = "2.0.0"
indicatif = "0.17.3"
indoc = "2.0.4"
inventory = "0.3.4"
itertools = "0.10.5"
jlabel = "0.1.2"
Expand Down
5 changes: 2 additions & 3 deletions crates/test_util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@ name = "test_util"
edition.workspace = true

[dependencies]
async_zip = { workspace = true, features = ["deflate"] }
futures-lite.workspace = true
libloading.workspace = true
once_cell.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
tokio = { workspace = true, features = ["fs", "io-util", "sync"] }

[build-dependencies]
anyhow.workspace = true
Expand All @@ -18,10 +15,12 @@ bindgen.workspace = true
camino.workspace = true
flate2.workspace = true
fs-err.workspace = true
indoc.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true, features = ["preserve_order"] }
reqwest = { workspace = true, features = ["rustls-tls"] }
tar.workspace = true
zip.workspace = true

[lints.rust]
unsafe_code = "allow" # C APIのbindgen
Expand Down
81 changes: 70 additions & 11 deletions crates/test_util/build.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::{
env,
path::{Path, PathBuf},
io::{self, Cursor, Write as _},
path::Path,
};

use anyhow::ensure;
use camino::Utf8PathBuf;
use anyhow::{anyhow, ensure};
use camino::{Utf8Path, Utf8PathBuf};
use flate2::read::GzDecoder;
use indoc::formatdoc;
use tar::Archive;
use zip::{write::FileOptions, ZipWriter};

#[path = "src/typing.rs"]
mod typing;
Expand All @@ -15,21 +18,78 @@ const DIC_DIR_NAME: &str = "open_jtalk_dic_utf_8-1.11";

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut dist = PathBuf::from(env::var_os("CARGO_MANIFEST_DIR").unwrap());
dist.push("data");
let out_dir = &Utf8PathBuf::from(env::var("OUT_DIR").unwrap());
let dist = &Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("data");

let dic_dir = dist.join(DIC_DIR_NAME);
if !dic_dir.try_exists()? {
download_open_jtalk_dict(&dist).await?;
ensure!(dic_dir.exists(), "`{}` does not exist", dic_dir.display());
download_open_jtalk_dict(dist.as_ref()).await?;
ensure!(dic_dir.exists(), "`{dic_dir}` does not exist");
}

generate_example_data_json(&dist)?;
create_sample_voice_model_file(out_dir, dist)?;

generate_example_data_json(dist.as_ref())?;

println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/typing.rs");

generate_c_api_rs_bindings()
generate_c_api_rs_bindings(out_dir)
}

fn create_sample_voice_model_file(out_dir: &Utf8Path, dist: &Utf8Path) -> anyhow::Result<()> {
const SRC: &str = "../../model/sample.vvm";

let files = fs_err::read_dir(SRC)?
.map(|entry| {
let entry = entry?;
let md = entry.metadata()?;
ensure!(!md.is_dir(), "directory in {SRC}");
let mtime = md.modified()?;
let name = entry
.file_name()
.into_string()
.map_err(|name| anyhow!("{name:?}"))?;
Ok((name, entry.path(), mtime))
})
.collect::<anyhow::Result<Vec<_>>>()?;

let output_dir = &dist.join("model");
let output_file = &output_dir.join("sample.vvm");

let up_to_date = fs_err::metadata(output_file)
.and_then(|md| md.modified())
.map(|t1| files.iter().all(|&(_, _, t2)| t1 >= t2));
let up_to_date = match up_to_date {
Ok(p) => p,
Err(e) if e.kind() == io::ErrorKind::NotFound => false,
Err(e) => return Err(e.into()),
};

if !up_to_date {
let mut zip = ZipWriter::new(Cursor::new(vec![]));
for (name, path, _) in files {
let content = &fs_err::read(path)?;
zip.start_file(name, FileOptions::default().compression_level(Some(0)))?;
zip.write_all(content)?;
}
let zip = zip.finish()?;
fs_err::create_dir_all(output_dir)?;
fs_err::write(output_file, zip.get_ref())?;
}

fs_err::write(
out_dir.join("sample_voice_model_file.rs"),
formatdoc! {"
pub const SAMPLE_VOICE_MODEL_FILE_PATH: &::std::primitive::str = {output_file:?};
const SAMPLE_VOICE_MODEL_FILE_C_PATH: &::std::ffi::CStr = c{output_file:?};
const VV_MODELS_ROOT_DIR: &::std::primitive::str = {output_dir:?};
",
},
)?;
println!("cargo:rerun-if-changed={SRC}");
Ok(())
}

/// OpenJTalkの辞書をダウンロードして展開する。
Expand Down Expand Up @@ -120,11 +180,10 @@ fn generate_example_data_json(dist: &Path) -> anyhow::Result<()> {
Ok(())
}

fn generate_c_api_rs_bindings() -> anyhow::Result<()> {
fn generate_c_api_rs_bindings(out_dir: &Utf8Path) -> anyhow::Result<()> {
static C_BINDINGS_PATH: &str = "../voicevox_core_c_api/include/voicevox_core.h";
static ADDITIONAL_C_BINDINGS_PATH: &str = "./compatible_engine.h";

let out_dir = Utf8PathBuf::from(env::var("OUT_DIR").unwrap());
bindgen::Builder::default()
.header(C_BINDINGS_PATH)
.header(ADDITIONAL_C_BINDINGS_PATH)
Expand Down
56 changes: 5 additions & 51 deletions crates/test_util/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
mod typing;

include!(concat!(env!("OUT_DIR"), "/sample_voice_model_file.rs"));

#[allow(
non_camel_case_types,
non_snake_case,
Expand All @@ -10,20 +12,12 @@ mod typing;
)]
pub mod c_api {
include!(concat!(env!("OUT_DIR"), "/c_api.rs"));

pub const SAMPLE_VOICE_MODEL_FILE_PATH: &std::ffi::CStr = super::SAMPLE_VOICE_MODEL_FILE_C_PATH;
pub const VV_MODELS_ROOT_DIR: &str = super::VV_MODELS_ROOT_DIR;
}

use async_zip::{base::write::ZipFileWriter, Compression, ZipEntryBuilder};
use futures_lite::AsyncWriteExt as _;
use once_cell::sync::Lazy;
use std::{
collections::HashMap,
path::{Path, PathBuf},
};
use tokio::{
fs::{self, File},
io::AsyncReadExt,
sync::Mutex,
};

pub use self::typing::{
DecodeExampleData, DurationExampleData, ExampleData, IntonationExampleData,
Expand All @@ -42,43 +36,3 @@ const EXAMPLE_DATA_JSON: &str = include_str!(concat!(
pub static EXAMPLE_DATA: Lazy<ExampleData> = Lazy::new(|| {
serde_json::from_str(EXAMPLE_DATA_JSON).expect("failed to parse example_data.json")
});

static PATH_MUTEX: Lazy<Mutex<HashMap<PathBuf, Mutex<()>>>> =
Lazy::new(|| Mutex::new(HashMap::default()));

pub async fn convert_zip_vvm(dir: impl AsRef<Path>) -> PathBuf {
let dir = dir.as_ref();
let output_file_name = dir.file_name().unwrap().to_str().unwrap().to_owned() + ".vvm";

let out_file_path = PathBuf::from(env!("OUT_DIR"))
.join("test_data/models/")
.join(output_file_name);
let mut path_map = PATH_MUTEX.lock().await;
if !path_map.contains_key(&out_file_path) {
path_map.insert(out_file_path.clone(), Mutex::new(()));
}
let _m = path_map.get(&out_file_path).unwrap().lock().await;

if !out_file_path.exists() {
fs::create_dir_all(out_file_path.parent().unwrap())
.await
.unwrap();
let mut writer = ZipFileWriter::new(vec![]);

for entry in dir.read_dir().unwrap().flatten() {
let entry_builder = ZipEntryBuilder::new(
entry.path().file_name().unwrap().to_str().unwrap().into(),
Compression::Deflate,
);
let mut entry_writer = writer.write_entry_stream(entry_builder).await.unwrap();
let mut file = File::open(entry.path()).await.unwrap();
let mut buf = Vec::with_capacity(entry.metadata().unwrap().len() as usize);
file.read_to_end(&mut buf).await.unwrap();
entry_writer.write_all(&buf).await.unwrap();
entry_writer.close().await.unwrap();
}
let zip = writer.close().await.unwrap();
fs::write(&out_file_path, zip).await.unwrap();
}
out_file_path
}
9 changes: 4 additions & 5 deletions crates/voicevox_core/src/__internal/doctest_fixtures.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::path::Path;

use camino::Utf8Path;

use crate::{AccelerationMode, InitializeOptions};

pub async fn synthesizer_with_sample_voice_model(
voice_model_path: impl AsRef<Path>,
open_jtalk_dic_dir: impl AsRef<Utf8Path>,
) -> anyhow::Result<crate::tokio::Synthesizer<crate::tokio::OpenJtalk>> {
let syntesizer = crate::tokio::Synthesizer::new(
Expand All @@ -13,11 +16,7 @@ pub async fn synthesizer_with_sample_voice_model(
},
)?;

let model = &crate::tokio::VoiceModel::from_path(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../model/sample.vvm",
))
.await?;
let model = &crate::tokio::VoiceModel::from_path(voice_model_path).await?;
syntesizer.load_voice_model(model).await?;

Ok(syntesizer)
Expand Down
5 changes: 2 additions & 3 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ mod tests {
},
macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl,
test_util::open_default_vvm_file,
};

use super::Status;
Expand Down Expand Up @@ -399,7 +398,7 @@ mod tests {
let status = Status::<InferenceRuntimeImpl>::new(InferenceDomainMap {
talk: enum_map!(_ => InferenceSessionOptions::new(0, false)),
});
let model = &open_default_vvm_file().await;
let model = &crate::tokio::VoiceModel::sample().await.unwrap();
let model_contents = &model.read_inference_models().await.unwrap();
let result = status.insert_model(model.header(), model_contents);
assert_debug_fmt_eq!(Ok(()), result);
Expand All @@ -412,7 +411,7 @@ mod tests {
let status = Status::<InferenceRuntimeImpl>::new(InferenceDomainMap {
talk: enum_map!(_ => InferenceSessionOptions::new(0, false)),
});
let vvm = open_default_vvm_file().await;
let vvm = &crate::tokio::VoiceModel::sample().await.unwrap();
let model_header = vvm.header();
let model_contents = &vvm.read_inference_models().await.unwrap();
assert!(
Expand Down
17 changes: 10 additions & 7 deletions crates/voicevox_core/src/synthesizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ pub(crate) mod blocking {
/// # async fn main() -> anyhow::Result<()> {
/// # let synthesizer =
/// # voicevox_core::__internal::doctest_fixtures::synthesizer_with_sample_voice_model(
/// # test_util::SAMPLE_VOICE_MODEL_FILE_PATH,
/// # test_util::OPEN_JTALK_DIC_DIR,
/// # )
/// # .await?;
Expand Down Expand Up @@ -682,6 +683,7 @@ pub(crate) mod blocking {
/// # async fn main() -> anyhow::Result<()> {
/// # let synthesizer =
/// # voicevox_core::__internal::doctest_fixtures::synthesizer_with_sample_voice_model(
/// # test_util::SAMPLE_VOICE_MODEL_FILE_PATH,
/// # test_util::OPEN_JTALK_DIC_DIR,
/// # )
/// # .await?;
Expand Down Expand Up @@ -730,6 +732,7 @@ pub(crate) mod blocking {
/// # async fn main() -> anyhow::Result<()> {
/// # let synthesizer =
/// # voicevox_core::__internal::doctest_fixtures::synthesizer_with_sample_voice_model(
/// # test_util::SAMPLE_VOICE_MODEL_FILE_PATH,
/// # test_util::OPEN_JTALK_DIC_DIR,
/// # )
/// # .await?;
Expand Down Expand Up @@ -762,6 +765,7 @@ pub(crate) mod blocking {
/// # async fn main() -> anyhow::Result<()> {
/// # let synthesizer =
/// # voicevox_core::__internal::doctest_fixtures::synthesizer_with_sample_voice_model(
/// # test_util::SAMPLE_VOICE_MODEL_FILE_PATH,
/// # test_util::OPEN_JTALK_DIC_DIR,
/// # )
/// # .await?;
Expand Down Expand Up @@ -1291,8 +1295,7 @@ mod tests {

use super::{blocking::PerformInference as _, AccelerationMode, InitializeOptions};
use crate::{
engine::MoraModel, macros::tests::assert_debug_fmt_eq, test_util::open_default_vvm_file,
AccentPhraseModel, Result, StyleId,
engine::MoraModel, macros::tests::assert_debug_fmt_eq, AccentPhraseModel, Result, StyleId,
};
use ::test_util::OPEN_JTALK_DIC_DIR;
use rstest::rstest;
Expand All @@ -1311,7 +1314,7 @@ mod tests {
.unwrap();

let result = syntesizer
.load_voice_model(&open_default_vvm_file().await)
.load_voice_model(&crate::tokio::VoiceModel::sample().await.unwrap())
.await;

assert_debug_fmt_eq!(
Expand Down Expand Up @@ -1353,7 +1356,7 @@ mod tests {
"expected is_model_loaded to return false, but got true",
);
syntesizer
.load_voice_model(&open_default_vvm_file().await)
.load_voice_model(&crate::tokio::VoiceModel::sample().await.unwrap())
.await
.unwrap();

Expand All @@ -1378,7 +1381,7 @@ mod tests {
.unwrap();

syntesizer
.load_voice_model(&open_default_vvm_file().await)
.load_voice_model(&crate::tokio::VoiceModel::sample().await.unwrap())
.await
.unwrap();

Expand Down Expand Up @@ -1408,7 +1411,7 @@ mod tests {
)
.unwrap();
syntesizer
.load_voice_model(&open_default_vvm_file().await)
.load_voice_model(&crate::tokio::VoiceModel::sample().await.unwrap())
.await
.unwrap();

Expand Down Expand Up @@ -1447,7 +1450,7 @@ mod tests {
)
.unwrap();
syntesizer
.load_voice_model(&open_default_vvm_file().await)
.load_voice_model(&crate::tokio::VoiceModel::sample().await.unwrap())
.await
.unwrap();

Expand Down
Loading

0 comments on commit d6a7bb0

Please sign in to comment.