From 197b381e8f19b52fb9f272d038c5a5ae4bfd66ac Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 09:38:04 +0200 Subject: [PATCH] tests: Added test for multitask embeddings --- tests/test_text_multitask_embeddings.py | 231 ++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 tests/test_text_multitask_embeddings.py diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py new file mode 100644 index 00000000..606b4138 --- /dev/null +++ b/tests/test_text_multitask_embeddings.py @@ -0,0 +1,231 @@ +import os + +import numpy as np +import pytest + +from fastembed import TextEmbedding +from tests.utils import delete_model_cache + + +CANONICAL_VECTOR_VALUES = { + "jinaai/jina-embeddings-v3": [ + { + "task_id": 0, + "vectors": np.array( + [ + [0.0623, -0.0402, 0.1706, -0.0143, 0.0617], + [-0.1064, -0.0733, 0.0353, 0.0096, 0.0667], + ] + ), + }, + { + "task_id": 1, + "vectors": np.array( + [ + [0.0513, -0.0247, 0.1751, -0.0075, 0.0679], + [-0.0987, -0.0786, 0.09, 0.0087, 0.0577], + ] + ), + }, + { + "task_id": 2, + "vectors": np.array( + [ + [0.094, -0.1065, 0.1305, 0.0547, 0.0556], + [0.0315, -0.1468, 0.065, 0.0568, 0.0546], + ] + ), + }, + { + "task_id": 3, + "vectors": np.array( + [ + [0.0606, -0.0877, 0.1384, 0.0065, 0.0722], + [-0.0502, -0.119, 0.032, 0.0514, 0.0689], + ] + ), + }, + { + "task_id": 4, + "vectors": np.array( + [ + [0.0911, -0.0341, 0.1305, -0.026, 0.0576], + [-0.1432, -0.05, 0.0133, 0.0464, 0.0789], + ] + ), + }, + ] +} +docs = ["Hello World", "Follow the white rabbit."] + + +def test_batch_embedding(): + is_ci = os.getenv("CI") + docs_to_embed = docs * 10 + default_task = 4 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} default task") + + embeddings = list(model.embed(documents=docs_to_embed, batch_size=6)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs_to_embed), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding(): + is_ci = os.getenv("CI") + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + for task in CANONICAL_VECTOR_VALUES[model_name]: + print(f"evaluating {model_name} task_id: {task['task_id']}") + + embeddings = list(model.embed(documents=docs, task_id=task["task_id"])) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = task["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding_query(): + is_ci = os.getenv("CI") + task_id = 0 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} query_embed task_id: {task_id}") + + embeddings = list(model.query_embed(query=docs)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding_passage(): + is_ci = os.getenv("CI") + task_id = 1 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} passage_embed task_id: {task_id}") + + embeddings = list(model.passage_embed(texts=docs)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_parallel_processing(): + is_ci = os.getenv("CI") + + docs = ["Hello World", "Follow the white rabbit."] * 100 + + model_name = "jinaai/jina-embeddings-v3" + dim = 1024 + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) + + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) + + assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim + assert np.allclose(embeddings, embeddings_2, atol=1e-4) + assert np.allclose(embeddings, embeddings_3, atol=1e-4) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize( + "model_name", + ["jinaai/jina-embeddings-v3"], +) +def test_lazy_load(model_name): + is_ci = os.getenv("CI") + model = TextEmbedding(model_name=model_name, lazy_load=True, cache_dir="models") + assert not hasattr(model.model, "model") + + list(model.embed(docs)) + assert hasattr(model.model, "model") + + if is_ci: + delete_model_cache(model.model._model_dir)