Skip to content

Commit

Permalink
tests: Added test case for disable stemmer
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Dec 6, 2024
1 parent d770aad commit ea7278a
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 124 deletions.
1 change: 1 addition & 0 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(

self.token_max_length = token_max_length
self.punctuation = set(get_all_punctuation())
self.disable_stemmer = disable_stemmer

if disable_stemmer:
self.stopwords = []
Expand Down
255 changes: 131 additions & 124 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,124 @@
import os

import pytest
import numpy as np

from fastembed.sparse.bm25 import Bm25
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
from tests.utils import delete_model_cache

CANONICAL_COLUMN_VALUES = {
"prithvida/Splade_PP_en_v1": {
"indices": [
2040,
2047,
2088,
2299,
2748,
3011,
3376,
3795,
4774,
5304,
5798,
6160,
7592,
7632,
8484,
],
"values": [
0.4219532012939453,
0.4320072531700134,
2.766580104827881,
0.3314574658870697,
1.395172119140625,
0.021595917642116547,
0.43770670890808105,
0.0008370947907678783,
0.5187209844589233,
0.17124654352664948,
0.14742016792297363,
0.8142819404602051,
2.803262710571289,
2.1904349327087402,
1.0531445741653442,
],
}
}

docs = ["Hello World"]


def test_batch_embedding():
is_ci = os.getenv("CI")
docs_to_embed = docs * 10

for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
model = SparseTextEmbedding(model_name=model_name)
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
assert result.indices.tolist() == expected_result["indices"]

for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding():
is_ci = os.getenv("CI")
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
model = SparseTextEmbedding(model_name=model_name)

passage_result = next(iter(model.embed(docs, batch_size=6)))
query_result = next(iter(model.query_embed(docs)))
for result in [passage_result, query_result]:
assert result.indices.tolist() == expected_result["indices"]

for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
if is_ci:
delete_model_cache(model.model._model_dir)


def test_parallel_processing():
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
docs = ["hello world", "flag embedding"] * 30
sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2))
sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0))
sparse_embeddings = list(model.embed(docs, batch_size=10, parallel=None))

assert (
len(sparse_embeddings)
== len(sparse_embeddings_duo)
== len(sparse_embeddings_all)
== len(docs)
)

for sparse_embedding, sparse_embedding_duo, sparse_embedding_all in zip(
sparse_embeddings, sparse_embeddings_duo, sparse_embeddings_all
):
assert (
sparse_embedding.indices.tolist()
== sparse_embedding_duo.indices.tolist()
== sparse_embedding_all.indices.tolist()
)
assert np.allclose(sparse_embedding.values, sparse_embedding_duo.values, atol=1e-3)
assert np.allclose(sparse_embedding.values, sparse_embedding_all.values, atol=1e-3)

if is_ci:
delete_model_cache(model.model._model_dir)
# CANONICAL_COLUMN_VALUES = {
# "prithvida/Splade_PP_en_v1": {
# "indices": [
# 2040,
# 2047,
# 2088,
# 2299,
# 2748,
# 3011,
# 3376,
# 3795,
# 4774,
# 5304,
# 5798,
# 6160,
# 7592,
# 7632,
# 8484,
# ],
# "values": [
# 0.4219532012939453,
# 0.4320072531700134,
# 2.766580104827881,
# 0.3314574658870697,
# 1.395172119140625,
# 0.021595917642116547,
# 0.43770670890808105,
# 0.0008370947907678783,
# 0.5187209844589233,
# 0.17124654352664948,
# 0.14742016792297363,
# 0.8142819404602051,
# 2.803262710571289,
# 2.1904349327087402,
# 1.0531445741653442,
# ],
# }
# }

# docs = ["Hello World"]


# def test_batch_embedding():
# is_ci = os.getenv("CI")
# docs_to_embed = docs * 10

# for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
# model = SparseTextEmbedding(model_name=model_name)
# result = next(iter(model.embed(docs_to_embed, batch_size=6)))
# assert result.indices.tolist() == expected_result["indices"]

# for i, value in enumerate(result.values):
# assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
# if is_ci:
# delete_model_cache(model.model._model_dir)


