Skip to content

Commit

Permalink
Extract HmmContext struct.
Browse files Browse the repository at this point in the history
Group V, prev, and path together into one struct that provies all the context
for the HMM model to function.
  • Loading branch information
awong-dev committed Apr 14, 2024
1 parent f7e5793 commit cbb0fe1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 51 deletions.
84 changes: 43 additions & 41 deletions src/hmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,42 @@ include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));

const MIN_FLOAT: f64 = -3.14e100;

pub(crate) struct HmmContext {
v: Vec<f64>,
prev: Vec<Option<Status>>,
best_path: Vec<Status>,
}

impl HmmContext {
pub fn new(num_states: usize, num_characters: usize) -> Self {
HmmContext {
v: vec![0.0; num_states * num_characters],
prev: vec![None; num_states * num_characters],
best_path: vec![Status::B; num_characters],
}
}
}

#[allow(non_snake_case)]
fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, best_path: &mut Vec<Status>) {
fn viterbi(sentence: &str, hmm_state: &mut HmmContext) {
let str_len = sentence.len();
let states = [Status::B, Status::M, Status::E, Status::S];
#[allow(non_snake_case)]
let R = states.len();
let C = sentence.chars().count();
assert!(C > 1);

if prev.len() < R * C {
prev.resize(R * C, None);
// TODO: Can code just do fill() with the default instead of clear() and resize?
if hmm_state.prev.len() < R * C {
hmm_state.prev.resize(R * C, None);
}

if V.len() < R * C {
V.resize(R * C, 0.0);
if hmm_state.v.len() < R * C {
hmm_state.v.resize(R * C, 0.0);
}

if best_path.len() < C {
best_path.resize(C, Status::B);
if hmm_state.best_path.len() < C {
hmm_state.best_path.resize(C, Status::B);
}

let mut curr = sentence.char_indices().map(|x| x.0).peekable();
Expand All @@ -58,7 +75,7 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
for y in &states {
let first_word = &sentence[x1..x2];
let prob = INITIAL_PROBS[*y as usize] + EMIT_PROBS[*y as usize].get(first_word).cloned().unwrap_or(MIN_FLOAT);
V[*y as usize] = prob;
hmm_state.v[*y as usize] = prob;
}

let mut t = 1;
Expand All @@ -71,7 +88,7 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
.iter()
.map(|y0| {
(
V[(t - 1) * R + (*y0 as usize)]
hmm_state.v[(t - 1) * R + (*y0 as usize)]
+ TRANS_PROBS[*y0 as usize].get(*y as usize).cloned().unwrap_or(MIN_FLOAT)
+ em_prob,
*y0,
Expand All @@ -80,51 +97,45 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
.unwrap();
let idx = (t * R) + (*y as usize);
V[idx] = prob;
prev[idx] = Some(state);
hmm_state.v[idx] = prob;
hmm_state.prev[idx] = Some(state);
}

t += 1;
}

let (_prob, state) = [Status::E, Status::S]
.iter()
.map(|y| (V[(C - 1) * R + (*y as usize)], y))
.map(|y| (hmm_state.v[(C - 1) * R + (*y as usize)], y))
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
.unwrap();

let mut t = C - 1;
let mut curr = *state;

best_path[t] = *state;
while let Some(p) = prev[t * R + (curr as usize)] {
hmm_state.best_path[t] = *state;
while let Some(p) = hmm_state.prev[t * R + (curr as usize)] {
assert!(t > 0);
best_path[t - 1] = p;
hmm_state.best_path[t - 1] = p;
curr = p;
t -= 1;
}

prev.clear();
V.clear();
hmm_state.prev.clear();
hmm_state.v.clear();
}

#[allow(non_snake_case)]
pub fn cut_internal<'a>(
sentence: &'a str,
words: &mut Vec<&'a str>,
V: &mut Vec<f64>,
prev: &mut Vec<Option<Status>>,
path: &mut Vec<Status>,
) {
pub fn cut_internal<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_state: &mut HmmContext) {
let str_len = sentence.len();
viterbi(sentence, V, prev, path);
viterbi(sentence, hmm_state);
let mut begin = 0;
let mut next_byte_offset = 0;
let mut i = 0;

let mut curr = sentence.char_indices().map(|x| x.0).peekable();
while let Some(curr_byte_offset) = curr.next() {
let state = path[i];
let state = hmm_state.best_path[i];
match state {
Status::B => begin = curr_byte_offset,
Status::E => {
Expand All @@ -150,17 +161,11 @@ pub fn cut_internal<'a>(
words.push(&sentence[byte_start..]);
}

path.clear();
hmm_state.best_path.clear();
}

#[allow(non_snake_case)]
pub(crate) fn cut_with_allocated_memory<'a>(
sentence: &'a str,
words: &mut Vec<&'a str>,
V: &mut Vec<f64>,
prev: &mut Vec<Option<Status>>,
path: &mut Vec<Status>,
) {
pub(crate) fn cut_with_allocated_memory<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_state: &mut HmmContext) {
let splitter = SplitMatches::new(&RE_HAN, sentence);
for state in splitter {
let block = state.into_str();
Expand All @@ -169,7 +174,7 @@ pub(crate) fn cut_with_allocated_memory<'a>(
}
if RE_HAN.is_match(block) {
if block.chars().count() > 1 {
cut_internal(block, words, V, prev, path);
cut_internal(block, words, hmm_state);
} else {
words.push(block);
}
Expand All @@ -188,13 +193,10 @@ pub(crate) fn cut_with_allocated_memory<'a>(

#[allow(non_snake_case)]
pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) {
let R = 4;
let C = sentence.chars().count();
let mut V = vec![0.0; R * C];
let mut prev: Vec<Option<Status>> = vec![None; R * C];
let mut path: Vec<Status> = vec![Status::B; C];
// TODO: Is 4 just the number of variants in Status?
let mut hmm_state = HmmContext::new(4, sentence.chars().count());

cut_with_allocated_memory(sentence, words, &mut V, &mut prev, &mut path);
cut_with_allocated_memory(sentence, words, &mut hmm_state)
}

#[cfg(test)]
Expand Down
15 changes: 5 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,7 @@ impl Jieba {
words: &mut Vec<&'a str>,
route: &mut Vec<(f64, usize)>,
dag: &mut StaticSparseDAG,
V: &mut Vec<f64>,
prev: &mut Vec<Option<hmm::Status>>,
path: &mut Vec<hmm::Status>,
hmm_state: &mut hmm::HmmContext,
) {
self.dag(sentence, dag);
self.calc(sentence, dag, route);
Expand All @@ -526,7 +524,7 @@ impl Jieba {
if word.chars().count() == 1 {
words.push(word);
} else if self.cedar.exact_match_search(word).is_none() {
hmm::cut_with_allocated_memory(word, words, V, prev, path);
hmm::cut_with_allocated_memory(word, words, hmm_state);
} else {
let mut word_indices = word.char_indices().map(|x| x.0).peekable();
while let Some(byte_start) = word_indices.next() {
Expand Down Expand Up @@ -582,11 +580,8 @@ impl Jieba {
let mut route = Vec::with_capacity(heuristic_capacity);
let mut dag = StaticSparseDAG::with_size_hint(heuristic_capacity);

let R = 4;
let C = sentence.chars().count();
let mut V = if hmm { vec![0.0; R * C] } else { Vec::new() };
let mut prev: Vec<Option<hmm::Status>> = if hmm { vec![None; R * C] } else { Vec::new() };
let mut path: Vec<hmm::Status> = if hmm { vec![hmm::Status::B; C] } else { Vec::new() };
// TODO: Is 4 just the number of variants in Status?
let mut hmm_state = hmm::HmmContext::new(4, sentence.chars().count());

for state in splitter {
match state {
Expand All @@ -597,7 +592,7 @@ impl Jieba {
if cut_all {
self.cut_all_internal(block, &mut words);
} else if hmm {
self.cut_dag_hmm(block, &mut words, &mut route, &mut dag, &mut V, &mut prev, &mut path);
self.cut_dag_hmm(block, &mut words, &mut route, &mut dag, &mut hmm_state);
} else {
self.cut_dag_no_hmm(block, &mut words, &mut route, &mut dag);
}
Expand Down

0 comments on commit cbb0fe1

Please sign in to comment.