Skip to content

Commit

Permalink
DRY-up API/impl by extracting a KeywordExtractConfig struct
Browse files Browse the repository at this point in the history
  • Loading branch information
awong-dev committed Apr 8, 2024
1 parent d6f8fb7 commit 350fa11
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 161 deletions.
73 changes: 73 additions & 0 deletions src/keywords/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,79 @@ pub struct Keyword {
pub weight: f64,
}

#[derive(Debug)]
pub struct KeywordExtractConfig {
stop_words: BTreeSet<String>,
min_keyword_length: usize,
use_hmm: bool,
}

impl KeywordExtractConfig {
/// Creates a KeywordExtractConfig state that contains filter criteria as
/// well as segmentation configuration for use by keyword extraction
/// implementations.
pub fn new(stop_words: BTreeSet<String>, min_keyword_length: usize, use_hmm: bool) -> Self {
KeywordExtractConfig {
stop_words,
min_keyword_length,
use_hmm,
}
}

/// Add a new stop word.
pub fn add_stop_word(&mut self, word: String) -> bool {
self.stop_words.insert(word)
}

/// Remove an existing stop word.
pub fn remove_stop_word(&mut self, word: &str) -> bool {
self.stop_words.remove(word)
}

/// Replace all stop words with new stop words set.
pub fn set_stop_words(&mut self, stop_words: BTreeSet<String>) {
self.stop_words = stop_words
}

/// Get current set of stop words.
pub fn get_stop_words(&self) -> &BTreeSet<String> {
&self.stop_words
}

/// True if hmm is used during segmentation in `extract_tags`.
pub fn get_use_hmm(&self) -> bool {
self.use_hmm
}

/// Sets whether or not to use hmm during segmentation in `extract_tags`.
pub fn set_use_hmm(&mut self, use_hmm: bool) {
self.use_hmm = use_hmm
}

/// Gets the minimum number of Unicode Scalar Values required per keyword.
pub fn get_min_keyword_length(&self) -> usize {
self.min_keyword_length
}

/// Sets the minimum number of Unicode Scalar Values required per keyword.
///
/// The default is 2. There is likely not much reason to change this.
pub fn set_min_keyword_length(&mut self, min_keyword_length: usize) {
self.min_keyword_length = min_keyword_length
}

#[inline]
pub fn filter(&self, s: &str) -> bool {
s.chars().count() >= self.min_keyword_length && !self.stop_words.contains(&s.to_lowercase())
}
}

impl Default for KeywordExtractConfig {
fn default() -> Self {
KeywordExtractConfig::new(DEFAULT_STOP_WORDS.clone(), 2, false)
}
}

pub trait KeywordExtract {
fn extract_tags(&self, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword>;
}
Expand Down
83 changes: 13 additions & 70 deletions src/keywords/textrank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::{BTreeSet, BinaryHeap};

use ordered_float::OrderedFloat;

use super::{JiebaKeywordExtract, Keyword, KeywordExtract, DEFAULT_STOP_WORDS};
use super::{JiebaKeywordExtract, Keyword, KeywordExtract, KeywordExtractConfig};
use crate::FxHashMap as HashMap;
use crate::Jieba;

Expand Down Expand Up @@ -74,9 +74,7 @@ impl StateDiagram {
#[derive(Debug)]
pub struct UnboundTextRank {
span: usize,
stop_words: BTreeSet<String>,
min_keyword_length: usize,
use_hmm: bool,
config: KeywordExtractConfig,
}

impl UnboundTextRank {
Expand All @@ -93,78 +91,23 @@ impl UnboundTextRank {
/// BTreeSet::from(["a", "the", "of"].map(|s| s.to_string()));
/// jieba_rs::UnboundTextRank::new(
/// 5,
/// stop_words,
/// 2,
/// false);
/// KeywordExtractConfig::default());
/// ```
pub fn new(span: usize, stop_words: BTreeSet<String>, min_keyword_length: usize, use_hmm: bool) -> Self {
UnboundTextRank {
stop_words,
span,
min_keyword_length,
use_hmm,
}
}

/// Add a new stop word.
pub fn add_stop_word(&mut self, word: String) -> bool {
self.stop_words.insert(word)
}

/// Remove an existing stop word.
pub fn remove_stop_word(&mut self, word: &str) -> bool {
self.stop_words.remove(word)
}

/// Replace all stop words with new stop words set.
pub fn set_stop_words(&mut self, stop_words: BTreeSet<String>) {
self.stop_words = stop_words
}

/// Get current set of stop words.
pub fn get_stop_words(&self) -> &BTreeSet<String> {
&self.stop_words
}

/// True if hmm is used during segmentation in `extract_tags`.
pub fn get_use_hmm(&self) -> bool {
self.use_hmm
}

/// Sets whether or not to use hmm during segmentation in `extract_tags`.
pub fn set_use_hmm(&mut self, use_hmm: bool) {
self.use_hmm = use_hmm
}

/// Gets the minimum number of Unicode Scalar Values required per keyword.
pub fn get_min_keyword_length(&self) -> usize {
self.min_keyword_length
}

/// Sets the minimum number of Unicode Scalar Values required per keyword.
///
/// The default is 2. There is likely not much reason to change this.
pub fn set_min_keyword_length(&mut self, min_keyword_length: usize) {
self.min_keyword_length = min_keyword_length
}

#[inline]
fn filter(&self, s: &str) -> bool {
s.chars().count() >= self.min_keyword_length && !self.stop_words.contains(&s.to_lowercase())
pub fn new(span: usize, config: KeywordExtractConfig) -> Self {
UnboundTextRank { span, config }
}
}

impl Default for UnboundTextRank {
/// Creates UnboundTextRank with 5 Unicode Scalar Value spans,
/// DEFAULT_STOP_WORDS, and no hmm in segmentation.
/// Creates UnboundTextRank with 5 Unicode Scalar Value spans
fn default() -> Self {
UnboundTextRank::new(5, DEFAULT_STOP_WORDS.clone(), 2, false)
UnboundTextRank::new(5, KeywordExtractConfig::default())
}
}

impl JiebaKeywordExtract for UnboundTextRank {
fn extract_tags(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> {
let tags = jieba.tag(sentence, self.use_hmm);
let tags = jieba.tag(sentence, self.config.get_use_hmm());
let mut allowed_pos_set = BTreeSet::new();

for s in allowed_pos {
Expand All @@ -190,7 +133,7 @@ impl JiebaKeywordExtract for UnboundTextRank {
continue;
}

if !self.filter(t.word) {
if !self.config.filter(t.word) {
continue;
}

Expand All @@ -203,7 +146,7 @@ impl JiebaKeywordExtract for UnboundTextRank {
continue;
}

if !self.filter(tags[j].word) {
if !self.config.filter(tags[j].word) {
continue;
}

Expand Down Expand Up @@ -267,17 +210,17 @@ impl<'a> TextRank<'a> {

/// Add a new stop word
pub fn add_stop_word(&mut self, word: String) -> bool {
self.unbound_text_rank.add_stop_word(word)
self.unbound_text_rank.config.add_stop_word(word)
}

/// Remove an existing stop word
pub fn remove_stop_word(&mut self, word: &str) -> bool {
self.unbound_text_rank.remove_stop_word(word)
self.unbound_text_rank.config.remove_stop_word(word)
}

/// Replace all stop words with new stop words set
pub fn set_stop_words(&mut self, stop_words: BTreeSet<String>) {
self.unbound_text_rank.set_stop_words(stop_words)
self.unbound_text_rank.config.set_stop_words(stop_words)
}
}

Expand Down
Loading

0 comments on commit 350fa11

Please sign in to comment.