Skip to content

Commit

Permalink
load{,_with_userdic}の引数をUtf8Pathにし、\0入りをエラーに
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Feb 7, 2024
1 parent a16714c commit 49ccd0c
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 41 deletions.
1 change: 1 addition & 0 deletions crates/open_jtalk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.25"
edition = "2021"

[dependencies]
camino = "1.1.6"
open_jtalk-sys = { path = "../open_jtalk-sys", version = "0.16.111" }
thiserror = "1.0.31"

Expand Down
91 changes: 57 additions & 34 deletions crates/open_jtalk/src/mecab/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@ mod mecab_dict_index;
pub use mecab_dict_index::*;

use super::*;
use std::{ffi::CString, mem::MaybeUninit, path::Path};
use camino::{Utf8Path, Utf8PathBuf};
use std::{ffi::CString, mem::MaybeUninit};

#[derive(thiserror::Error, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
pub enum MecabLoadError {
#[error("`{function}` failed")]
Unsuccessful { function: &'static str },
#[error("file name contained an NUL byte: {filename:?}")]
Nul { filename: Utf8PathBuf },
}

#[derive(Default)]
pub struct Mecab(Option<open_jtalk_sys::Mecab>);
Expand Down Expand Up @@ -38,35 +47,42 @@ impl Mecab {
self.0.as_ref().unwrap() as *const open_jtalk_sys::Mecab as *mut open_jtalk_sys::Mecab
}

pub fn load(&mut self, dic_dir: impl AsRef<Path>) -> bool {
let dic_dir = CString::new(dic_dir.as_ref().to_str().unwrap()).unwrap();
unsafe {
bool_number_to_bool(open_jtalk_sys::Mecab_load(
self.as_raw_ptr(),
dic_dir.as_ptr(),
))
pub fn load(&mut self, dic_dir: impl AsRef<Utf8Path>) -> Result<(), MecabLoadError> {
let dic_dir = c_filename(dic_dir.as_ref())?;
let success = bool_number_to_bool(unsafe {
open_jtalk_sys::Mecab_load(self.as_raw_ptr(), dic_dir.as_ptr())
});
if !success {
return Err(MecabLoadError::Unsuccessful {
function: "Mecab_load",
});
}
}

/// # Panics
///
/// 次の場合にパニックする。
///
/// - `dic_dir`または`userdic`が`\0`を含む。
/// - `dic_dir`または`userdic`がUTF-8の文字列ではない。
pub fn load_with_userdic(&mut self, dic_dir: &Path, userdic: Option<&Path>) -> bool {
let dic_dir = CString::new(dic_dir.to_str().unwrap()).unwrap();
let userdic = &userdic.map(|userdic| CString::new(userdic.to_str().unwrap()).unwrap());
unsafe {
bool_number_to_bool(open_jtalk_sys::Mecab_load_with_userdic(
Ok(())
}

pub fn load_with_userdic(
&mut self,
dic_dir: &Utf8Path,
userdic: Option<&Utf8Path>,
) -> Result<(), MecabLoadError> {
let dic_dir = c_filename(dic_dir)?;
let userdic = &userdic.map(c_filename).transpose()?;
let success = bool_number_to_bool(unsafe {
open_jtalk_sys::Mecab_load_with_userdic(
self.as_raw_ptr(),
dic_dir.as_ptr(),
match userdic {
Some(userdic) => userdic.as_ptr(),
None => std::ptr::null(),
},
))
)
});
if !success {
return Err(MecabLoadError::Unsuccessful {
function: "Mecab_load_with_userdic",
});
}
Ok(())
}
pub fn get_feature(&self) -> Option<&MecabFeature> {
unsafe {
Expand Down Expand Up @@ -113,11 +129,16 @@ impl Mecab {
}
}

fn c_filename(path: &Utf8Path) -> Result<CString, MecabLoadError> {
CString::new(path.as_str()).map_err(|_| MecabLoadError::Nul {
filename: path.to_owned(),
})
}

#[cfg(test)]
mod tests {
use std::{path::PathBuf, str::FromStr};

use super::*;
use camino::Utf8Path;
use pretty_assertions::{assert_eq, assert_ne};
use resources::Resource as _;

Expand All @@ -139,11 +160,12 @@ mod tests {
#[rstest]
fn mecab_load_works() {
let mut mecab = ManagedResource::<Mecab>::initialize();
assert!(mecab.load(
PathBuf::from_str(std::env!("CARGO_MANIFEST_DIR"))
.unwrap()
.join("src/mecab/testdata/mecab_load"),
));
mecab
.load(
Utf8Path::new(std::env!("CARGO_MANIFEST_DIR"))
.join("src/mecab/testdata/mecab_load"),
)
.unwrap();
}

#[rstest]
Expand All @@ -156,11 +178,12 @@ mod tests {
#[case("h^o-d+e=s/A:2+3+2/B:22-xx_xx/C:10_7+2/D:xx+xx_xx/E:5_5!0_xx-0/F:4_1#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:1_5/I:1-4@2+1&2-1|6+4/J:xx_xx/K:2+2-9",true)]
fn mecab_analysis_works(#[case] input: &str, #[case] expected: bool) {
let mut mecab = ManagedResource::<Mecab>::initialize();
assert!(mecab.load(
PathBuf::from_str(std::env!("CARGO_MANIFEST_DIR"))
.unwrap()
.join("src/mecab/testdata/mecab_load"),
));
mecab
.load(
Utf8Path::new(std::env!("CARGO_MANIFEST_DIR"))
.join("src/mecab/testdata/mecab_load"),
)
.unwrap();
let s = text2mecab(input).unwrap();
assert_eq!(expected, mecab.analysis(s));
assert_ne!(0, mecab.get_size());
Expand Down
14 changes: 7 additions & 7 deletions crates/open_jtalk/src/njd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ impl Njd {
#[cfg(test)]
mod tests {
use super::*;
use camino::Utf8Path;
use resources::Resource as _;
use std::path::PathBuf;
use std::str::FromStr;
#[rstest]
fn njd_initialize_and_clear_works() {
let mut njd = Njd::default();
Expand Down Expand Up @@ -131,11 +130,12 @@ mod tests {
let mut njd = ManagedResource::<Njd>::initialize();
let mut mecab = ManagedResource::<Mecab>::initialize();

assert!(mecab.load(
PathBuf::from_str(std::env!("CARGO_MANIFEST_DIR"))
.unwrap()
.join("src/mecab/testdata/mecab_load"),
));
mecab
.load(
Utf8Path::new(std::env!("CARGO_MANIFEST_DIR"))
.join("src/mecab/testdata/mecab_load"),
)
.unwrap();
let s = text2mecab("h^o-d+e=s/A:2+3+2/B:22-xx_xx/C:10_7+2/D:xx+xx_xx/E:5_5!0_xx-0/F:4_1#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:1_5/I:1-4@2+1&2-1|6+4/J:xx_xx/K:2+2-9").unwrap();
assert!(mecab.analysis(s));
njd.mecab2njd(mecab.get_feature().unwrap(), mecab.get_size());
Expand Down

0 comments on commit 49ccd0c

Please sign in to comment.