From 8be10f856789f4869581bcac6b71b0cd13acac6e Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 4 Dec 2024 10:09:38 +0200 Subject: [PATCH] refactor: Refactored how to disable stemming in bm25 --- fastembed/sparse/bm25.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index 170c517b..eb9658ec 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -92,6 +92,7 @@ class Bm25(SparseTextEmbeddingBase): b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length. Defaults to 0.75. avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0. + language (str, optional): Specifies the language for the stemmer. Set to None to disable stemming. Raises: ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. """ @@ -103,14 +104,13 @@ def __init__( k: float = 1.2, b: float = 0.75, avg_len: float = 256.0, - language: str = "english", + language: Optional[str] = "english", token_max_length: int = 40, - disable_stemmer: bool = False, **kwargs, ): super().__init__(model_name, cache_dir, **kwargs) - if language not in supported_languages: + if language is not None and language not in supported_languages: raise ValueError(f"{language} language is not supported") else: self.language = language @@ -130,8 +130,7 @@ def __init__( self.punctuation = set(get_all_punctuation()) self.stopwords = set(self._load_stopwords(self._model_dir, self.language)) - self.disable_stemmer = disable_stemmer - self.stemmer = SnowballStemmer(language) if not disable_stemmer else None + self.stemmer = SnowballStemmer(language) self.tokenizer = SimpleTokenizer @classmethod @@ -144,7 +143,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]: return supported_bm25_models @classmethod - def _load_stopwords(cls, model_dir: Path, language: str) -> list[str]: + def _load_stopwords(cls, model_dir: Path, language: Optional[str]) -> list[str]: stopwords_path = model_dir / f"{language}.txt" if not stopwords_path.exists(): return [] @@ -225,9 +224,6 @@ def embed( ) def _stem(self, tokens: list[str]) -> list[str]: - if self.disable_stemmer: - return tokens - stemmed_tokens = [] for token in tokens: if token in self.punctuation: @@ -239,7 +235,10 @@ def _stem(self, tokens: list[str]) -> list[str]: if len(token) > self.token_max_length: continue - stemmed_token = self.stemmer.stem_word(token.lower()) + if self.stemmer: + stemmed_token = self.stemmer.stem_word(token.lower()) + else: + stemmed_token = token.lower() if stemmed_token: stemmed_tokens.append(stemmed_token)