Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow users to load custom hmm model #92

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ required-features = ["tfidf", "textrank"]
regex = "1.0"
lazy_static = "1.0"
phf = "0.11"
hashbrown = { version = "0.12", default-features = false, features = ["inline-more"] }
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would need a benchmark result to justify it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if instead HMMModel should be a seprate struct and that Jieba's interface should be expanded to be

cut()
hmm_cut()

Will discussion back in #48

cedarwood = "0.4"
ordered-float = { version = "3.0", optional = true }
once_cell = "1"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no point in use both lazy_static and once_cell, please also replace lazy_static with once_cell

fxhash = "0.2.1"

[build-dependencies]
phf_codegen = "0.11"

[features]
default = ["default-dict"]
default = ["default-dict", "default-hmm-model"]
default-dict = []
default-hmm-model = []
tfidf = ["ordered-float"]
textrank = ["ordered-float"]

Expand Down
8 changes: 4 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ fn main() {
let mut lines = reader.lines().map(|x| x.unwrap()).skip_while(|x| x.starts_with('#'));
let prob_start = lines.next().unwrap();
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
write!(&mut file, "static INITIAL_PROBS: StatusSet = [").unwrap();
write!(&mut file, "pub static INITIAL_PROBS: StatusSet = [").unwrap();
for prob in prob_start.split(' ') {
write!(&mut file, "{}, ", prob).unwrap();
}
write!(&mut file, "];\n\n").unwrap();
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
write!(&mut file, "static TRANS_PROBS: [StatusSet; 4] = [").unwrap();
write!(&mut file, "pub static TRANS_PROBS: [StatusSet; 4] = [").unwrap();
for line in lines
.by_ref()
.skip_while(|x| x.starts_with('#'))
Expand All @@ -38,7 +38,7 @@ fn main() {
continue;
}
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
write!(&mut file, "static EMIT_PROB_{}: phf::Map<&'static str, f64> = ", i).unwrap();
write!(&mut file, "pub static EMIT_PROB_{}: phf::Map<&'static str, f64> = ", i).unwrap();
let mut map = phf_codegen::Map::new();
for word_prob in line.split(',') {
let mut parts = word_prob.split(':');
Expand All @@ -50,5 +50,5 @@ fn main() {
i += 1;
}
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
writeln!(&mut file, "static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; 4] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap();
writeln!(&mut file, "pub static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; 4] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap();
}
149 changes: 139 additions & 10 deletions src/hmm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::cmp::Ordering;

use lazy_static::lazy_static;
use once_cell::sync::OnceCell;
use regex::Regex;
use std::cmp::Ordering;
use std::collections::HashMap;

use crate::SplitMatches;

Expand All @@ -20,17 +21,108 @@ pub enum Status {
S = 3,
}

pub struct HmmModel {
pub initial_probs: StatusSet,
pub trans_probs: [StatusSet; 4],
pub emit_probs: [HashMap<String, f64>; 4],
}

static CUSTOM_HMM_MODEL: OnceCell<HmmModel> = OnceCell::new();

pub fn get_custom_hmm_model() -> Option<&'static HmmModel> {
CUSTOM_HMM_MODEL.get()
}

pub fn init_custom_hmm_model(model: HmmModel) -> Result<(), HmmModel> {
CUSTOM_HMM_MODEL.set(model)
}

static PREV_STATUS: [[Status; 2]; 4] = [
[Status::E, Status::S], // B
[Status::B, Status::M], // E
[Status::M, Status::B], // M
[Status::S, Status::E], // S
];

include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));

const MIN_FLOAT: f64 = -3.14e100;

#[cfg(feature = "default-hmm-model")]
mod default_hmm {
use super::*;

include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));
}

