Skip to content

Commit

Permalink
refactor: Refactored the way of disabling stemmer in bm25
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Dec 6, 2024
1 parent be1da05 commit d521eec
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ class Bm25(SparseTextEmbeddingBase):
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
Defaults to 0.75.
avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
language (str, optional): Specifies the language for the stemmer. Set to None to disable stemming.
language (str, optional): Specifies the language for the stemmer.
disable_stemmer (bool): Disable the stemmer.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""
Expand All @@ -106,6 +107,7 @@ def __init__(
avg_len: float = 256.0,
language: Optional[str] = "english",
token_max_length: int = 40,
disable_stemmer: bool = False,
**kwargs,
):
super().__init__(model_name, cache_dir, **kwargs)
Expand All @@ -128,9 +130,14 @@ def __init__(

self.token_max_length = token_max_length
self.punctuation = set(get_all_punctuation())
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))

self.stemmer = SnowballStemmer(language)
if disable_stemmer:
self.stopwords = []
self.stemmer = None
else:
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
self.stemmer = SnowballStemmer(language)

self.tokenizer = SimpleTokenizer

@classmethod
Expand Down Expand Up @@ -226,19 +233,21 @@ def embed(
def _stem(self, tokens: list[str]) -> list[str]:
stemmed_tokens = []
for token in tokens:
lower_token = token.lower()

if token in self.punctuation:
continue

if token.lower() in self.stopwords:
if lower_token in self.stopwords:
continue

if len(token) > self.token_max_length:
continue

if self.stemmer:
stemmed_token = self.stemmer.stem_word(token.lower())
stemmed_token = self.stemmer.stem_word(lower_token)
else:
stemmed_token = token.lower()
stemmed_token = lower_token

if stemmed_token:
stemmed_tokens.append(stemmed_token)
Expand Down

0 comments on commit d521eec

Please sign in to comment.