From 350fa114d9a7e9720f4a4ee2be3b4c05f9c3c6fe Mon Sep 17 00:00:00 2001 From: "Albert J. Wong" Date: Mon, 8 Apr 2024 01:12:55 -0700 Subject: [PATCH] DRY-up API/impl by extracting a KeywordExtractConfig struct --- src/keywords/mod.rs | 73 +++++++++++++++++++++++++ src/keywords/textrank.rs | 83 +++++----------------------- src/keywords/tfidf.rs | 115 ++++++++++----------------------------- src/lib.rs | 8 +-- 4 files changed, 118 insertions(+), 161 deletions(-) diff --git a/src/keywords/mod.rs b/src/keywords/mod.rs index 3928d9b..e5e6c6e 100644 --- a/src/keywords/mod.rs +++ b/src/keywords/mod.rs @@ -32,6 +32,79 @@ pub struct Keyword { pub weight: f64, } +#[derive(Debug)] +pub struct KeywordExtractConfig { + stop_words: BTreeSet, + min_keyword_length: usize, + use_hmm: bool, +} + +impl KeywordExtractConfig { + /// Creates a KeywordExtractConfig state that contains filter criteria as + /// well as segmentation configuration for use by keyword extraction + /// implementations. + pub fn new(stop_words: BTreeSet, min_keyword_length: usize, use_hmm: bool) -> Self { + KeywordExtractConfig { + stop_words, + min_keyword_length, + use_hmm, + } + } + + /// Add a new stop word. + pub fn add_stop_word(&mut self, word: String) -> bool { + self.stop_words.insert(word) + } + + /// Remove an existing stop word. + pub fn remove_stop_word(&mut self, word: &str) -> bool { + self.stop_words.remove(word) + } + + /// Replace all stop words with new stop words set. + pub fn set_stop_words(&mut self, stop_words: BTreeSet) { + self.stop_words = stop_words + } + + /// Get current set of stop words. + pub fn get_stop_words(&self) -> &BTreeSet { + &self.stop_words + } + + /// True if hmm is used during segmentation in `extract_tags`. + pub fn get_use_hmm(&self) -> bool { + self.use_hmm + } + + /// Sets whether or not to use hmm during segmentation in `extract_tags`. + pub fn set_use_hmm(&mut self, use_hmm: bool) { + self.use_hmm = use_hmm + } + + /// Gets the minimum number of Unicode Scalar Values required per keyword. + pub fn get_min_keyword_length(&self) -> usize { + self.min_keyword_length + } + + /// Sets the minimum number of Unicode Scalar Values required per keyword. + /// + /// The default is 2. There is likely not much reason to change this. + pub fn set_min_keyword_length(&mut self, min_keyword_length: usize) { + self.min_keyword_length = min_keyword_length + } + + #[inline] + pub fn filter(&self, s: &str) -> bool { + s.chars().count() >= self.min_keyword_length && !self.stop_words.contains(&s.to_lowercase()) + } +} + +impl Default for KeywordExtractConfig { + fn default() -> Self { + KeywordExtractConfig::new(DEFAULT_STOP_WORDS.clone(), 2, false) + } +} + pub trait KeywordExtract { fn extract_tags(&self, sentence: &str, top_k: usize, allowed_pos: Vec) -> Vec; } diff --git a/src/keywords/textrank.rs b/src/keywords/textrank.rs index 06e8217..02d76dc 100644 --- a/src/keywords/textrank.rs +++ b/src/keywords/textrank.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeSet, BinaryHeap}; use ordered_float::OrderedFloat; -use super::{JiebaKeywordExtract, Keyword, KeywordExtract, DEFAULT_STOP_WORDS}; +use super::{JiebaKeywordExtract, Keyword, KeywordExtract, KeywordExtractConfig}; use crate::FxHashMap as HashMap; use crate::Jieba; @@ -74,9 +74,7 @@ impl StateDiagram { #[derive(Debug)] pub struct UnboundTextRank { span: usize, - stop_words: BTreeSet, - min_keyword_length: usize, - use_hmm: bool, + config: KeywordExtractConfig, } impl UnboundTextRank { @@ -93,78 +91,23 @@ impl UnboundTextRank { /// BTreeSet::from(["a", "the", "of"].map(|s| s.to_string())); /// jieba_rs::UnboundTextRank::new( /// 5, - /// stop_words, - /// 2, - /// false); + /// KeywordExtractConfig::default()); /// ``` - pub fn new(span: usize, stop_words: BTreeSet, min_keyword_length: usize, use_hmm: bool) -> Self { - UnboundTextRank { - stop_words, - span, - min_keyword_length, - use_hmm, - } - } - - /// Add a new stop word. - pub fn add_stop_word(&mut self, word: String) -> bool { - self.stop_words.insert(word) - } - - /// Remove an existing stop word. - pub fn remove_stop_word(&mut self, word: &str) -> bool { - self.stop_words.remove(word) - } - - /// Replace all stop words with new stop words set. - pub fn set_stop_words(&mut self, stop_words: BTreeSet) { - self.stop_words = stop_words - } - - /// Get current set of stop words. - pub fn get_stop_words(&self) -> &BTreeSet { - &self.stop_words - } - - /// True if hmm is used during segmentation in `extract_tags`. - pub fn get_use_hmm(&self) -> bool { - self.use_hmm - } - - /// Sets whether or not to use hmm during segmentation in `extract_tags`. - pub fn set_use_hmm(&mut self, use_hmm: bool) { - self.use_hmm = use_hmm - } - - /// Gets the minimum number of Unicode Scalar Values required per keyword. - pub fn get_min_keyword_length(&self) -> usize { - self.min_keyword_length - } - - /// Sets the minimum number of Unicode Scalar Values required per keyword. - /// - /// The default is 2. There is likely not much reason to change this. - pub fn set_min_keyword_length(&mut self, min_keyword_length: usize) { - self.min_keyword_length = min_keyword_length - } - - #[inline] - fn filter(&self, s: &str) -> bool { - s.chars().count() >= self.min_keyword_length && !self.stop_words.contains(&s.to_lowercase()) + pub fn new(span: usize, config: KeywordExtractConfig) -> Self { + UnboundTextRank { span, config } } } impl Default for UnboundTextRank { - /// Creates UnboundTextRank with 5 Unicode Scalar Value spans, - /// DEFAULT_STOP_WORDS, and no hmm in segmentation. + /// Creates UnboundTextRank with 5 Unicode Scalar Value spans fn default() -> Self { - UnboundTextRank::new(5, DEFAULT_STOP_WORDS.clone(), 2, false) + UnboundTextRank::new(5, KeywordExtractConfig::default()) } } impl JiebaKeywordExtract for UnboundTextRank { fn extract_tags(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec) -> Vec { - let tags = jieba.tag(sentence, self.use_hmm); + let tags = jieba.tag(sentence, self.config.get_use_hmm()); let mut allowed_pos_set = BTreeSet::new(); for s in allowed_pos { @@ -190,7 +133,7 @@ impl JiebaKeywordExtract for UnboundTextRank { continue; } - if !self.filter(t.word) { + if !self.config.filter(t.word) { continue; } @@ -203,7 +146,7 @@ impl JiebaKeywordExtract for UnboundTextRank { continue; } - if !self.filter(tags[j].word) { + if !self.config.filter(tags[j].word) { continue; } @@ -267,17 +210,17 @@ impl<'a> TextRank<'a> { /// Add a new stop word pub fn add_stop_word(&mut self, word: String) -> bool { - self.unbound_text_rank.add_stop_word(word) + self.unbound_text_rank.config.add_stop_word(word) } /// Remove an existing stop word pub fn remove_stop_word(&mut self, word: &str) -> bool { - self.unbound_text_rank.remove_stop_word(word) + self.unbound_text_rank.config.remove_stop_word(word) } /// Replace all stop words with new stop words set pub fn set_stop_words(&mut self, stop_words: BTreeSet) { - self.unbound_text_rank.set_stop_words(stop_words) + self.unbound_text_rank.config.set_stop_words(stop_words) } } diff --git a/src/keywords/tfidf.rs b/src/keywords/tfidf.rs index 8ce4fb3..267c4de 100644 --- a/src/keywords/tfidf.rs +++ b/src/keywords/tfidf.rs @@ -4,7 +4,7 @@ use std::io::{self, BufRead, BufReader}; use ordered_float::OrderedFloat; -use super::{JiebaKeywordExtract, Keyword, KeywordExtract, DEFAULT_STOP_WORDS}; +use super::{JiebaKeywordExtract, Keyword, KeywordExtract, KeywordExtractConfig}; use crate::FxHashMap as HashMap; use crate::Jieba; @@ -35,9 +35,7 @@ impl<'a> PartialOrd for HeapNode<'a> { pub struct UnboundTfidf { idf_dict: HashMap, median_idf: f64, - stop_words: BTreeSet, - min_keyword_length: usize, - use_hmm: bool, + config: KeywordExtractConfig, } /// Implementation of JiebaKeywordExtract using a TFIDF dictionary. @@ -50,44 +48,27 @@ impl UnboundTfidf { /// /// # Examples /// - /// New instance with custom stop words and idf dictionary. Also uses hmm - /// for unknown words during segmentation and allows keywords of length 1. + /// New instance with custom idf dictionary. /// ``` - /// use std::collections::BTreeSet; - /// - /// let stop_words : BTreeSet = - /// BTreeSet::from(["a", "the", "of"].map(|s| s.to_string())); /// let mut sample_idf = "劳动防护 13.900677652\n\ /// 生化学 13.900677652\n"; /// jieba_rs::UnboundTfidf::new( /// Some(&mut sample_idf.as_bytes()), - /// stop_words, - /// 1, - /// true); + /// jieba_rs::KeywordExtractConfig::default()); /// ``` /// /// New instance with module default stop words and no initial IDF /// dictionary. Dictionary should be loaded later with `load_dict()` calls. - /// No hmm and more standard minimal of length 2 keywords. /// ``` /// jieba_rs::UnboundTfidf::new( /// None::<&mut std::io::Empty>, - /// jieba_rs::DEFAULT_STOP_WORDS.clone(), - /// 2, - /// false); + /// jieba_rs::KeywordExtractConfig::default()); /// ``` - pub fn new( - opt_dict: Option<&mut impl BufRead>, - stop_words: BTreeSet, - min_keyword_length: usize, - use_hmm: bool, - ) -> Self { + pub fn new(opt_dict: Option<&mut impl BufRead>, config: KeywordExtractConfig) -> Self { let mut instance = UnboundTfidf { idf_dict: HashMap::default(), median_idf: 0.0, - stop_words, - min_keyword_length, - use_hmm, + config, }; if let Some(dict) = opt_dict { instance.load_dict(dict).unwrap(); @@ -99,32 +80,33 @@ impl UnboundTfidf { /// /// ``` /// use jieba_rs::{Jieba, JiebaKeywordExtract, Keyword, - /// UnboundTfidf, DEFAULT_STOP_WORDS}; + /// KeywordExtractConfig, UnboundTfidf}; /// /// let jieba = Jieba::default(); /// let mut init_idf = "生化学 13.900677652\n"; /// /// let mut tfidf = UnboundTfidf::new( /// Some(&mut init_idf.as_bytes()), - /// DEFAULT_STOP_WORDS.clone(), - /// true); - /// let top_k = tfidf.extract_tags(&jieba, "生化学很難", 3, vec![]); + /// KeywordExtractConfig::default()); + /// let top_k = tfidf.extract_tags(&jieba, "生化学不是光化学的,", 3, vec![]); /// assert_eq!( /// top_k, /// vec![ - /// Keyword { keyword: "很難".to_string(), weight: 6.950338826 }, - /// Keyword { keyword: "生化学".to_string(), weight: 6.950338826 } + /// Keyword { keyword: "不是".to_string(), weight: 4.6335592173333335 }, + /// Keyword { keyword: "光化学".to_string(), weight: 4.6335592173333335 }, + /// Keyword { keyword: "生化学".to_string(), weight: 4.6335592173333335 } /// ] /// ); /// - /// let mut init_idf = "很難 99.123456789\n"; + /// let mut init_idf = "光化学 99.123456789\n"; /// tfidf.load_dict(&mut init_idf.as_bytes()); - /// let top_k = tfidf.extract_tags(&jieba, "生化学很難", 3, vec![]); + /// let new_top_k = tfidf.extract_tags(&jieba, "生化学不是光化学的,", 3, vec![]); /// assert_eq!( - /// top_k, + /// new_top_k, /// vec![ - /// Keyword { keyword: "很難".to_string(), weight: 49.5617283945 }, - /// Keyword { keyword: "生化学".to_string(), weight: 6.950338826 } + /// Keyword { keyword: "不是".to_string(), weight: 33.041152263 }, + /// Keyword { keyword: "光化学".to_string(), weight: 33.041152263 }, + /// Keyword { keyword: "生化学".to_string(), weight: 4.6335592173333335 } /// ] /// ); /// ``` @@ -156,51 +138,12 @@ impl UnboundTfidf { Ok(()) } - /// Add a new stop word. - pub fn add_stop_word(&mut self, word: String) -> bool { - self.stop_words.insert(word) - } - - /// Remove an existing stop word. - pub fn remove_stop_word(&mut self, word: &str) -> bool { - self.stop_words.remove(word) - } - - /// Replace all stop words with new stop words set. - pub fn set_stop_words(&mut self, stop_words: BTreeSet) { - self.stop_words = stop_words - } - - /// Get current set of stop words. - pub fn get_stop_words(&self) -> &BTreeSet { - &self.stop_words - } - - /// True if hmm is used during segmentation in `extract_tags`. - pub fn get_use_hmm(&self) -> bool { - self.use_hmm - } - - /// Sets whether or not to use hmm during segmentation in `extract_tags`. - pub fn set_use_hmm(&mut self, use_hmm: bool) { - self.use_hmm = use_hmm - } - - /// Gets the minimum number of Unicode Scalar Values required per keyword. - pub fn get_min_keyword_length(&self) -> usize { - self.min_keyword_length - } - - /// Sets the minimum number of Unicode Scalar Values required per keyword. - /// - /// The default is 2. There is likely not much reason to change this. - pub fn set_min_keyword_length(&mut self, min_keyword_length: usize) { - self.min_keyword_length = min_keyword_length + pub fn config(&self) -> &KeywordExtractConfig { + &self.config } - #[inline] - fn filter(&self, s: &str) -> bool { - s.chars().count() >= self.min_keyword_length && !self.stop_words.contains(&s.to_lowercase()) + pub fn config_mut(&mut self) -> &mut KeywordExtractConfig { + &mut self.config } } @@ -209,13 +152,13 @@ impl Default for UnboundTfidf { /// 2 Unicode Scalar Value minimum for keywords, and no hmm in segmentation. fn default() -> Self { let mut default_dict = BufReader::new(DEFAULT_IDF.as_bytes()); - UnboundTfidf::new(Some(&mut default_dict), DEFAULT_STOP_WORDS.clone(), 2, false) + UnboundTfidf::new(Some(&mut default_dict), KeywordExtractConfig::default()) } } impl JiebaKeywordExtract for UnboundTfidf { fn extract_tags(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec) -> Vec { - let tags = jieba.tag(sentence, self.use_hmm); + let tags = jieba.tag(sentence, self.config.get_use_hmm()); let mut allowed_pos_set = BTreeSet::new(); for s in allowed_pos { @@ -228,7 +171,7 @@ impl JiebaKeywordExtract for UnboundTfidf { continue; } - if !self.filter(t.word) { + if !self.config.filter(t.word) { continue; } @@ -288,17 +231,17 @@ impl<'a> TFIDF<'a> { /// Add a new stop word pub fn add_stop_word(&mut self, word: String) -> bool { - self.unbound_tfidf.add_stop_word(word) + self.unbound_tfidf.config.add_stop_word(word) } /// Remove an existing stop word pub fn remove_stop_word(&mut self, word: &str) -> bool { - self.unbound_tfidf.remove_stop_word(word) + self.unbound_tfidf.config.remove_stop_word(word) } /// Replace all stop words with new stop words set pub fn set_stop_words(&mut self, stop_words: BTreeSet) { - self.unbound_tfidf.set_stop_words(stop_words) + self.unbound_tfidf.config.set_stop_words(stop_words) } } diff --git a/src/lib.rs b/src/lib.rs index 1b4b45c..c977a02 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,13 +82,11 @@ pub(crate) type FxHashMap = HashMap; pub use crate::errors::Error; #[cfg(feature = "textrank")] -pub use crate::keywords::textrank::TextRank; -pub use crate::keywords::textrank::UnboundTextRank; -pub use crate::keywords::tfidf::UnboundTfidf; +pub use crate::keywords::textrank::{TextRank, UnboundTextRank}; #[cfg(feature = "tfidf")] -pub use crate::keywords::tfidf::TFIDF; +pub use crate::keywords::tfidf::{UnboundTfidf, TFIDF}; #[cfg(any(feature = "tfidf", feature = "textrank"))] -pub use crate::keywords::{JiebaKeywordExtract, Keyword, KeywordExtract, DEFAULT_STOP_WORDS}; +pub use crate::keywords::{JiebaKeywordExtract, Keyword, KeywordExtract, KeywordExtractConfig, DEFAULT_STOP_WORDS}; mod errors; mod hmm;