Skip to content

Commit

Permalink
feat: Added a toggle to disable stemmer in bm25 (#416)
Browse files Browse the repository at this point in the history
* feat: Added a toggle to disable stemmer in bm25

* refactor: Refactored how to disable stemming in bm25

* refactor: Refactored the way of disabling stemmer in bm25

* new: Added english fallback if language = None

* tests: Added test case for disable stemmer

* fix: Fix language to be only string

* tests: Updated bm25 toggle stemmer tests

* refactor: fix stopwords type

* fix: fix param propagation in parallel embed in bm25

---------

Co-authored-by: George Panchuk <[email protected]>
  • Loading branch information
hh-space-invader and joein authored Dec 10, 2024
1 parent 2ef9c38 commit 0f79d3f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
22 changes: 18 additions & 4 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class Bm25(SparseTextEmbeddingBase):
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
Defaults to 0.75.
avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
language (str): Specifies the language for the stemmer.
disable_stemmer (bool): Disable the stemmer.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""
Expand All @@ -105,6 +107,7 @@ def __init__(
avg_len: float = 256.0,
language: str = "english",
token_max_length: int = 40,
disable_stemmer: bool = False,
**kwargs,
):
super().__init__(model_name, cache_dir, **kwargs)
Expand All @@ -127,9 +130,15 @@ def __init__(

self.token_max_length = token_max_length
self.punctuation = set(get_all_punctuation())
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
self.disable_stemmer = disable_stemmer

if disable_stemmer:
self.stopwords = set()
self.stemmer = None
else:
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
self.stemmer = SnowballStemmer(language)

self.stemmer = SnowballStemmer(language)
self.tokenizer = SimpleTokenizer

@classmethod
Expand Down Expand Up @@ -182,6 +191,9 @@ def _embed_documents(
"k": self.k,
"b": self.b,
"avg_len": self.avg_len,
"language": self.language,
"token_max_length": self.token_max_length,
"disable_stemmer": self.disable_stemmer,
}
pool = ParallelWorkerPool(
num_workers=parallel or 1,
Expand Down Expand Up @@ -225,16 +237,18 @@ def embed(
def _stem(self, tokens: list[str]) -> list[str]:
stemmed_tokens = []
for token in tokens:
lower_token = token.lower()

if token in self.punctuation:
continue

if token.lower() in self.stopwords:
if lower_token in self.stopwords:
continue

if len(token) > self.token_max_length:
continue

stemmed_token = self.stemmer.stem_word(token.lower())
stemmed_token = self.stemmer.stem_word(lower_token) if self.stemmer else lower_token

if stemmed_token:
stemmed_tokens.append(stemmed_token)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,27 @@ def test_stem_case_insensitive_stopwords(bm25_instance):
assert result == expected, f"Expected {expected}, but got {result}"


@pytest.mark.parametrize("disable_stemmer", [True, False])
def test_disable_stemmer_behavior(disable_stemmer):
# Setup
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
model.stopwords = {"the", "is", "a"}
model.punctuation = {".", ",", "!"}

# Test data
tokens = ["The", "quick", "brown", "fox", "is", "a", "test", "sentence", ".", "!"]

# Execute
result = model._stem(tokens)

# Assert
if disable_stemmer:
expected = ["quick", "brown", "fox", "test", "sentence"] # no stemming, lower case only
else:
expected = ["quick", "brown", "fox", "test", "sentenc"]
assert result == expected, f"Expected {expected}, but got {result}"


@pytest.mark.parametrize(
"model_name",
["prithivida/Splade_PP_en_v1"],
Expand Down

0 comments on commit 0f79d3f

Please sign in to comment.