Skip to content

Commit

Permalink
tests: Updated bm25 toggle stemmer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Dec 10, 2024
1 parent 8bdbca4 commit 25ff9e3
Showing 1 changed file with 142 additions and 128 deletions.
270 changes: 142 additions & 128 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,124 @@
import os

import pytest
import numpy as np

from fastembed.sparse.bm25 import Bm25
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
from tests.utils import delete_model_cache

# CANONICAL_COLUMN_VALUES = {
# "prithvida/Splade_PP_en_v1": {
# "indices": [
# 2040,
# 2047,
# 2088,
# 2299,
# 2748,
# 3011,
# 3376,
# 3795,
# 4774,
# 5304,
# 5798,
# 6160,
# 7592,
# 7632,
# 8484,
# ],
# "values": [
# 0.4219532012939453,
# 0.4320072531700134,
# 2.766580104827881,
# 0.3314574658870697,
# 1.395172119140625,
# 0.021595917642116547,
# 0.43770670890808105,
# 0.0008370947907678783,
# 0.5187209844589233,
# 0.17124654352664948,
# 0.14742016792297363,
# 0.8142819404602051,
# 2.803262710571289,
# 2.1904349327087402,
# 1.0531445741653442,
# ],
# }
# }

# docs = ["Hello World"]


# def test_batch_embedding():
# is_ci = os.getenv("CI")
# docs_to_embed = docs * 10

# for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
# model = SparseTextEmbedding(model_name=model_name)
# result = next(iter(model.embed(docs_to_embed, batch_size=6)))
# assert result.indices.tolist() == expected_result["indices"]

# for i, value in enumerate(result.values):
# assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
# if is_ci:
# delete_model_cache(model.model._model_dir)


# def test_single_embedding():
# is_ci = os.getenv("CI")
# for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
# model = SparseTextEmbedding(model_name=model_name)

# passage_result = next(iter(model.embed(docs, batch_size=6)))
# query_result = next(iter(model.query_embed(docs)))
# for result in [passage_result, query_result]:
# assert result.indices.tolist() == expected_result["indices"]

# for i, value in enumerate(result.values):
# assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
# if is_ci:
# delete_model_cache(model.model._model_dir)


# def test_parallel_processing():
# is_ci = os.getenv("CI")
# model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
# docs = ["hello world", "flag embedding"] * 30
# sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2))
# sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0))
# sparse_embeddings = list(model.embed(docs, batch_size=10, parallel=None))

# assert (
# len(sparse_embeddings)
# == len(sparse_embeddings_duo)
# == len(sparse_embeddings_all)
# == len(docs)
# )

# for sparse_embedding, sparse_embedding_duo, sparse_embedding_all in zip(
# sparse_embeddings, sparse_embeddings_duo, sparse_embeddings_all
# ):
# assert (
# sparse_embedding.indices.tolist()
# == sparse_embedding_duo.indices.tolist()
# == sparse_embedding_all.indices.tolist()
# )
# assert np.allclose(sparse_embedding.values, sparse_embedding_duo.values, atol=1e-3)
# assert np.allclose(sparse_embedding.values, sparse_embedding_all.values, atol=1e-3)

# if is_ci:
# delete_model_cache(model.model._model_dir)
CANONICAL_COLUMN_VALUES = {
"prithvida/Splade_PP_en_v1": {
"indices": [
2040,
2047,
2088,
2299,
2748,
3011,
3376,
3795,
4774,
5304,
5798,
6160,
7592,
7632,
8484,
],
"values": [
0.4219532012939453,
0.4320072531700134,
2.766580104827881,
0.3314574658870697,
1.395172119140625,
0.021595917642116547,
0.43770670890808105,
0.0008370947907678783,
0.5187209844589233,
0.17124654352664948,
0.14742016792297363,
0.8142819404602051,
2.803262710571289,
2.1904349327087402,
1.0531445741653442,
],
}
}

docs = ["Hello World"]


def test_batch_embedding():
is_ci = os.getenv("CI")
docs_to_embed = docs * 10

for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
model = SparseTextEmbedding(model_name=model_name)
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
assert result.indices.tolist() == expected_result["indices"]

for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding():
is_ci = os.getenv("CI")
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
model = SparseTextEmbedding(model_name=model_name)

