From 64158fcf0f996f66a18bea16536e1f706ef0bcc3 Mon Sep 17 00:00:00 2001 From: name1e5s Date: Sun, 18 Dec 2022 21:34:30 +0800 Subject: [PATCH] allow users to load custom hmm model --- Cargo.toml | 5 +- build.rs | 8 +-- src/hmm.rs | 149 +++++++++++++++++++++++++++++++++++++++++++++++++---- src/lib.rs | 11 +++- 4 files changed, 156 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5d26926..55f7892 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,17 +30,18 @@ required-features = ["tfidf", "textrank"] regex = "1.0" lazy_static = "1.0" phf = "0.11" -hashbrown = { version = "0.12", default-features = false, features = ["inline-more"] } cedarwood = "0.4" ordered-float = { version = "3.0", optional = true } +once_cell = "1" fxhash = "0.2.1" [build-dependencies] phf_codegen = "0.11" [features] -default = ["default-dict"] +default = ["default-dict", "default-hmm-model"] default-dict = [] +default-hmm-model = [] tfidf = ["ordered-float"] textrank = ["ordered-float"] diff --git a/build.rs b/build.rs index fb77314..f3b5aeb 100644 --- a/build.rs +++ b/build.rs @@ -13,13 +13,13 @@ fn main() { let mut lines = reader.lines().map(|x| x.unwrap()).skip_while(|x| x.starts_with('#')); let prob_start = lines.next().unwrap(); writeln!(&mut file, "#[allow(clippy::style)]").unwrap(); - write!(&mut file, "static INITIAL_PROBS: StatusSet = [").unwrap(); + write!(&mut file, "pub static INITIAL_PROBS: StatusSet = [").unwrap(); for prob in prob_start.split(' ') { write!(&mut file, "{}, ", prob).unwrap(); } write!(&mut file, "];\n\n").unwrap(); writeln!(&mut file, "#[allow(clippy::style)]").unwrap(); - write!(&mut file, "static TRANS_PROBS: [StatusSet; 4] = [").unwrap(); + write!(&mut file, "pub static TRANS_PROBS: [StatusSet; 4] = [").unwrap(); for line in lines .by_ref() .skip_while(|x| x.starts_with('#')) @@ -38,7 +38,7 @@ fn main() { continue; } writeln!(&mut file, "#[allow(clippy::style)]").unwrap(); - write!(&mut file, "static EMIT_PROB_{}: phf::Map<&'static str, f64> = ", i).unwrap(); + write!(&mut file, "pub static EMIT_PROB_{}: phf::Map<&'static str, f64> = ", i).unwrap(); let mut map = phf_codegen::Map::new(); for word_prob in line.split(',') { let mut parts = word_prob.split(':'); @@ -50,5 +50,5 @@ fn main() { i += 1; } writeln!(&mut file, "#[allow(clippy::style)]").unwrap(); - writeln!(&mut file, "static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; 4] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap(); + writeln!(&mut file, "pub static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; 4] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap(); } diff --git a/src/hmm.rs b/src/hmm.rs index b8f51cb..b349c14 100644 --- a/src/hmm.rs +++ b/src/hmm.rs @@ -1,7 +1,8 @@ -use std::cmp::Ordering; - use lazy_static::lazy_static; +use once_cell::sync::OnceCell; use regex::Regex; +use std::cmp::Ordering; +use std::collections::HashMap; use crate::SplitMatches; @@ -20,6 +21,22 @@ pub enum Status { S = 3, } +pub struct HmmModel { + pub initial_probs: StatusSet, + pub trans_probs: [StatusSet; 4], + pub emit_probs: [HashMap; 4], +} + +static CUSTOM_HMM_MODEL: OnceCell = OnceCell::new(); + +pub fn get_custom_hmm_model() -> Option<&'static HmmModel> { + CUSTOM_HMM_MODEL.get() +} + +pub fn init_custom_hmm_model(model: HmmModel) -> Result<(), HmmModel> { + CUSTOM_HMM_MODEL.set(model) +} + static PREV_STATUS: [[Status; 2]; 4] = [ [Status::E, Status::S], // B [Status::B, Status::M], // E @@ -27,10 +44,85 @@ static PREV_STATUS: [[Status; 2]; 4] = [ [Status::S, Status::E], // S ]; -include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs")); - const MIN_FLOAT: f64 = -3.14e100; +#[cfg(feature = "default-hmm-model")] +mod default_hmm { + use super::*; + + include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs")); +} + +#[inline] +fn get_initial_prob(index: usize) -> f64 { + debug_assert!(index < 4); + if let Some(model) = get_custom_hmm_model() { + model.initial_probs[index] + } else { + #[cfg(feature = "default-hmm-model")] + { + default_hmm::INITIAL_PROBS[index] + } + #[cfg(not(feature = "default-hmm-model"))] + { + debug_assert!( + true, + "No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model." + ); + MIN_FLOAT + } + } +} + +#[inline] +fn get_emit_prob(index: usize, word: &str) -> f64 { + debug_assert!(index < 4); + if let Some(model) = get_custom_hmm_model() { + model.emit_probs[index].get(word).cloned().unwrap_or(MIN_FLOAT) + } else { + #[cfg(feature = "default-hmm-model")] + { + default_hmm::EMIT_PROBS[index].get(word).cloned().unwrap_or(MIN_FLOAT) + } + #[cfg(not(feature = "default-hmm-model"))] + { + debug_assert!( + true, + "No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model." + ); + MIN_FLOAT + } + } +} + +#[inline] +fn get_trans_prob(from_index: usize, to_index: usize) -> f64 { + debug_assert!(from_index < 4); + debug_assert!(to_index < 4); + if let Some(model) = get_custom_hmm_model() { + model.trans_probs[from_index] + .get(to_index) + .cloned() + .unwrap_or(MIN_FLOAT) + } else { + #[cfg(feature = "default-hmm-model")] + { + default_hmm::TRANS_PROBS[from_index] + .get(to_index) + .cloned() + .unwrap_or(MIN_FLOAT) + } + #[cfg(not(feature = "default-hmm-model"))] + { + debug_assert!( + true, + "No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model." + ); + MIN_FLOAT + } + } +} + #[allow(non_snake_case)] fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, best_path: &mut Vec) { let str_len = sentence.len(); @@ -57,7 +149,8 @@ fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, bes let x2 = *curr.peek().unwrap(); for y in &states { let first_word = &sentence[x1..x2]; - let prob = INITIAL_PROBS[*y as usize] + EMIT_PROBS[*y as usize].get(first_word).cloned().unwrap_or(MIN_FLOAT); + let index = *y as usize; + let prob = get_initial_prob(index) + get_emit_prob(index, first_word); V[*y as usize] = prob; } @@ -66,14 +159,12 @@ fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, bes for y in &states { let byte_end = *curr.peek().unwrap_or(&str_len); let word = &sentence[byte_start..byte_end]; - let em_prob = EMIT_PROBS[*y as usize].get(word).cloned().unwrap_or(MIN_FLOAT); + let em_prob = get_emit_prob(*y as usize, word); let (prob, state) = PREV_STATUS[*y as usize] .iter() .map(|y0| { ( - V[(t - 1) * R + (*y0 as usize)] - + TRANS_PROBS[*y0 as usize].get(*y as usize).cloned().unwrap_or(MIN_FLOAT) - + em_prob, + V[(t - 1) * R + (*y0 as usize)] + get_trans_prob(*y0 as usize, *y as usize) + em_prob, *y0, ) }) @@ -197,15 +288,52 @@ pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) { cut_with_allocated_memory(sentence, words, &mut V, &mut prev, &mut path); } +#[cfg(all(test, not(feature = "default-hmm-model")))] +pub fn test_init_custom_hmm_model() { + use std::convert::TryInto; + + mod hmm_prob { + use super::*; + include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs")); + } + + if get_custom_hmm_model().is_none() { + let initial_probs = hmm_prob::INITIAL_PROBS; + let trans_probs = hmm_prob::TRANS_PROBS; + let emit_probs: [HashMap<_, _>; 4] = { + let mut emit_probs = Vec::new(); + for prob in hmm_prob::EMIT_PROBS { + let mut probs = HashMap::new(); + for (k, v) in prob { + probs.insert(k.to_string(), *v); + } + emit_probs.push(probs); + } + emit_probs.try_into().unwrap() + }; + let _ = init_custom_hmm_model(HmmModel { + initial_probs, + trans_probs, + emit_probs, + }); + } +} + +#[cfg(all(test, feature = "default-hmm-model"))] +pub fn test_init_custom_hmm_model() { + // nothing +} + #[cfg(test)] mod tests { - use super::{cut, viterbi, Status}; + use super::*; #[test] #[allow(non_snake_case)] fn test_viterbi() { use super::Status::*; + test_init_custom_hmm_model(); let sentence = "小明硕士毕业于中国科学院计算所"; let R = 4; @@ -219,6 +347,7 @@ mod tests { #[test] fn test_hmm_cut() { + test_init_custom_hmm_model(); let sentence = "小明硕士毕业于中国科学院计算所"; let mut words = Vec::with_capacity(sentence.chars().count() / 2); cut(sentence, &mut words); diff --git a/src/lib.rs b/src/lib.rs index b5555a6..ff3c0ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,8 +75,8 @@ use std::cmp::Ordering; use std::io::BufRead; use cedarwood::Cedar; -use hashbrown::HashMap; use regex::{Match, Matches, Regex}; +use std::collections::HashMap; pub(crate) type FxHashMap = HashMap; @@ -88,6 +88,8 @@ pub use crate::keywords::tfidf::TFIDF; #[cfg(any(feature = "tfidf", feature = "textrank"))] pub use crate::keywords::{Keyword, KeywordExtract}; +pub use crate::hmm::{get_custom_hmm_model, init_custom_hmm_model}; + mod errors; mod hmm; #[cfg(any(feature = "tfidf", feature = "textrank"))] @@ -806,6 +808,7 @@ impl Jieba { #[cfg(test)] mod tests { use super::{Jieba, SplitMatches, SplitState, Tag, Token, TokenizeMode, RE_HAN_DEFAULT}; + use crate::hmm::test_init_custom_hmm_model; use std::io::BufReader; #[test] @@ -900,6 +903,7 @@ mod tests { #[test] fn test_cut_with_hmm() { + test_init_custom_hmm_model(); let jieba = Jieba::new(); let words = jieba.cut("我们中出了一个叛徒", false); assert_eq!(words, vec!["我们", "中", "出", "了", "一个", "叛徒"]); @@ -917,6 +921,7 @@ mod tests { #[test] fn test_cut_weicheng() { + test_init_custom_hmm_model(); static WEICHENG_TXT: &str = include_str!("../examples/weicheng/src/weicheng.txt"); let jieba = Jieba::new(); for line in WEICHENG_TXT.split('\n') { @@ -926,6 +931,7 @@ mod tests { #[test] fn test_cut_for_search() { + test_init_custom_hmm_model(); let jieba = Jieba::new(); let words = jieba.cut_for_search("南京市长江大桥", true); assert_eq!(words, vec!["南京", "京市", "南京市", "长江", "大桥", "长江大桥"]); @@ -962,6 +968,7 @@ mod tests { #[test] fn test_tag() { + test_init_custom_hmm_model(); let jieba = Jieba::new(); let tags = jieba.tag( "我是拖拉机学院手扶拖拉机专业的。不用多久,我就会升职加薪,当上CEO,走上人生巅峰。", @@ -1078,6 +1085,7 @@ mod tests { #[test] fn test_tokenize() { + test_init_custom_hmm_model(); let jieba = Jieba::new(); let tokens = jieba.tokenize("南京市长江大桥", TokenizeMode::Default, false); assert_eq!( @@ -1305,6 +1313,7 @@ mod tests { #[test] fn test_userdict_hmm() { + test_init_custom_hmm_model(); let mut jieba = Jieba::new(); let tokens = jieba.tokenize("我们中出了一个叛徒", TokenizeMode::Default, true); assert_eq!(