diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index eb9658ec..7d0b393f 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -92,7 +92,8 @@ class Bm25(SparseTextEmbeddingBase): b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length. Defaults to 0.75. avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0. - language (str, optional): Specifies the language for the stemmer. Set to None to disable stemming. + language (str, optional): Specifies the language for the stemmer. + disable_stemmer (bool): Disable the stemmer. Raises: ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. """ @@ -106,6 +107,7 @@ def __init__( avg_len: float = 256.0, language: Optional[str] = "english", token_max_length: int = 40, + disable_stemmer: bool = False, **kwargs, ): super().__init__(model_name, cache_dir, **kwargs) @@ -128,9 +130,14 @@ def __init__( self.token_max_length = token_max_length self.punctuation = set(get_all_punctuation()) - self.stopwords = set(self._load_stopwords(self._model_dir, self.language)) - self.stemmer = SnowballStemmer(language) + if disable_stemmer: + self.stopwords = [] + self.stemmer = None + else: + self.stopwords = set(self._load_stopwords(self._model_dir, self.language)) + self.stemmer = SnowballStemmer(language) + self.tokenizer = SimpleTokenizer @classmethod @@ -226,19 +233,21 @@ def embed( def _stem(self, tokens: list[str]) -> list[str]: stemmed_tokens = [] for token in tokens: + lower_token = token.lower() + if token in self.punctuation: continue - if token.lower() in self.stopwords: + if lower_token in self.stopwords: continue if len(token) > self.token_max_length: continue if self.stemmer: - stemmed_token = self.stemmer.stem_word(token.lower()) + stemmed_token = self.stemmer.stem_word(lower_token) else: - stemmed_token = token.lower() + stemmed_token = lower_token if stemmed_token: stemmed_tokens.append(stemmed_token)