diff --git a/build.rs b/build.rs index bc2081d..6a11084 100644 --- a/build.rs +++ b/build.rs @@ -13,13 +13,13 @@ fn main() { let mut lines = reader.lines().map(|x| x.unwrap()).skip_while(|x| x.starts_with('#')); let prob_start = lines.next().unwrap(); writeln!(&mut file, "#[allow(clippy::style)]").unwrap(); - write!(&mut file, "static INITIAL_PROBS: StatusSet = [").unwrap(); + write!(&mut file, "static INITIAL_PROBS: StateSet = [").unwrap(); for prob in prob_start.split(' ') { write!(&mut file, "{}, ", prob).unwrap(); } write!(&mut file, "];\n\n").unwrap(); writeln!(&mut file, "#[allow(clippy::style)]").unwrap(); - write!(&mut file, "static TRANS_PROBS: [StatusSet; 4] = [").unwrap(); + write!(&mut file, "static TRANS_PROBS: [StateSet; crate::hmm::NUM_STATES] = [").unwrap(); for line in lines .by_ref() .skip_while(|x| x.starts_with('#')) @@ -50,5 +50,5 @@ fn main() { i += 1; } writeln!(&mut file, "#[allow(clippy::style)]").unwrap(); - writeln!(&mut file, "static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; 4] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap(); + writeln!(&mut file, "static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; crate::hmm::NUM_STATES] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap(); } diff --git a/src/hmm.rs b/src/hmm.rs index b33cbbd..4258349 100644 --- a/src/hmm.rs +++ b/src/hmm.rs @@ -10,21 +10,46 @@ lazy_static! { static ref RE_SKIP: Regex = Regex::new(r"([a-zA-Z0-9]+(?:.\d+)?%?)").unwrap(); } -pub type StatusSet = [f64; 4]; - +pub const NUM_STATES: usize = 4; + +pub type StateSet = [f64; NUM_STATES]; + +/// Result of hmm is a labeling of each Unicode Scalar Value in the input +/// string with Begin, Middle, End, or Single. These denote the proposed +/// segments. A segment is one of the following two patterns. +/// +/// Begin, [Middle...], End +/// Single +/// +/// Each state in the enum is also assigned an index value from 0-3 that +/// can be used as an index into an array representing data pertaining +/// to that state. +/// +/// WARNING: The data file format for hmm.model comments imply one can +/// reassign the index values of each state at the top but `build.rs` +/// currently ignores the mapping. Do not reassign these indicies without +/// verifying how it interacts with `build.rs`. These indicies must also +/// match the order if ALLOWED_PREV_STATUS. #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)] -pub enum Status { - B = 0, - E = 1, - M = 2, - S = 3, +pub enum State { + Begin = 0, + End = 1, + Middle = 2, + Single = 3, } -static PREV_STATUS: [[Status; 2]; 4] = [ - [Status::E, Status::S], // B - [Status::B, Status::M], // E - [Status::M, Status::B], // M - [Status::S, Status::E], // S +// Mapping representing the allow transitiongs into the given state. +// +// WARNING: Ordering must match the indicies in State. +static ALLOWED_PREV_STATUS: [[State; 2]; NUM_STATES] = [ + // Can preceed State::Begin + [State::End, State::Single], + // Can preceed State::End + [State::Begin, State::Middle], + // Can preceed State::Middle + [State::Middle, State::Begin], + // Can preceed State::Single + [State::Single, State::End], ]; include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs")); @@ -33,16 +58,16 @@ const MIN_FLOAT: f64 = -3.14e100; pub(crate) struct HmmContext { v: Vec, - prev: Vec>, - best_path: Vec, + prev: Vec>, + best_path: Vec, } impl HmmContext { - pub fn new(num_states: usize, num_characters: usize) -> Self { + pub fn new(num_characters: usize) -> Self { HmmContext { - v: vec![0.0; num_states * num_characters], - prev: vec![None; num_states * num_characters], - best_path: vec![Status::B; num_characters], + v: vec![0.0; NUM_STATES * num_characters], + prev: vec![None; NUM_STATES * num_characters], + best_path: vec![State::Begin; num_characters], } } } @@ -50,7 +75,7 @@ impl HmmContext { #[allow(non_snake_case)] fn viterbi(sentence: &str, hmm_context: &mut HmmContext) { let str_len = sentence.len(); - let states = [Status::B, Status::M, Status::E, Status::S]; + let states = [State::Begin, State::Middle, State::End, State::Single]; #[allow(non_snake_case)] let R = states.len(); let C = sentence.chars().count(); @@ -66,7 +91,7 @@ fn viterbi(sentence: &str, hmm_context: &mut HmmContext) { } if hmm_context.best_path.len() < C { - hmm_context.best_path.resize(C, Status::B); + hmm_context.best_path.resize(C, State::Begin); } let mut curr = sentence.char_indices().map(|x| x.0).peekable(); @@ -84,7 +109,7 @@ fn viterbi(sentence: &str, hmm_context: &mut HmmContext) { let byte_end = *curr.peek().unwrap_or(&str_len); let word = &sentence[byte_start..byte_end]; let em_prob = EMIT_PROBS[*y as usize].get(word).cloned().unwrap_or(MIN_FLOAT); - let (prob, state) = PREV_STATUS[*y as usize] + let (prob, state) = ALLOWED_PREV_STATUS[*y as usize] .iter() .map(|y0| { ( @@ -104,7 +129,7 @@ fn viterbi(sentence: &str, hmm_context: &mut HmmContext) { t += 1; } - let (_prob, state) = [Status::E, Status::S] + let (_prob, state) = [State::End, State::Single] .iter() .map(|y| (hmm_context.v[(C - 1) * R + (*y as usize)], y)) .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal)) @@ -137,20 +162,20 @@ pub fn cut_internal<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_context while let Some(curr_byte_offset) = curr.next() { let state = hmm_context.best_path[i]; match state { - Status::B => begin = curr_byte_offset, - Status::E => { + State::Begin => begin = curr_byte_offset, + State::End => { let byte_start = begin; let byte_end = *curr.peek().unwrap_or(&str_len); words.push(&sentence[byte_start..byte_end]); next_byte_offset = byte_end; } - Status::S => { + State::Single => { let byte_start = curr_byte_offset; let byte_end = *curr.peek().unwrap_or(&str_len); words.push(&sentence[byte_start..byte_end]); next_byte_offset = byte_end; } - Status::M => { /* do nothing */ } + State::Middle => { /* do nothing */ } } i += 1; @@ -193,8 +218,7 @@ pub(crate) fn cut_with_allocated_memory<'a>(sentence: &'a str, words: &mut Vec<& #[allow(non_snake_case)] pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) { - // TODO: Is 4 just the number of variants in Status? - let mut hmm_context = HmmContext::new(4, sentence.chars().count()); + let mut hmm_context = HmmContext::new(sentence.chars().count()); cut_with_allocated_memory(sentence, words, &mut hmm_context) } @@ -206,14 +230,16 @@ mod tests { #[test] #[allow(non_snake_case)] fn test_viterbi() { - use super::Status::*; + use super::State::*; let sentence = "小明硕士毕业于中国科学院计算所"; - // TODO: Is 4 just the number of variants in Status? - let mut hmm_context = HmmContext::new(4, sentence.chars().count()); + let mut hmm_context = HmmContext::new(sentence.chars().count()); viterbi(sentence, &mut hmm_context); - assert_eq!(hmm_context.best_path, vec![B, E, B, E, B, M, E, B, E, B, M, E, B, E, S]); + assert_eq!( + hmm_context.best_path, + vec![Begin, End, Begin, End, Begin, Middle, End, Begin, End, Begin, Middle, End, Begin, End, Single] + ); } #[test] diff --git a/src/lib.rs b/src/lib.rs index a52d193..5305182 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -580,8 +580,7 @@ impl Jieba { let mut route = Vec::with_capacity(heuristic_capacity); let mut dag = StaticSparseDAG::with_size_hint(heuristic_capacity); - // TODO: Is 4 just the number of variants in Status? - let mut hmm_context = hmm::HmmContext::new(4, sentence.chars().count()); + let mut hmm_context = hmm::HmmContext::new(sentence.chars().count()); for state in splitter { match state {