Skip to content

Commit

Permalink
refactor: Refactored how to disable stemming in bm25
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Dec 4, 2024
1 parent 530ea5a commit 8be10f8
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Bm25(SparseTextEmbeddingBase):
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
Defaults to 0.75.
avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
language (str, optional): Specifies the language for the stemmer. Set to None to disable stemming.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""
Expand All @@ -103,14 +104,13 @@ def __init__(
k: float = 1.2,
b: float = 0.75,
avg_len: float = 256.0,
language: str = "english",
language: Optional[str] = "english",
token_max_length: int = 40,
disable_stemmer: bool = False,
**kwargs,
):
super().__init__(model_name, cache_dir, **kwargs)

if language not in supported_languages:
if language is not None and language not in supported_languages:
raise ValueError(f"{language} language is not supported")
else:
self.language = language
Expand All @@ -130,8 +130,7 @@ def __init__(
self.punctuation = set(get_all_punctuation())
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))

self.disable_stemmer = disable_stemmer
self.stemmer = SnowballStemmer(language) if not disable_stemmer else None
self.stemmer = SnowballStemmer(language)
self.tokenizer = SimpleTokenizer

@classmethod
Expand All @@ -144,7 +143,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
return supported_bm25_models

@classmethod
def _load_stopwords(cls, model_dir: Path, language: str) -> list[str]:
def _load_stopwords(cls, model_dir: Path, language: Optional[str]) -> list[str]:
stopwords_path = model_dir / f"{language}.txt"
if not stopwords_path.exists():
return []
Expand Down Expand Up @@ -225,9 +224,6 @@ def embed(
)

def _stem(self, tokens: list[str]) -> list[str]:
if self.disable_stemmer:
return tokens

stemmed_tokens = []
for token in tokens:
if token in self.punctuation:
Expand All @@ -239,7 +235,10 @@ def _stem(self, tokens: list[str]) -> list[str]:
if len(token) > self.token_max_length:
continue

stemmed_token = self.stemmer.stem_word(token.lower())
if self.stemmer:
stemmed_token = self.stemmer.stem_word(token.lower())
else:
stemmed_token = token.lower()

if stemmed_token:
stemmed_tokens.append(stemmed_token)
Expand Down

0 comments on commit 8be10f8

Please sign in to comment.