diff --git a/src/hmm.rs b/src/hmm.rs index b8f51cb..b33cbbd 100644 --- a/src/hmm.rs +++ b/src/hmm.rs @@ -31,8 +31,24 @@ include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs")); const MIN_FLOAT: f64 = -3.14e100; +pub(crate) struct HmmContext { + v: Vec, + prev: Vec>, + best_path: Vec, +} + +impl HmmContext { + pub fn new(num_states: usize, num_characters: usize) -> Self { + HmmContext { + v: vec![0.0; num_states * num_characters], + prev: vec![None; num_states * num_characters], + best_path: vec![Status::B; num_characters], + } + } +} + #[allow(non_snake_case)] -fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, best_path: &mut Vec) { +fn viterbi(sentence: &str, hmm_context: &mut HmmContext) { let str_len = sentence.len(); let states = [Status::B, Status::M, Status::E, Status::S]; #[allow(non_snake_case)] @@ -40,16 +56,17 @@ fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, bes let C = sentence.chars().count(); assert!(C > 1); - if prev.len() < R * C { - prev.resize(R * C, None); + // TODO: Can code just do fill() with the default instead of clear() and resize? + if hmm_context.prev.len() < R * C { + hmm_context.prev.resize(R * C, None); } - if V.len() < R * C { - V.resize(R * C, 0.0); + if hmm_context.v.len() < R * C { + hmm_context.v.resize(R * C, 0.0); } - if best_path.len() < C { - best_path.resize(C, Status::B); + if hmm_context.best_path.len() < C { + hmm_context.best_path.resize(C, Status::B); } let mut curr = sentence.char_indices().map(|x| x.0).peekable(); @@ -58,7 +75,7 @@ fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, bes for y in &states { let first_word = &sentence[x1..x2]; let prob = INITIAL_PROBS[*y as usize] + EMIT_PROBS[*y as usize].get(first_word).cloned().unwrap_or(MIN_FLOAT); - V[*y as usize] = prob; + hmm_context.v[*y as usize] = prob; } let mut t = 1; @@ -71,7 +88,7 @@ fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, bes .iter() .map(|y0| { ( - V[(t - 1) * R + (*y0 as usize)] + hmm_context.v[(t - 1) * R + (*y0 as usize)] + TRANS_PROBS[*y0 as usize].get(*y as usize).cloned().unwrap_or(MIN_FLOAT) + em_prob, *y0, @@ -80,8 +97,8 @@ fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, bes .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal)) .unwrap(); let idx = (t * R) + (*y as usize); - V[idx] = prob; - prev[idx] = Some(state); + hmm_context.v[idx] = prob; + hmm_context.prev[idx] = Some(state); } t += 1; @@ -89,42 +106,36 @@ fn viterbi(sentence: &str, V: &mut Vec, prev: &mut Vec>, bes let (_prob, state) = [Status::E, Status::S] .iter() - .map(|y| (V[(C - 1) * R + (*y as usize)], y)) + .map(|y| (hmm_context.v[(C - 1) * R + (*y as usize)], y)) .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal)) .unwrap(); let mut t = C - 1; let mut curr = *state; - best_path[t] = *state; - while let Some(p) = prev[t * R + (curr as usize)] { + hmm_context.best_path[t] = *state; + while let Some(p) = hmm_context.prev[t * R + (curr as usize)] { assert!(t > 0); - best_path[t - 1] = p; + hmm_context.best_path[t - 1] = p; curr = p; t -= 1; } - prev.clear(); - V.clear(); + hmm_context.prev.clear(); + hmm_context.v.clear(); } #[allow(non_snake_case)] -pub fn cut_internal<'a>( - sentence: &'a str, - words: &mut Vec<&'a str>, - V: &mut Vec, - prev: &mut Vec>, - path: &mut Vec, -) { +pub fn cut_internal<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_context: &mut HmmContext) { let str_len = sentence.len(); - viterbi(sentence, V, prev, path); + viterbi(sentence, hmm_context); let mut begin = 0; let mut next_byte_offset = 0; let mut i = 0; let mut curr = sentence.char_indices().map(|x| x.0).peekable(); while let Some(curr_byte_offset) = curr.next() { - let state = path[i]; + let state = hmm_context.best_path[i]; match state { Status::B => begin = curr_byte_offset, Status::E => { @@ -150,17 +161,11 @@ pub fn cut_internal<'a>( words.push(&sentence[byte_start..]); } - path.clear(); + hmm_context.best_path.clear(); } #[allow(non_snake_case)] -pub(crate) fn cut_with_allocated_memory<'a>( - sentence: &'a str, - words: &mut Vec<&'a str>, - V: &mut Vec, - prev: &mut Vec>, - path: &mut Vec, -) { +pub(crate) fn cut_with_allocated_memory<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_context: &mut HmmContext) { let splitter = SplitMatches::new(&RE_HAN, sentence); for state in splitter { let block = state.into_str(); @@ -169,7 +174,7 @@ pub(crate) fn cut_with_allocated_memory<'a>( } if RE_HAN.is_match(block) { if block.chars().count() > 1 { - cut_internal(block, words, V, prev, path); + cut_internal(block, words, hmm_context); } else { words.push(block); } @@ -188,18 +193,15 @@ pub(crate) fn cut_with_allocated_memory<'a>( #[allow(non_snake_case)] pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) { - let R = 4; - let C = sentence.chars().count(); - let mut V = vec![0.0; R * C]; - let mut prev: Vec> = vec![None; R * C]; - let mut path: Vec = vec![Status::B; C]; + // TODO: Is 4 just the number of variants in Status? + let mut hmm_context = HmmContext::new(4, sentence.chars().count()); - cut_with_allocated_memory(sentence, words, &mut V, &mut prev, &mut path); + cut_with_allocated_memory(sentence, words, &mut hmm_context) } #[cfg(test)] mod tests { - use super::{cut, viterbi, Status}; + use super::{cut, viterbi, HmmContext}; #[test] #[allow(non_snake_case)] @@ -208,13 +210,10 @@ mod tests { let sentence = "小明硕士毕业于中国科学院计算所"; - let R = 4; - let C = sentence.chars().count(); - let mut V = vec![0.0; R * C]; - let mut prev: Vec> = vec![None; R * C]; - let mut path: Vec = vec![Status::B; C]; - viterbi(sentence, &mut V, &mut prev, &mut path); - assert_eq!(path, vec![B, E, B, E, B, M, E, B, E, B, M, E, B, E, S]); + // TODO: Is 4 just the number of variants in Status? + let mut hmm_context = HmmContext::new(4, sentence.chars().count()); + viterbi(sentence, &mut hmm_context); + assert_eq!(hmm_context.best_path, vec![B, E, B, E, B, M, E, B, E, B, M, E, B, E, S]); } #[test] diff --git a/src/lib.rs b/src/lib.rs index efc1a75..a52d193 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -498,9 +498,7 @@ impl Jieba { words: &mut Vec<&'a str>, route: &mut Vec<(f64, usize)>, dag: &mut StaticSparseDAG, - V: &mut Vec, - prev: &mut Vec>, - path: &mut Vec, + hmm_context: &mut hmm::HmmContext, ) { self.dag(sentence, dag); self.calc(sentence, dag, route); @@ -526,7 +524,7 @@ impl Jieba { if word.chars().count() == 1 { words.push(word); } else if self.cedar.exact_match_search(word).is_none() { - hmm::cut_with_allocated_memory(word, words, V, prev, path); + hmm::cut_with_allocated_memory(word, words, hmm_context); } else { let mut word_indices = word.char_indices().map(|x| x.0).peekable(); while let Some(byte_start) = word_indices.next() { @@ -582,11 +580,8 @@ impl Jieba { let mut route = Vec::with_capacity(heuristic_capacity); let mut dag = StaticSparseDAG::with_size_hint(heuristic_capacity); - let R = 4; - let C = sentence.chars().count(); - let mut V = if hmm { vec![0.0; R * C] } else { Vec::new() }; - let mut prev: Vec> = if hmm { vec![None; R * C] } else { Vec::new() }; - let mut path: Vec = if hmm { vec![hmm::Status::B; C] } else { Vec::new() }; + // TODO: Is 4 just the number of variants in Status? + let mut hmm_context = hmm::HmmContext::new(4, sentence.chars().count()); for state in splitter { match state { @@ -597,7 +592,7 @@ impl Jieba { if cut_all { self.cut_all_internal(block, &mut words); } else if hmm { - self.cut_dag_hmm(block, &mut words, &mut route, &mut dag, &mut V, &mut prev, &mut path); + self.cut_dag_hmm(block, &mut words, &mut route, &mut dag, &mut hmm_context); } else { self.cut_dag_no_hmm(block, &mut words, &mut route, &mut dag); }