Skip to content

Commit

Permalink
fix: Fix language to be only string
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Dec 10, 2024
1 parent ea7278a commit 8bdbca4
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Bm25(SparseTextEmbeddingBase):
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
Defaults to 0.75.
avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
language (str, optional): Specifies the language for the stemmer.
language (str): Specifies the language for the stemmer.
disable_stemmer (bool): Disable the stemmer.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
Expand All @@ -105,16 +105,14 @@ def __init__(
k: float = 1.2,
b: float = 0.75,
avg_len: float = 256.0,
language: Optional[str] = "english",
language: str = "english",
token_max_length: int = 40,
disable_stemmer: bool = False,
**kwargs,
):
super().__init__(model_name, cache_dir, **kwargs)

if language is None:
language = "english"
elif language not in supported_languages:
if language not in supported_languages:
raise ValueError(f"{language} language is not supported")
else:
self.language = language
Expand Down Expand Up @@ -153,7 +151,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
return supported_bm25_models

@classmethod
def _load_stopwords(cls, model_dir: Path, language: Optional[str]) -> list[str]:
def _load_stopwords(cls, model_dir: Path, language: str) -> list[str]:
stopwords_path = model_dir / f"{language}.txt"
if not stopwords_path.exists():
return []
Expand Down Expand Up @@ -247,10 +245,7 @@ def _stem(self, tokens: list[str]) -> list[str]:
if len(token) > self.token_max_length:
continue

if self.stemmer:
stemmed_token = self.stemmer.stem_word(lower_token)
else:
stemmed_token = lower_token
stemmed_token = self.stemmer.stem_word(lower_token) if self.stemmer else lower_token

if stemmed_token:
stemmed_tokens.append(stemmed_token)
Expand Down

0 comments on commit 8bdbca4

Please sign in to comment.