From 0f79d3f9d85af8aa9f97b8bb8308ec6706292853 Mon Sep 17 00:00:00 2001 From: Hossam Hagag <90828745+hh-space-invader@users.noreply.github.com> Date: Tue, 10 Dec 2024 22:12:12 +0200 Subject: [PATCH] feat: Added a toggle to disable stemmer in bm25 (#416) * feat: Added a toggle to disable stemmer in bm25 * refactor: Refactored how to disable stemming in bm25 * refactor: Refactored the way of disabling stemmer in bm25 * new: Added english fallback if language = None * tests: Added test case for disable stemmer * fix: Fix language to be only string * tests: Updated bm25 toggle stemmer tests * refactor: fix stopwords type * fix: fix param propagation in parallel embed in bm25 --------- Co-authored-by: George Panchuk --- fastembed/sparse/bm25.py | 22 ++++++++++++++++++---- tests/test_sparse_embeddings.py | 21 +++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index 485d476d..9db6c8da 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -92,6 +92,8 @@ class Bm25(SparseTextEmbeddingBase): b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length. Defaults to 0.75. avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0. + language (str): Specifies the language for the stemmer. + disable_stemmer (bool): Disable the stemmer. Raises: ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. """ @@ -105,6 +107,7 @@ def __init__( avg_len: float = 256.0, language: str = "english", token_max_length: int = 40, + disable_stemmer: bool = False, **kwargs, ): super().__init__(model_name, cache_dir, **kwargs) @@ -127,9 +130,15 @@ def __init__( self.token_max_length = token_max_length self.punctuation = set(get_all_punctuation()) - self.stopwords = set(self._load_stopwords(self._model_dir, self.language)) + self.disable_stemmer = disable_stemmer + + if disable_stemmer: + self.stopwords = set() + self.stemmer = None + else: + self.stopwords = set(self._load_stopwords(self._model_dir, self.language)) + self.stemmer = SnowballStemmer(language) - self.stemmer = SnowballStemmer(language) self.tokenizer = SimpleTokenizer @classmethod @@ -182,6 +191,9 @@ def _embed_documents( "k": self.k, "b": self.b, "avg_len": self.avg_len, + "language": self.language, + "token_max_length": self.token_max_length, + "disable_stemmer": self.disable_stemmer, } pool = ParallelWorkerPool( num_workers=parallel or 1, @@ -225,16 +237,18 @@ def embed( def _stem(self, tokens: list[str]) -> list[str]: stemmed_tokens = [] for token in tokens: + lower_token = token.lower() + if token in self.punctuation: continue - if token.lower() in self.stopwords: + if lower_token in self.stopwords: continue if len(token) > self.token_max_length: continue - stemmed_token = self.stemmer.stem_word(token.lower()) + stemmed_token = self.stemmer.stem_word(lower_token) if self.stemmer else lower_token if stemmed_token: stemmed_tokens.append(stemmed_token) diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 8040b46f..236b1de4 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -151,6 +151,27 @@ def test_stem_case_insensitive_stopwords(bm25_instance): assert result == expected, f"Expected {expected}, but got {result}" +@pytest.mark.parametrize("disable_stemmer", [True, False]) +def test_disable_stemmer_behavior(disable_stemmer): + # Setup + model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer) + model.stopwords = {"the", "is", "a"} + model.punctuation = {".", ",", "!"} + + # Test data + tokens = ["The", "quick", "brown", "fox", "is", "a", "test", "sentence", ".", "!"] + + # Execute + result = model._stem(tokens) + + # Assert + if disable_stemmer: + expected = ["quick", "brown", "fox", "test", "sentence"] # no stemming, lower case only + else: + expected = ["quick", "brown", "fox", "test", "sentenc"] + assert result == expected, f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( "model_name", ["prithivida/Splade_PP_en_v1"],