Skip to content

Commit

Permalink
Add TFIDFState + APIs to avoid repeat loading of TFIDF data
Browse files Browse the repository at this point in the history
  • Loading branch information
awong-dev committed Apr 4, 2024
1 parent 89f6305 commit 538b99c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 7 deletions.
2 changes: 1 addition & 1 deletion capi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use c_fixed_string::CFixedStr;
use jieba_rs::{Jieba, KeywordExtract, TextRank, TFIDF};
use jieba_rs::{Jieba, KeywordExtract, TFIDFState, TextRank, TFIDF};

Check failure on line 2 in capi/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest)

unresolved import `jieba_rs::TFIDFState`

Check failure on line 2 in capi/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macos-latest)

unresolved import `jieba_rs::TFIDFState`
use std::boxed::Box;
use std::os::raw::c_char;
use std::{mem, ptr};
Expand Down
62 changes: 56 additions & 6 deletions src/keywords/tfidf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,64 @@ pub struct TFIDF<'a> {
stop_words: BTreeSet<String>,
}

/// Frozen state of TF-IDF keywords extractor without Jieba reference.
///
/// This can be used to save the state (stop words, idf_dictionary, etc)
/// of the TFIDF extractor beyond the lifetime of the `TFIDF<'a>` object.
/// The state can then be used to construct a new `TFIDF<'a>` object without
/// reparsing and constructing this data.
///
/// This is useful in situations where use of the extractor extends
/// beyond a stack frame, such as when implementing API bindings into a
/// programming language with refcounted lifetimes.
#[derive(Debug)]
pub struct TFIDFState {
idf_dict: HashMap<String, f64>,
median_idf: f64,
stop_words: BTreeSet<String>,
}

impl TFIDFState {
pub fn new<'a>(tfidf: TFIDF<'a>) -> Self {
TFIDFState {
idf_dict: tfidf.idf_dict,
median_idf: tfidf.median_idf,
stop_words: tfidf.stop_words,
}
}
}

impl<'a> TFIDF<'a> {
pub fn new_with_jieba(jieba: &'a Jieba) -> Self {
let mut instance = TFIDF {
pub fn new(jieba: &'a Jieba, tfidf_state: TFIDFState) -> Self {
TFIDF {
jieba,
idf_dict: tfidf_state.idf_dict,
median_idf: tfidf_state.median_idf,
stop_words: tfidf_state.stop_words,
}
}

pub fn new_with_jieba(jieba: &'a Jieba) -> Self {
let mut state = TFIDFState {
idf_dict: HashMap::default(),
median_idf: 0.0,
stop_words: STOP_WORDS.clone(),
};

let mut default_dict = BufReader::new(DEFAULT_IDF.as_bytes());
instance.load_dict(&mut default_dict).unwrap();
instance
Self::load_dict_internal(&mut state.idf_dict, &mut state.median_idf, &mut default_dict).unwrap();
Self::new(jieba, state)
}

pub fn load_dict<R: BufRead>(&mut self, dict: &mut R) -> io::Result<()> {
Self::load_dict_internal(&mut self.idf_dict, &mut self.median_idf, dict)
}

fn load_dict_internal<R: BufRead>(
idf_dict: &mut HashMap<String, f64>,
median_idf: &mut f64,
dict: &mut R,
) -> io::Result<()> {
let mut buf = String::new();
let mut idf_heap = BinaryHeap::new();
while dict.read_line(&mut buf)? > 0 {
Expand All @@ -64,7 +107,7 @@ impl<'a> TFIDF<'a> {

let word = parts[0];
if let Some(idf) = parts.get(1).and_then(|x| x.parse::<f64>().ok()) {
self.idf_dict.insert(word.to_string(), idf);
idf_dict.insert(word.to_string(), idf);
idf_heap.push(OrderedFloat(idf));
}

Expand All @@ -76,7 +119,7 @@ impl<'a> TFIDF<'a> {
idf_heap.pop();
}

self.median_idf = idf_heap.pop().unwrap().into_inner();
*median_idf = idf_heap.pop().unwrap().into_inner();

Ok(())
}
Expand Down Expand Up @@ -172,6 +215,13 @@ mod tests {
let _ = TFIDF::new_with_jieba(&jieba);
}

#[test]
fn test_init_tfidfstate() {
let jieba = super::Jieba::new();
let tfidf = TFIDF::new_with_jieba(&jieba);
let _ = TFIDFState::new(tfidf);
}

#[test]
fn test_extract_tags() {
let jieba = super::Jieba::new();
Expand Down

0 comments on commit 538b99c

Please sign in to comment.