Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract HmmContext struct. #107

Merged
merged 1 commit into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 48 additions & 49 deletions src/hmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,42 @@ include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));

const MIN_FLOAT: f64 = -3.14e100;

pub(crate) struct HmmContext {
v: Vec<f64>,
prev: Vec<Option<Status>>,
best_path: Vec<Status>,
}

impl HmmContext {
pub fn new(num_states: usize, num_characters: usize) -> Self {
HmmContext {
v: vec![0.0; num_states * num_characters],
prev: vec![None; num_states * num_characters],
best_path: vec![Status::B; num_characters],
}
}
}

#[allow(non_snake_case)]
fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, best_path: &mut Vec<Status>) {
fn viterbi(sentence: &str, hmm_context: &mut HmmContext) {
let str_len = sentence.len();
let states = [Status::B, Status::M, Status::E, Status::S];
#[allow(non_snake_case)]
let R = states.len();
let C = sentence.chars().count();
assert!(C > 1);

if prev.len() < R * C {
prev.resize(R * C, None);
// TODO: Can code just do fill() with the default instead of clear() and resize?
if hmm_context.prev.len() < R * C {
hmm_context.prev.resize(R * C, None);
}

if V.len() < R * C {
V.resize(R * C, 0.0);
if hmm_context.v.len() < R * C {
hmm_context.v.resize(R * C, 0.0);
}

if best_path.len() < C {
best_path.resize(C, Status::B);
if hmm_context.best_path.len() < C {
hmm_context.best_path.resize(C, Status::B);
}

let mut curr = sentence.char_indices().map(|x| x.0).peekable();
Expand All @@ -58,7 +75,7 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
for y in &states {
let first_word = &sentence[x1..x2];
let prob = INITIAL_PROBS[*y as usize] + EMIT_PROBS[*y as usize].get(first_word).cloned().unwrap_or(MIN_FLOAT);
V[*y as usize] = prob;
hmm_context.v[*y as usize] = prob;
}

let mut t = 1;
Expand All @@ -71,7 +88,7 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
.iter()
.map(|y0| {
(
V[(t - 1) * R + (*y0 as usize)]
hmm_context.v[(t - 1) * R + (*y0 as usize)]
+ TRANS_PROBS[*y0 as usize].get(*y as usize).cloned().unwrap_or(MIN_FLOAT)
+ em_prob,
*y0,
Expand All @@ -80,51 +97,45 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
.unwrap();
let idx = (t * R) + (*y as usize);
V[idx] = prob;
prev[idx] = Some(state);
hmm_context.v[idx] = prob;
hmm_context.prev[idx] = Some(state);
}

t += 1;
}

let (_prob, state) = [Status::E, Status::S]
.iter()
.map(|y| (V[(C - 1) * R + (*y as usize)], y))
.map(|y| (hmm_context.v[(C - 1) * R + (*y as usize)], y))
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
.unwrap();

let mut t = C - 1;
let mut curr = *state;

best_path[t] = *state;
while let Some(p) = prev[t * R + (curr as usize)] {
hmm_context.best_path[t] = *state;
while let Some(p) = hmm_context.prev[t * R + (curr as usize)] {
assert!(t > 0);
best_path[t - 1] = p;
hmm_context.best_path[t - 1] = p;
curr = p;
t -= 1;
}

prev.clear();
V.clear();
hmm_context.prev.clear();
hmm_context.v.clear();
}

#[allow(non_snake_case)]
pub fn cut_internal<'a>(
sentence: &'a str,
words: &mut Vec<&'a str>,
V: &mut Vec<f64>,
prev: &mut Vec<Option<Status>>,
path: &mut Vec<Status>,
) {
pub fn cut_internal<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_context: &mut HmmContext) {
let str_len = sentence.len();
viterbi(sentence, V, prev, path);
viterbi(sentence, hmm_context);
let mut begin = 0;
let mut next_byte_offset = 0;
let mut i = 0;

let mut curr = sentence.char_indices().map(|x| x.0).peekable();
while let Some(curr_byte_offset) = curr.next() {
let state = path[i];
let state = hmm_context.best_path[i];
match state {
Status::B => begin = curr_byte_offset,
Status::E => {
Expand All @@ -150,17 +161,11 @@ pub fn cut_internal<'a>(
words.push(&sentence[byte_start..]);
}

path.clear();
hmm_context.best_path.clear();
}

#[allow(non_snake_case)]
pub(crate) fn cut_with_allocated_memory<'a>(
sentence: &'a str,
words: &mut Vec<&'a str>,
V: &mut Vec<f64>,
prev: &mut Vec<Option<Status>>,
path: &mut Vec<Status>,
) {
pub(crate) fn cut_with_allocated_memory<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_context: &mut HmmContext) {
let splitter = SplitMatches::new(&RE_HAN, sentence);
for state in splitter {
let block = state.into_str();
Expand All @@ -169,7 +174,7 @@ pub(crate) fn cut_with_allocated_memory<'a>(
}
if RE_HAN.is_match(block) {
if block.chars().count() > 1 {
cut_internal(block, words, V, prev, path);
cut_internal(block, words, hmm_context);
} else {
words.push(block);
}
Expand All @@ -188,18 +193,15 @@ pub(crate) fn cut_with_allocated_memory<'a>(

#[allow(non_snake_case)]
pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) {
let R = 4;
let C = sentence.chars().count();
let mut V = vec![0.0; R * C];
let mut prev: Vec<Option<Status>> = vec![None; R * C];
let mut path: Vec<Status> = vec![Status::B; C];
// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = HmmContext::new(4, sentence.chars().count());

cut_with_allocated_memory(sentence, words, &mut V, &mut prev, &mut path);
cut_with_allocated_memory(sentence, words, &mut hmm_context)
}

#[cfg(test)]
mod tests {
use super::{cut, viterbi, Status};
use super::{cut, viterbi, HmmContext};

#[test]
#[allow(non_snake_case)]
Expand All @@ -208,13 +210,10 @@ mod tests {

let sentence = "小明硕士毕业于中国科学院计算所";

let R = 4;
let C = sentence.chars().count();
let mut V = vec![0.0; R * C];
let mut prev: Vec<Option<Status>> = vec![None; R * C];
let mut path: Vec<Status> = vec![Status::B; C];
viterbi(sentence, &mut V, &mut prev, &mut path);
assert_eq!(path, vec![B, E, B, E, B, M, E, B, E, B, M, E, B, E, S]);
// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = HmmContext::new(4, sentence.chars().count());
viterbi(sentence, &mut hmm_context);
assert_eq!(hmm_context.best_path, vec![B, E, B, E, B, M, E, B, E, B, M, E, B, E, S]);
}

#[test]
Expand Down
15 changes: 5 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,7 @@ impl Jieba {
words: &mut Vec<&'a str>,
route: &mut Vec<(f64, usize)>,
dag: &mut StaticSparseDAG,
V: &mut Vec<f64>,
prev: &mut Vec<Option<hmm::Status>>,
path: &mut Vec<hmm::Status>,
hmm_context: &mut hmm::HmmContext,
) {
self.dag(sentence, dag);
self.calc(sentence, dag, route);
Expand All @@ -526,7 +524,7 @@ impl Jieba {
if word.chars().count() == 1 {
words.push(word);
} else if self.cedar.exact_match_search(word).is_none() {
hmm::cut_with_allocated_memory(word, words, V, prev, path);
hmm::cut_with_allocated_memory(word, words, hmm_context);
} else {
let mut word_indices = word.char_indices().map(|x| x.0).peekable();
while let Some(byte_start) = word_indices.next() {
Expand Down Expand Up @@ -582,11 +580,8 @@ impl Jieba {
let mut route = Vec::with_capacity(heuristic_capacity);
let mut dag = StaticSparseDAG::with_size_hint(heuristic_capacity);

let R = 4;
let C = sentence.chars().count();
let mut V = if hmm { vec![0.0; R * C] } else { Vec::new() };
let mut prev: Vec<Option<hmm::Status>> = if hmm { vec![None; R * C] } else { Vec::new() };
let mut path: Vec<hmm::Status> = if hmm { vec![hmm::Status::B; C] } else { Vec::new() };
// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = hmm::HmmContext::new(4, sentence.chars().count());

for state in splitter {
match state {
Expand All @@ -597,7 +592,7 @@ impl Jieba {
if cut_all {
self.cut_all_internal(block, &mut words);
} else if hmm {
self.cut_dag_hmm(block, &mut words, &mut route, &mut dag, &mut V, &mut prev, &mut path);
self.cut_dag_hmm(block, &mut words, &mut route, &mut dag, &mut hmm_context);
} else {
self.cut_dag_no_hmm(block, &mut words, &mut route, &mut dag);
}
Expand Down