Skip to content

Commit

Permalink
Add KeywordExtractConfigBuilder (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
awong-dev authored Apr 13, 2024
1 parent 18dfe4e commit 1391b3d
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 54 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ harness = false
required-features = ["tfidf", "textrank"]

[dependencies]
regex = "1.0"
lazy_static = "1.0"
phf = "0.11"
cedarwood = "0.4"
ordered-float = { version = "4.0", optional = true }
derive_builder = "0.20.0"
fxhash = "0.2.1"
lazy_static = "1.0"
ordered-float = { version = "4.0", optional = true }
phf = "0.11"
regex = "1.0"

[build-dependencies]
phf_codegen = "0.11"
Expand Down
120 changes: 77 additions & 43 deletions src/keywords/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,76 +32,110 @@ pub struct Keyword {
pub weight: f64,
}

#[derive(Debug, Clone)]
/// Creates a KeywordExtractConfig state that contains filter criteria as
/// well as segmentation configuration for use by keyword extraction
/// implementations.
///
/// Use KeywordExtractConfigBuilder to change the defaults.
///
/// # Examples
/// ```
/// use jieba_rs::KeywordExtractConfig;
///
/// let mut config = KeywordExtractConfig::default();
/// assert!(config.stop_words().contains("the"));
/// assert!(!config.stop_words().contains("FakeWord"));
/// assert!(!config.use_hmm());
/// assert_eq!(2, config.min_keyword_length());
///
/// let built_default = KeywordExtractConfig::builder().build().unwrap();
/// assert_eq!(config, built_default);
///
/// let changed = KeywordExtractConfig::builder()
/// .add_stop_word("FakeWord".to_string())
/// .remove_stop_word("the")
/// .use_hmm(true)
/// .min_keyword_length(10)
/// .build().unwrap();
///
/// assert!(!changed.stop_words().contains("the"));
/// assert!(changed.stop_words().contains("FakeWord"));
/// assert!(changed.use_hmm());
/// assert_eq!(10, changed.min_keyword_length());
/// ```
#[derive(Builder, Debug, Clone, PartialEq)]
pub struct KeywordExtractConfig {
#[builder(default = "self.default_stop_words()?", setter(custom))]
stop_words: BTreeSet<String>,

#[builder(default = "2")]
#[doc = r"Any segments less than this length will not be considered a Keyword"]
min_keyword_length: usize,

#[builder(default = "false")]
#[doc = r"If true, fall back to hmm model if segment cannot be found in the dictionary"]
use_hmm: bool,
}

impl KeywordExtractConfig {
/// Creates a KeywordExtractConfig state that contains filter criteria as
/// well as segmentation configuration for use by keyword extraction
/// implementations.
pub fn new(stop_words: BTreeSet<String>, min_keyword_length: usize, use_hmm: bool) -> Self {
KeywordExtractConfig {
stop_words,
min_keyword_length,
use_hmm,
}
}

/// Add a new stop word.
pub fn add_stop_word(&mut self, word: String) -> bool {
self.stop_words.insert(word)
}

/// Remove an existing stop word.
pub fn remove_stop_word(&mut self, word: &str) -> bool {
self.stop_words.remove(word)
}

/// Replace all stop words with new stop words set.
pub fn set_stop_words(&mut self, stop_words: BTreeSet<String>) {
self.stop_words = stop_words
pub fn builder() -> KeywordExtractConfigBuilder {
KeywordExtractConfigBuilder::default()
}

/// Get current set of stop words.
pub fn get_stop_words(&self) -> &BTreeSet<String> {
pub fn stop_words(&self) -> &BTreeSet<String> {
&self.stop_words
}

/// True if hmm is used during segmentation in `extract_tags`.
pub fn get_use_hmm(&self) -> bool {
pub fn use_hmm(&self) -> bool {
self.use_hmm
}

/// Sets whether or not to use hmm during segmentation in `extract_tags`.
pub fn set_use_hmm(&mut self, use_hmm: bool) {
self.use_hmm = use_hmm
}

/// Gets the minimum number of Unicode Scalar Values required per keyword.
pub fn get_min_keyword_length(&self) -> usize {
pub fn min_keyword_length(&self) -> usize {
self.min_keyword_length
}

/// Sets the minimum number of Unicode Scalar Values required per keyword.
///
/// The default is 2. There is likely not much reason to change this.
pub fn set_min_keyword_length(&mut self, min_keyword_length: usize) {
self.min_keyword_length = min_keyword_length
}

#[inline]
pub(crate) fn filter(&self, s: &str) -> bool {
s.chars().count() >= self.min_keyword_length && !self.stop_words.contains(&s.to_lowercase())
s.chars().count() >= self.min_keyword_length() && !self.stop_words.contains(&s.to_lowercase())
}
}

impl KeywordExtractConfigBuilder {
fn default_stop_words(&self) -> Result<BTreeSet<String>, KeywordExtractConfigBuilderError> {
Ok(DEFAULT_STOP_WORDS.clone())
}

/// Add a new stop word.
pub fn add_stop_word(&mut self, word: String) -> &mut Self {
if self.stop_words.is_none() {
self.stop_words = Some(self.default_stop_words().unwrap());
}
self.stop_words.as_mut().unwrap().insert(word);
self
}

/// Remove an existing stop word.
pub fn remove_stop_word(&mut self, word: &str) -> &mut Self {
if self.stop_words.is_none() {
self.stop_words = Some(self.default_stop_words().unwrap());
}
self.stop_words.as_mut().unwrap().remove(word);
self
}

/// Replace all stop words with new stop words set.
pub fn set_stop_words(&mut self, stop_words: BTreeSet<String>) -> &mut Self {
self.stop_words = Some(stop_words);
self
}
}

impl Default for KeywordExtractConfig {
fn default() -> Self {
KeywordExtractConfig::new(DEFAULT_STOP_WORDS.clone(), 2, false)
fn default() -> KeywordExtractConfig {
KeywordExtractConfigBuilder::default().build().unwrap()
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/keywords/textrank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::{BTreeSet, BinaryHeap};

use ordered_float::OrderedFloat;

use super::{Keyword, KeywordExtract, KeywordExtractConfig};
use super::{Keyword, KeywordExtract, KeywordExtractConfig, KeywordExtractConfigBuilder};
use crate::FxHashMap as HashMap;
use crate::Jieba;

Expand Down Expand Up @@ -102,7 +102,7 @@ impl TextRank {
impl Default for TextRank {
/// Creates TextRank with 5 Unicode Scalar Value spans
fn default() -> Self {
TextRank::new(5, KeywordExtractConfig::default())
TextRank::new(5, KeywordExtractConfigBuilder::default().build().unwrap())
}
}

Expand Down Expand Up @@ -142,7 +142,7 @@ impl KeywordExtract for TextRank {
/// );
/// ```
fn extract_keywords(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> {
let tags = jieba.tag(sentence, self.config.get_use_hmm());
let tags = jieba.tag(sentence, self.config.use_hmm());
let mut allowed_pos_set = BTreeSet::new();

for s in allowed_pos {
Expand All @@ -156,7 +156,7 @@ impl KeywordExtract for TextRank {
continue;
}

if word2id.get(t.word).is_none() {
if !word2id.contains_key(t.word) {
unique_words.push(String::from(t.word));
word2id.insert(String::from(t.word), unique_words.len() - 1);
}
Expand Down
9 changes: 6 additions & 3 deletions src/keywords/tfidf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::io::{self, BufRead, BufReader};

use ordered_float::OrderedFloat;

use super::{Keyword, KeywordExtract, KeywordExtractConfig};
use super::{Keyword, KeywordExtract, KeywordExtractConfig, KeywordExtractConfigBuilder};
use crate::FxHashMap as HashMap;
use crate::Jieba;

Expand Down Expand Up @@ -159,7 +159,10 @@ impl Default for TfIdf {
/// 2 Unicode Scalar Value minimum for keywords, and no hmm in segmentation.
fn default() -> Self {
let mut default_dict = BufReader::new(DEFAULT_IDF.as_bytes());
TfIdf::new(Some(&mut default_dict), KeywordExtractConfig::default())
TfIdf::new(
Some(&mut default_dict),
KeywordExtractConfigBuilder::default().build().unwrap(),
)
}
}

Expand Down Expand Up @@ -209,7 +212,7 @@ impl KeywordExtract for TfIdf {
/// );
/// ```
fn extract_keywords(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> {
let tags = jieba.tag(sentence, self.config.get_use_hmm());
let tags = jieba.tag(sentence, self.config.use_hmm());
let mut allowed_pos_set = BTreeSet::new();

for s in allowed_pos {
Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ pub use crate::keywords::tfidf::TfIdf;
#[cfg(any(feature = "tfidf", feature = "textrank"))]
pub use crate::keywords::{Keyword, KeywordExtract, KeywordExtractConfig, DEFAULT_STOP_WORDS};

#[cfg(any(feature = "tfidf", feature = "textrank"))]
#[macro_use]
extern crate derive_builder;

mod errors;
mod hmm;
#[cfg(any(feature = "tfidf", feature = "textrank"))]
Expand Down

0 comments on commit 1391b3d

Please sign in to comment.