#[inline]
fn get_initial_prob(index: usize) -> f64 {
debug_assert!(index < 4);
if let Some(model) = get_custom_hmm_model() {
model.initial_probs[index]
} else {
#[cfg(feature = "default-hmm-model")]
{
default_hmm::INITIAL_PROBS[index]
}
#[cfg(not(feature = "default-hmm-model"))]
{
debug_assert!(
true,
"No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model."
);
MIN_FLOAT
}
}
}

#[inline]
fn get_emit_prob(index: usize, word: &str) -> f64 {
debug_assert!(index < 4);
if let Some(model) = get_custom_hmm_model() {
model.emit_probs[index].get(word).cloned().unwrap_or(MIN_FLOAT)
} else {
#[cfg(feature = "default-hmm-model")]
{
default_hmm::EMIT_PROBS[index].get(word).cloned().unwrap_or(MIN_FLOAT)
}
#[cfg(not(feature = "default-hmm-model"))]
{
debug_assert!(
true,
"No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model."
);
MIN_FLOAT
}
}
}

#[inline]
fn get_trans_prob(from_index: usize, to_index: usize) -> f64 {
debug_assert!(from_index < 4);
debug_assert!(to_index < 4);
if let Some(model) = get_custom_hmm_model() {
model.trans_probs[from_index]
.get(to_index)
.cloned()
.unwrap_or(MIN_FLOAT)
} else {
#[cfg(feature = "default-hmm-model")]
{
default_hmm::TRANS_PROBS[from_index]
.get(to_index)
.cloned()
.unwrap_or(MIN_FLOAT)
}
#[cfg(not(feature = "default-hmm-model"))]
{
debug_assert!(
true,
"No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model."
);
MIN_FLOAT
}
}
}

#[allow(non_snake_case)]
fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, best_path: &mut Vec<Status>) {
let str_len = sentence.len();
Expand All @@ -57,7 +149,8 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
let x2 = *curr.peek().unwrap();
for y in &states {
let first_word = &sentence[x1..x2];
let prob = INITIAL_PROBS[*y as usize] + EMIT_PROBS[*y as usize].get(first_word).cloned().unwrap_or(MIN_FLOAT);
let index = *y as usize;
let prob = get_initial_prob(index) + get_emit_prob(index, first_word);
V[*y as usize] = prob;
}

Expand All @@ -66,14 +159,12 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
for y in &states {
let byte_end = *curr.peek().unwrap_or(&str_len);
let word = &sentence[byte_start..byte_end];
let em_prob = EMIT_PROBS[*y as usize].get(word).cloned().unwrap_or(MIN_FLOAT);
let em_prob = get_emit_prob(*y as usize, word);
let (prob, state) = PREV_STATUS[*y as usize]
.iter()
.map(|y0| {
(
V[(t - 1) * R + (*y0 as usize)]
+ TRANS_PROBS[*y0 as usize].get(*y as usize).cloned().unwrap_or(MIN_FLOAT)
+ em_prob,
V[(t - 1) * R + (*y0 as usize)] + get_trans_prob(*y0 as usize, *y as usize) + em_prob,
*y0,
)
})
Expand Down Expand Up @@ -197,15 +288,52 @@ pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) {
cut_with_allocated_memory(sentence, words, &mut V, &mut prev, &mut path);
}

#[cfg(all(test, not(feature = "default-hmm-model")))]
pub fn test_init_custom_hmm_model() {
use std::convert::TryInto;

mod hmm_prob {
use super::*;
include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));
}

if get_custom_hmm_model().is_none() {
let initial_probs = hmm_prob::INITIAL_PROBS;
let trans_probs = hmm_prob::TRANS_PROBS;
let emit_probs: [HashMap<_, _>; 4] = {
let mut emit_probs = Vec::new();
for prob in hmm_prob::EMIT_PROBS {
let mut probs = HashMap::new();
for (k, v) in prob {
probs.insert(k.to_string(), *v);
}
emit_probs.push(probs);
}
emit_probs.try_into().unwrap()
};
let _ = init_custom_hmm_model(HmmModel {
initial_probs,
trans_probs,
emit_probs,
});
}
}

