Skip to content

Commit

Permalink
Extract HmmContext struct. (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
awong-dev authored Apr 14, 2024
1 parent f7e5793 commit 46e4273
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 59 deletions.
97 changes: 48 additions & 49 deletions src/hmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,42 @@ include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));

const MIN_FLOAT: f64 = -3.14e100;

pub(crate) struct HmmContext {
v: Vec<f64>,
prev: Vec<Option<Status>>,
best_path: Vec<Status>,
}

impl HmmContext {
pub fn new(num_states: usize, num_characters: usize) -> Self {
HmmContext {
v: vec![0.0; num_states * num_characters],
prev: vec![None; num_states * num_characters],
best_path: vec![Status::B; num_characters],
}
}
}

#[allow(non_snake_case)]
fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, best_path: &mut Vec<Status>) {
fn viterbi(sentence: &str, hmm_context: &mut HmmContext) {
let str_len = sentence.len();
let states = [Status::B, Status::M, Status::E, Status::S];
#[allow(non_snake_case)]
let R = states.len();
let C = sentence.chars().count();
assert!(C > 1);

if prev.len() < R * C {
prev.resize(R * C, None);
// TODO: Can code just do fill() with the default instead of clear() and resize?
if hmm_context.prev.len() < R * C {
hmm_context.prev.resize(R * C, None);
}

if V.len() < R * C {
V.resize(R * C, 0.0);
if hmm_context.v.len() < R * C {
hmm_context.v.resize(R * C, 0.0);
}

if best_path.len() < C {
best_path.resize(C, Status::B);
if hmm_context.best_path.len() < C {
hmm_context.best_path.resize(C, Status::B);
}

let mut curr = sentence.char_indices().map(|x| x.0).peekable();
Expand All @@ -58,7 +75,7 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
for y in &states {
let first_word = &sentence[x1..x2];
let prob = INITIAL_PROBS[*y as usize] + EMIT_PROBS[*y as usize].get(first_word).cloned().unwrap_or(MIN_FLOAT);
V[*y as usize] = prob;
hmm_context.v[*y as usize] = prob;
}

let mut t = 1;
Expand All @@ -71,7 +88,7 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
.iter()
.map(|y0| {
(
V[(t - 1) * R + (*y0 as usize)]
hmm_context.v[(t - 1) * R + (*y0 as usize)]
+ TRANS_PROBS[*y0 as usize].get(*y as usize).cloned().unwrap_or(MIN_FLOAT)
+ em_prob,
*y0,
Expand All @@ -80,51 +97,45 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
.unwrap();
let idx = (t * R) + (*y as usize);
V[idx] = prob;
prev[idx] = Some(state);
hmm_context.v[idx] = prob;
hmm_context.prev[idx] = Some(state);
}

t += 1;
}

let (_prob, state) = [Status::E, Status::S]
.iter()
.map(|y| (V[(C - 1) * R + (*y as usize)], y))
.map(|y| (hmm_context.v[(C - 1) * R + (*y as usize)], y))
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
.unwrap();

let mut t = C - 1;
let mut curr = *state;

best_path[t] = *state;
while let Some(p) = prev[t * R + (curr as usize)] {
hmm_context.best_path[t] = *state;
while let Some(p) = hmm_context.prev[t * R + (curr as usize)] {
assert!(t > 0);
best_path[t - 1] = p;
hmm_context.best_path[t - 1] = p;
curr = p;
t -= 1;
}

prev.clear();
V.clear();
hmm_context.prev.clear();
hmm_context.v.clear();
}

#[allow(non_snake_case)]
pub fn cut_internal<'a>(
sentence: &'a str,
words: &mut Vec<&'a str>,
V: &mut Vec<f64>,
prev: &mut Vec<Option<Status>>,
path: &mut Vec<Status>,
) {
pub fn cut_internal<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_context: &mut HmmContext) {
let str_len = sentence.len();
viterbi(sentence, V, prev, path);
viterbi(sentence, hmm_context);
let mut begin = 0;
let mut next_byte_offset = 0;
let mut i = 0;

let mut curr = sentence.char_indices().map(|x| x.0).peekable();
while let Some(curr_byte_offset) = curr.next() {
let state = path[i];
let state = hmm_context.best_path[i];
match state {
Status::B => begin = curr_byte_offset,
Status::E => {
Expand All @@ -150,17 +161,11 @@ pub fn cut_internal<'a>(
words.push(&sentence[byte_start..]);
}

path.clear();
hmm_context.best_path.clear();
}

#[allow(non_snake_case)]
pub(crate) fn cut_with_allocated_memory<'a>(
sentence: &'a str,
words: &mut Vec<&'a str>,
V: &mut Vec<f64>,
prev: &mut Vec<Option<Status>>,
path: &mut Vec<Status>,
) {
pub(crate) fn cut_with_allocated_memory<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_context: &mut HmmContext) {
let splitter = SplitMatches::new(&RE_HAN, sentence);
for state in splitter {
let block = state.into_str();
Expand All @@ -169,7 +174,7 @@ pub(crate) fn cut_with_allocated_memory<'a>(
}
if RE_HAN.is_match(block) {
if block.chars().count() > 1 {
cut_internal(block, words, V, prev, path);
cut_internal(block, words, hmm_context);
} else {
words.push(block);
}
Expand All @@ -188,18 +193,15 @@ pub(crate) fn cut_with_allocated_memory<'a>(

#[allow(non_snake_case)]
pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) {
let R = 4;
let C = sentence.chars().count();
let mut V = vec![0.0; R * C];
let mut prev: Vec<Option<Status>> = vec![None; R * C];
let mut path: Vec<Status> = vec![Status::B; C];
// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = HmmContext::new(4, sentence.chars().count());

cut_with_allocated_memory(sentence, words, &mut V, &mut prev, &mut path);
cut_with_allocated_memory(sentence, words, &mut hmm_context)
}

#[cfg(test)]
mod tests {
use super::{cut, viterbi, Status};
use super::{cut, viterbi, HmmContext};

#[test]
#[allow(non_snake_case)]
Expand All @@ -208,13 +210,10 @@ mod tests {

let sentence = "小明硕士毕业于中国科学院计算所";

let R = 4;
let C = sentence.chars().count();
let mut V = vec![0.0; R * C];
let mut prev: Vec<Option<Status>> = vec![None; R * C];
let mut path: Vec<Status> = vec![Status::B; C];
viterbi(sentence, &mut V, &mut prev, &mut path);
assert_eq!(path, vec![B, E, B, E, B, M, E, B, E, B, M, E, B, E, S]);
// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = HmmContext::new(4, sentence.chars().count());
viterbi(sentence, &mut hmm_context);
assert_eq!(hmm_context.best_path, vec![B, E, B, E, B, M, E, B, E, B, M, E, B, E, S]);
}

#[test]
Expand Down
15 changes: 5 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,7 @@ impl Jieba {
words: &mut Vec<&'a str>,
route: &mut Vec<(f64, usize)>,
dag: &mut StaticSparseDAG,
V: &mut Vec<f64>,
prev: &mut Vec<Option<hmm::Status>>,
path: &mut Vec<hmm::Status>,
hmm_context: &mut hmm::HmmContext,
) {
self.dag(sentence, dag);
self.calc(sentence, dag, route);
Expand All @@ -526,7 +524,7 @@ impl Jieba {
if word.chars().count() == 1 {
words.push(word);
} else if self.cedar.exact_match_search(word).is_none() {
hmm::cut_with_allocated_memory(word, words, V, prev, path);
hmm::cut_with_allocated_memory(word, words, hmm_context);
} else {
let mut word_indices = word.char_indices().map(|x| x.0).peekable();
while let Some(byte_start) = word_indices.next() {
Expand Down Expand Up @@ -582,11 +580,8 @@ impl Jieba {
let mut route = Vec::with_capacity(heuristic_capacity);
let mut dag = StaticSparseDAG::with_size_hint(heuristic_capacity);

let R = 4;
let C = sentence.chars().count();
let mut V = if hmm { vec![0.0; R * C] } else { Vec::new() };
let mut prev: Vec<Option<hmm::Status>> = if hmm { vec![None; R * C] } else { Vec::new() };
let mut path: Vec<hmm::Status> = if hmm { vec![hmm::Status::B; C] } else { Vec::new() };
// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = hmm::HmmContext::new(4, sentence.chars().count());

for state in splitter {
match state {
Expand All @@ -597,7 +592,7 @@ impl Jieba {
if cut_all {
self.cut_all_internal(block, &mut words);
} else if hmm {
self.cut_dag_hmm(block, &mut words, &mut route, &mut dag, &mut V, &mut prev, &mut path);
self.cut_dag_hmm(block, &mut words, &mut route, &mut dag, &mut hmm_context);
} else {
self.cut_dag_no_hmm(block, &mut words, &mut route, &mut dag);
}
Expand Down

0 comments on commit 46e4273

Please sign in to comment.