diff --git a/Cargo.toml b/Cargo.toml index da92697..4b6eb47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,12 +27,13 @@ harness = false required-features = ["tfidf", "textrank"] [dependencies] -regex = "1.0" -lazy_static = "1.0" -phf = "0.11" cedarwood = "0.4" -ordered-float = { version = "4.0", optional = true } +derive_builder = "0.20.0" fxhash = "0.2.1" +lazy_static = "1.0" +ordered-float = { version = "4.0", optional = true } +phf = "0.11" +regex = "1.0" [build-dependencies] phf_codegen = "0.11" diff --git a/src/keywords/mod.rs b/src/keywords/mod.rs index 548dd65..ed9a544 100644 --- a/src/keywords/mod.rs +++ b/src/keywords/mod.rs @@ -32,76 +32,110 @@ pub struct Keyword { pub weight: f64, } -#[derive(Debug, Clone)] +/// Creates a KeywordExtractConfig state that contains filter criteria as +/// well as segmentation configuration for use by keyword extraction +/// implementations. +/// +/// Use KeywordExtractConfigBuilder to change the defaults. +/// +/// # Examples +/// ``` +/// use jieba_rs::KeywordExtractConfig; +/// +/// let mut config = KeywordExtractConfig::default(); +/// assert!(config.stop_words().contains("the")); +/// assert!(!config.stop_words().contains("FakeWord")); +/// assert!(!config.use_hmm()); +/// assert_eq!(2, config.min_keyword_length()); +/// +/// let built_default = KeywordExtractConfig::builder().build().unwrap(); +/// assert_eq!(config, built_default); +/// +/// let changed = KeywordExtractConfig::builder() +/// .add_stop_word("FakeWord".to_string()) +/// .remove_stop_word("the") +/// .use_hmm(true) +/// .min_keyword_length(10) +/// .build().unwrap(); +/// +/// assert!(!changed.stop_words().contains("the")); +/// assert!(changed.stop_words().contains("FakeWord")); +/// assert!(changed.use_hmm()); +/// assert_eq!(10, changed.min_keyword_length()); +/// ``` +#[derive(Builder, Debug, Clone, PartialEq)] pub struct KeywordExtractConfig { + #[builder(default = "self.default_stop_words()?", setter(custom))] stop_words: BTreeSet, + + #[builder(default = "2")] + #[doc = r"Any segments less than this length will not be considered a Keyword"] min_keyword_length: usize, + + #[builder(default = "false")] + #[doc = r"If true, fall back to hmm model if segment cannot be found in the dictionary"] use_hmm: bool, } impl KeywordExtractConfig { - /// Creates a KeywordExtractConfig state that contains filter criteria as - /// well as segmentation configuration for use by keyword extraction - /// implementations. - pub fn new(stop_words: BTreeSet, min_keyword_length: usize, use_hmm: bool) -> Self { - KeywordExtractConfig { - stop_words, - min_keyword_length, - use_hmm, - } - } - - /// Add a new stop word. - pub fn add_stop_word(&mut self, word: String) -> bool { - self.stop_words.insert(word) - } - - /// Remove an existing stop word. - pub fn remove_stop_word(&mut self, word: &str) -> bool { - self.stop_words.remove(word) - } - - /// Replace all stop words with new stop words set. - pub fn set_stop_words(&mut self, stop_words: BTreeSet) { - self.stop_words = stop_words + pub fn builder() -> KeywordExtractConfigBuilder { + KeywordExtractConfigBuilder::default() } /// Get current set of stop words. - pub fn get_stop_words(&self) -> &BTreeSet { + pub fn stop_words(&self) -> &BTreeSet { &self.stop_words } /// True if hmm is used during segmentation in `extract_tags`. - pub fn get_use_hmm(&self) -> bool { + pub fn use_hmm(&self) -> bool { self.use_hmm } - /// Sets whether or not to use hmm during segmentation in `extract_tags`. - pub fn set_use_hmm(&mut self, use_hmm: bool) { - self.use_hmm = use_hmm - } - /// Gets the minimum number of Unicode Scalar Values required per keyword. - pub fn get_min_keyword_length(&self) -> usize { + pub fn min_keyword_length(&self) -> usize { self.min_keyword_length } - /// Sets the minimum number of Unicode Scalar Values required per keyword. - /// - /// The default is 2. There is likely not much reason to change this. - pub fn set_min_keyword_length(&mut self, min_keyword_length: usize) { - self.min_keyword_length = min_keyword_length - } - #[inline] pub(crate) fn filter(&self, s: &str) -> bool { - s.chars().count() >= self.min_keyword_length && !self.stop_words.contains(&s.to_lowercase()) + s.chars().count() >= self.min_keyword_length() && !self.stop_words.contains(&s.to_lowercase()) + } +} + +impl KeywordExtractConfigBuilder { + fn default_stop_words(&self) -> Result, KeywordExtractConfigBuilderError> { + Ok(DEFAULT_STOP_WORDS.clone()) + } + + /// Add a new stop word. + pub fn add_stop_word(&mut self, word: String) -> &mut Self { + if self.stop_words.is_none() { + self.stop_words = Some(self.default_stop_words().unwrap()); + } + self.stop_words.as_mut().unwrap().insert(word); + self + } + + /// Remove an existing stop word. + pub fn remove_stop_word(&mut self, word: &str) -> &mut Self { + if self.stop_words.is_none() { + self.stop_words = Some(self.default_stop_words().unwrap()); + } + self.stop_words.as_mut().unwrap().remove(word); + self + } + + /// Replace all stop words with new stop words set. + pub fn set_stop_words(&mut self, stop_words: BTreeSet) -> &mut Self { + self.stop_words = Some(stop_words); + self } } impl Default for KeywordExtractConfig { - fn default() -> Self { - KeywordExtractConfig::new(DEFAULT_STOP_WORDS.clone(), 2, false) + fn default() -> KeywordExtractConfig { + KeywordExtractConfigBuilder::default().build().unwrap() } } diff --git a/src/keywords/textrank.rs b/src/keywords/textrank.rs index 9359f9f..1e2e42b 100644 --- a/src/keywords/textrank.rs +++ b/src/keywords/textrank.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeSet, BinaryHeap}; use ordered_float::OrderedFloat; -use super::{Keyword, KeywordExtract, KeywordExtractConfig}; +use super::{Keyword, KeywordExtract, KeywordExtractConfig, KeywordExtractConfigBuilder}; use crate::FxHashMap as HashMap; use crate::Jieba; @@ -102,7 +102,7 @@ impl TextRank { impl Default for TextRank { /// Creates TextRank with 5 Unicode Scalar Value spans fn default() -> Self { - TextRank::new(5, KeywordExtractConfig::default()) + TextRank::new(5, KeywordExtractConfigBuilder::default().build().unwrap()) } } @@ -142,7 +142,7 @@ impl KeywordExtract for TextRank { /// ); /// ``` fn extract_keywords(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec) -> Vec { - let tags = jieba.tag(sentence, self.config.get_use_hmm()); + let tags = jieba.tag(sentence, self.config.use_hmm()); let mut allowed_pos_set = BTreeSet::new(); for s in allowed_pos { @@ -156,7 +156,7 @@ impl KeywordExtract for TextRank { continue; } - if word2id.get(t.word).is_none() { + if !word2id.contains_key(t.word) { unique_words.push(String::from(t.word)); word2id.insert(String::from(t.word), unique_words.len() - 1); } diff --git a/src/keywords/tfidf.rs b/src/keywords/tfidf.rs index 3ba4f9f..4097a55 100644 --- a/src/keywords/tfidf.rs +++ b/src/keywords/tfidf.rs @@ -4,7 +4,7 @@ use std::io::{self, BufRead, BufReader}; use ordered_float::OrderedFloat; -use super::{Keyword, KeywordExtract, KeywordExtractConfig}; +use super::{Keyword, KeywordExtract, KeywordExtractConfig, KeywordExtractConfigBuilder}; use crate::FxHashMap as HashMap; use crate::Jieba; @@ -159,7 +159,10 @@ impl Default for TfIdf { /// 2 Unicode Scalar Value minimum for keywords, and no hmm in segmentation. fn default() -> Self { let mut default_dict = BufReader::new(DEFAULT_IDF.as_bytes()); - TfIdf::new(Some(&mut default_dict), KeywordExtractConfig::default()) + TfIdf::new( + Some(&mut default_dict), + KeywordExtractConfigBuilder::default().build().unwrap(), + ) } } @@ -209,7 +212,7 @@ impl KeywordExtract for TfIdf { /// ); /// ``` fn extract_keywords(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec) -> Vec { - let tags = jieba.tag(sentence, self.config.get_use_hmm()); + let tags = jieba.tag(sentence, self.config.use_hmm()); let mut allowed_pos_set = BTreeSet::new(); for s in allowed_pos { diff --git a/src/lib.rs b/src/lib.rs index 224376c..efc1a75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,6 +90,10 @@ pub use crate::keywords::tfidf::TfIdf; #[cfg(any(feature = "tfidf", feature = "textrank"))] pub use crate::keywords::{Keyword, KeywordExtract, KeywordExtractConfig, DEFAULT_STOP_WORDS}; +#[cfg(any(feature = "tfidf", feature = "textrank"))] +#[macro_use] +extern crate derive_builder; + mod errors; mod hmm; #[cfg(any(feature = "tfidf", feature = "textrank"))]