passage_result = next(iter(model.embed(docs, batch_size=6)))
query_result = next(iter(model.query_embed(docs)))
for result in [passage_result, query_result]:
assert result.indices.tolist() == expected_result["indices"]

for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
if is_ci:
delete_model_cache(model.model._model_dir)


def test_parallel_processing():
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
docs = ["hello world", "flag embedding"] * 30
sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2))
sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0))
sparse_embeddings = list(model.embed(docs, batch_size=10, parallel=None))

assert (
len(sparse_embeddings)
== len(sparse_embeddings_duo)
== len(sparse_embeddings_all)
== len(docs)
)

for sparse_embedding, sparse_embedding_duo, sparse_embedding_all in zip(
sparse_embeddings, sparse_embeddings_duo, sparse_embeddings_all
):
assert (
sparse_embedding.indices.tolist()
== sparse_embedding_duo.indices.tolist()
== sparse_embedding_all.indices.tolist()
)
assert np.allclose(sparse_embedding.values, sparse_embedding_duo.values, atol=1e-3)
assert np.allclose(sparse_embedding.values, sparse_embedding_all.values, atol=1e-3)

if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.fixture
def bm25_instance(request):
def bm25_instance():
ci = os.getenv("CI", True)
disable_stemmer = getattr(request, "param", False)
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
model = Bm25("Qdrant/bm25", language="english")
yield model
if ci:
delete_model_cache(model._model_dir)


@pytest.mark.parametrize("bm25_instance", [True, False], indirect=True)
def test_stem_with_stopwords_and_punctuation(bm25_instance):
# Setup
bm25_instance.stopwords = {"the", "is", "a"}
Expand All @@ -131,14 +131,10 @@ def test_stem_with_stopwords_and_punctuation(bm25_instance):
result = bm25_instance._stem(tokens)

# Assert
if bm25_instance.disable_stemmer:
expected = ["quick", "brown", "fox", "test", "sentence"] # no stemming, lower case only
else:
expected = ["quick", "brown", "fox", "test", "sentenc"]
expected = ["quick", "brown", "fox", "test", "sentenc"]
assert result == expected, f"Expected {expected}, but got {result}"


@pytest.mark.parametrize("bm25_instance", [True, False], indirect=True)
def test_stem_case_insensitive_stopwords(bm25_instance):
# Setup
bm25_instance.stopwords = {"the", "is", "a"}
Expand All @@ -151,31 +147,49 @@ def test_stem_case_insensitive_stopwords(bm25_instance):
result = bm25_instance._stem(tokens)

# Assert
if bm25_instance.disable_stemmer:
expected = ["quick", "brown", "fox", "test", "sentenc"]
assert result == expected, f"Expected {expected}, but got {result}"


@pytest.mark.parametrize("disable_stemmer", [True, False])
def test_disable_stemmer_behavior(disable_stemmer):
# Setup
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
model.stopwords = {"the", "is", "a"}
model.punctuation = {".", ",", "!"}

# Test data
tokens = ["The", "quick", "brown", "fox", "is", "a", "test", "sentence", ".", "!"]

# Execute
result = model._stem(tokens)

# Assert
if disable_stemmer:
expected = ["quick", "brown", "fox", "test", "sentence"] # no stemming, lower case only
else:
expected = ["quick", "brown", "fox", "test", "sentenc"]
assert result == expected, f"Expected {expected}, but got {result}"


# @pytest.mark.parametrize(
# "model_name",
# ["prithivida/Splade_PP_en_v1"],
# )
# def test_lazy_load(model_name):
# is_ci = os.getenv("CI")
# model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
# assert not hasattr(model.model, "model")
@pytest.mark.parametrize(
"model_name",
["prithivida/Splade_PP_en_v1"],
)
def test_lazy_load(model_name):
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")

# docs = ["hello world", "flag embedding"]
# list(model.embed(docs))
# assert hasattr(model.model, "model")
docs = ["hello world", "flag embedding"]
list(model.embed(docs))
assert hasattr(model.model, "model")

# model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
# list(model.query_embed(docs))
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
list(model.query_embed(docs))

# model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
# list(model.passage_embed(docs))
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
list(model.passage_embed(docs))

# if is_ci:
# delete_model_cache(model.model._model_dir)
if is_ci:
delete_model_cache(model.model._model_dir)

0 comments on commit 25ff9e3

Please sign in to comment.