diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index 7bfcac97..1d7ef36c 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -92,7 +92,7 @@ class Bm25(SparseTextEmbeddingBase): b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length. Defaults to 0.75. avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0. - language (str, optional): Specifies the language for the stemmer. + language (str): Specifies the language for the stemmer. disable_stemmer (bool): Disable the stemmer. Raises: ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. @@ -105,16 +105,14 @@ def __init__( k: float = 1.2, b: float = 0.75, avg_len: float = 256.0, - language: Optional[str] = "english", + language: str = "english", token_max_length: int = 40, disable_stemmer: bool = False, **kwargs, ): super().__init__(model_name, cache_dir, **kwargs) - if language is None: - language = "english" - elif language not in supported_languages: + if language not in supported_languages: raise ValueError(f"{language} language is not supported") else: self.language = language @@ -153,7 +151,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]: return supported_bm25_models @classmethod - def _load_stopwords(cls, model_dir: Path, language: Optional[str]) -> list[str]: + def _load_stopwords(cls, model_dir: Path, language: str) -> list[str]: stopwords_path = model_dir / f"{language}.txt" if not stopwords_path.exists(): return [] @@ -247,10 +245,7 @@ def _stem(self, tokens: list[str]) -> list[str]: if len(token) > self.token_max_length: continue - if self.stemmer: - stemmed_token = self.stemmer.stem_word(lower_token) - else: - stemmed_token = lower_token + stemmed_token = self.stemmer.stem_word(lower_token) if self.stemmer else lower_token if stemmed_token: stemmed_tokens.append(stemmed_token)