#[cfg(all(test, feature = "default-hmm-model"))]
pub fn test_init_custom_hmm_model() {
// nothing
}

#[cfg(test)]
mod tests {
use super::{cut, viterbi, Status};
use super::*;

#[test]
#[allow(non_snake_case)]
fn test_viterbi() {
use super::Status::*;

test_init_custom_hmm_model();
let sentence = "小明硕士毕业于中国科学院计算所";

let R = 4;
Expand All @@ -219,6 +347,7 @@ mod tests {

#[test]
fn test_hmm_cut() {
test_init_custom_hmm_model();
let sentence = "小明硕士毕业于中国科学院计算所";
let mut words = Vec::with_capacity(sentence.chars().count() / 2);
cut(sentence, &mut words);
Expand Down
11 changes: 10 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ use std::cmp::Ordering;
use std::io::BufRead;

use cedarwood::Cedar;
use hashbrown::HashMap;
use regex::{Match, Matches, Regex};
use std::collections::HashMap;

pub(crate) type FxHashMap<K, V> = HashMap<K, V, fxhash::FxBuildHasher>;

Expand All @@ -88,6 +88,8 @@ pub use crate::keywords::tfidf::TFIDF;
#[cfg(any(feature = "tfidf", feature = "textrank"))]
pub use crate::keywords::{Keyword, KeywordExtract};

pub use crate::hmm::{get_custom_hmm_model, init_custom_hmm_model};

mod errors;
mod hmm;
#[cfg(any(feature = "tfidf", feature = "textrank"))]
Expand Down Expand Up @@ -806,6 +808,7 @@ impl Jieba {
#[cfg(test)]
mod tests {
use super::{Jieba, SplitMatches, SplitState, Tag, Token, TokenizeMode, RE_HAN_DEFAULT};
use crate::hmm::test_init_custom_hmm_model;
use std::io::BufReader;

#[test]
Expand Down Expand Up @@ -900,6 +903,7 @@ mod tests {

#[test]
fn test_cut_with_hmm() {
test_init_custom_hmm_model();
let jieba = Jieba::new();
let words = jieba.cut("我们中出了一个叛徒", false);
assert_eq!(words, vec!["我们", "中", "出", "了", "一个", "叛徒"]);
Expand All @@ -917,6 +921,7 @@ mod tests {

#[test]
fn test_cut_weicheng() {
test_init_custom_hmm_model();
static WEICHENG_TXT: &str = include_str!("../examples/weicheng/src/weicheng.txt");
let jieba = Jieba::new();
for line in WEICHENG_TXT.split('\n') {
Expand All @@ -926,6 +931,7 @@ mod tests {

#[test]
fn test_cut_for_search() {
test_init_custom_hmm_model();
let jieba = Jieba::new();
let words = jieba.cut_for_search("南京市长江大桥", true);
assert_eq!(words, vec!["南京", "京市", "南京市", "长江", "大桥", "长江大桥"]);
Expand Down Expand Up @@ -962,6 +968,7 @@ mod tests {

#[test]
fn test_tag() {
test_init_custom_hmm_model();
let jieba = Jieba::new();
let tags = jieba.tag(
"我是拖拉机学院手扶拖拉机专业的。不用多久,我就会升职加薪,当上CEO,走上人生巅峰。",
Expand Down Expand Up @@ -1078,6 +1085,7 @@ mod tests {

#[test]
fn test_tokenize() {
test_init_custom_hmm_model();
let jieba = Jieba::new();
let tokens = jieba.tokenize("南京市长江大桥", TokenizeMode::Default, false);
assert_eq!(
Expand Down Expand Up @@ -1305,6 +1313,7 @@ mod tests {

#[test]
fn test_userdict_hmm() {
test_init_custom_hmm_model();
let mut jieba = Jieba::new();
let tokens = jieba.tokenize("我们中出了一个叛徒", TokenizeMode::Default, true);
assert_eq!(
Expand Down