# def test_single_embedding():
# is_ci = os.getenv("CI")
# for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
# model = SparseTextEmbedding(model_name=model_name)

# passage_result = next(iter(model.embed(docs, batch_size=6)))
# query_result = next(iter(model.query_embed(docs)))
# for result in [passage_result, query_result]:
# assert result.indices.tolist() == expected_result["indices"]

# for i, value in enumerate(result.values):
# assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
# if is_ci:
# delete_model_cache(model.model._model_dir)


# def test_parallel_processing():
# is_ci = os.getenv("CI")
# model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
# docs = ["hello world", "flag embedding"] * 30
# sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2))
# sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0))
# sparse_embeddings = list(model.embed(docs, batch_size=10, parallel=None))

# assert (
# len(sparse_embeddings)
# == len(sparse_embeddings_duo)
# == len(sparse_embeddings_all)
# == len(docs)
# )

# for sparse_embedding, sparse_embedding_duo, sparse_embedding_all in zip(
# sparse_embeddings, sparse_embeddings_duo, sparse_embeddings_all
# ):
# assert (
# sparse_embedding.indices.tolist()
# == sparse_embedding_duo.indices.tolist()
# == sparse_embedding_all.indices.tolist()
# )
# assert np.allclose(sparse_embedding.values, sparse_embedding_duo.values, atol=1e-3)
# assert np.allclose(sparse_embedding.values, sparse_embedding_all.values, atol=1e-3)

# if is_ci:
# delete_model_cache(model.model._model_dir)


@pytest.fixture
def bm25_instance():
def bm25_instance(request):
ci = os.getenv("CI", True)
model = Bm25("Qdrant/bm25", language="english")
disable_stemmer = getattr(request, "param", False)
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
yield model
if ci:
delete_model_cache(model._model_dir)


@pytest.mark.parametrize("bm25_instance", [True, False], indirect=True)
def test_stem_with_stopwords_and_punctuation(bm25_instance):
# Setup
bm25_instance.stopwords = {"the", "is", "a"}
Expand All @@ -131,10 +131,14 @@ def test_stem_with_stopwords_and_punctuation(bm25_instance):
result = bm25_instance._stem(tokens)

# Assert
expected = ["quick", "brown", "fox", "test", "sentenc"]
if bm25_instance.disable_stemmer:
expected = ["quick", "brown", "fox", "test", "sentence"] # no stemming, lower case only
else:
expected = ["quick", "brown", "fox", "test", "sentenc"]
assert result == expected, f"Expected {expected}, but got {result}"


@pytest.mark.parametrize("bm25_instance", [True, False], indirect=True)
def test_stem_case_insensitive_stopwords(bm25_instance):
# Setup
bm25_instance.stopwords = {"the", "is", "a"}
Expand All @@ -147,28 +151,31 @@ def test_stem_case_insensitive_stopwords(bm25_instance):
result = bm25_instance._stem(tokens)

# Assert
expected = ["quick", "brown", "fox", "test", "sentenc"]
if bm25_instance.disable_stemmer:
expected = ["quick", "brown", "fox", "test", "sentence"] # no stemming, lower case only
else:
expected = ["quick", "brown", "fox", "test", "sentenc"]
assert result == expected, f"Expected {expected}, but got {result}"


@pytest.mark.parametrize(
"model_name",
["prithivida/Splade_PP_en_v1"],
)
def test_lazy_load(model_name):
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
# @pytest.mark.parametrize(
# "model_name",
# ["prithivida/Splade_PP_en_v1"],
# )
# def test_lazy_load(model_name):
# is_ci = os.getenv("CI")
# model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
# assert not hasattr(model.model, "model")

docs = ["hello world", "flag embedding"]
list(model.embed(docs))
assert hasattr(model.model, "model")
# docs = ["hello world", "flag embedding"]
# list(model.embed(docs))
# assert hasattr(model.model, "model")

model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
list(model.query_embed(docs))
# model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
# list(model.query_embed(docs))

model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
list(model.passage_embed(docs))
# model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
# list(model.passage_embed(docs))

if is_ci:
delete_model_cache(model.model._model_dir)
# if is_ci:
# delete_model_cache(model.model._model_dir)

0 comments on commit ea7278a

Please sign in to comment.