From 8b4eb26d4fc790d00796a0074f2dd9f026ef81a7 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Thu, 19 Dec 2024 12:57:14 +0200 Subject: [PATCH 1/8] new: Added jina embedding v3 --- fastembed/text/multitask_embedding.py | 83 +++++++++++++++++++++++++++ fastembed/text/text_embedding.py | 2 + 2 files changed, 85 insertions(+) create mode 100644 fastembed/text/multitask_embedding.py diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py new file mode 100644 index 00000000..f910b282 --- /dev/null +++ b/fastembed/text/multitask_embedding.py @@ -0,0 +1,83 @@ +from typing import Any, Type, Iterable, Union, Optional + +import numpy as np + +from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding +from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker +from fastembed.text.onnx_text_model import TextEmbeddingWorker + +supported_multitask_models = [ + { + "model": "jinaai/jina-embeddings-v3", + "dim": [32, 64, 128, 256, 512, 768, 1024], + "tasks": { + "retrieval.query": 0, + "retrieval.passage": 1, + "separation": 2, + "classification": 3, + "text-matching": 4, + }, + "description": "Multi-task, multi-lingual embedding model with Matryoshka architecture", + "license": "cc-by-nc-4.0", + "size_in_GB": 2.29, + "sources": { + "hf": "jinaai/jina-embeddings-v3", + }, + "model_file": "onnx/model.onnx", + "additional_files": ["onnx/model.onnx_data"], + }, +] + + +class JinaEmbeddingV3(PooledNormalizedEmbedding): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._current_task_id = 4 + + @classmethod + def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]: + return JinaEmbeddingV3Worker + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + return supported_multitask_models + + def _preprocess_onnx_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + onnx_input["task_id"] = np.array(self._current_task_id, dtype=np.int64) + return onnx_input + + def embed( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + task_id: int = 4, + **kwargs, + ) -> Iterable[np.ndarray]: + self._current_task_id = task_id + yield from super().embed(documents, batch_size, parallel, **kwargs) + + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: + self._current_task_id = 0 + yield from super().query_embed(query, **kwargs) + + def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: + self._current_task_id = 1 + yield from super().passage_embed(texts, **kwargs) + + +class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs, + ) -> JinaEmbeddingV3: + return JinaEmbeddingV3( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 960d68f7..f7e44775 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -7,6 +7,7 @@ from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.pooled_embedding import PooledEmbedding +from fastembed.text.multitask_embedding import JinaEmbeddingV3 from fastembed.text.onnx_embedding import OnnxTextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase @@ -18,6 +19,7 @@ class TextEmbedding(TextEmbeddingBase): CLIPOnnxEmbedding, PooledNormalizedEmbedding, PooledEmbedding, + JinaEmbeddingV3, ] @classmethod From 64127fcf3ae8bd0cfc663de0b4dcd6709a3ee39c Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 08:27:52 +0200 Subject: [PATCH 2/8] refactor: Changed dim to int value --- fastembed/text/multitask_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index f910b282..8a9148cf 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -9,7 +9,7 @@ supported_multitask_models = [ { "model": "jinaai/jina-embeddings-v3", - "dim": [32, 64, 128, 256, 512, 768, 1024], + "dim": 1024, "tasks": { "retrieval.query": 0, "retrieval.passage": 1, From e48f64731629e59fce969722d8c70833895a577e Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 08:28:08 +0200 Subject: [PATCH 3/8] new: Updated notice --- NOTICE | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NOTICE b/NOTICE index bfa9618d..caa664b7 100644 --- a/NOTICE +++ b/NOTICE @@ -7,6 +7,8 @@ This distribution includes the following Jina AI models, each with its respectiv - License: cc-by-nc-4.0 - jinaai/jina-reranker-v2-base-multilingual - License: cc-by-nc-4.0 +- jinaai/jina-embeddings-v3 + - License: cc-by-nc-4.0 These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms. From eb475d5474c66644020602b234ff05f3619a8ab6 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 08:28:41 +0200 Subject: [PATCH 4/8] new: Extended text embedding with query embed and passage embed --- fastembed/text/text_embedding.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index f7e44775..a8def42e 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -107,3 +107,30 @@ def embed( List of embeddings, one per document """ yield from self.model.embed(documents, batch_size, parallel, **kwargs) + + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds queries + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + + Returns: + Iterable[np.ndarray]: The embeddings. + """ + # This is model-specific, so that different models can have specialized implementations + yield from self.model.query_embed(query, **kwargs) + + def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds a list of text passages into a list of embeddings. + + Args: + texts (Iterable[str]): The list of texts to embed. + **kwargs: Additional keyword argument to pass to the embed method. + + Yields: + Iterable[SparseEmbedding]: The sparse embeddings. + """ + # This is model-specific, so that different models can have specialized implementations + yield from self.model.passage_embed(texts, **kwargs) From 1650252e07f6abc998bcf7259a3d8793ee7aa13b Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 09:37:30 +0200 Subject: [PATCH 5/8] fix: Fix lazy load in query and passage embed --- fastembed/text/multitask_embedding.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index 8a9148cf..3ae31a82 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -61,11 +61,24 @@ def embed( def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: self._current_task_id = 0 - yield from super().query_embed(query, **kwargs) + + if isinstance(query, str): + query = [query] + + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + for text in query: + yield from self._post_process_onnx_output(self.onnx_embed([text])) def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: self._current_task_id = 1 - yield from super().passage_embed(texts, **kwargs) + + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + for text in texts: + yield from self._post_process_onnx_output(self.onnx_embed([text])) class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker): From 197b381e8f19b52fb9f272d038c5a5ae4bfd66ac Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 09:38:04 +0200 Subject: [PATCH 6/8] tests: Added test for multitask embeddings --- tests/test_text_multitask_embeddings.py | 231 ++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 tests/test_text_multitask_embeddings.py diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py new file mode 100644 index 00000000..606b4138 --- /dev/null +++ b/tests/test_text_multitask_embeddings.py @@ -0,0 +1,231 @@ +import os + +import numpy as np +import pytest + +from fastembed import TextEmbedding +from tests.utils import delete_model_cache + + +CANONICAL_VECTOR_VALUES = { + "jinaai/jina-embeddings-v3": [ + { + "task_id": 0, + "vectors": np.array( + [ + [0.0623, -0.0402, 0.1706, -0.0143, 0.0617], + [-0.1064, -0.0733, 0.0353, 0.0096, 0.0667], + ] + ), + }, + { + "task_id": 1, + "vectors": np.array( + [ + [0.0513, -0.0247, 0.1751, -0.0075, 0.0679], + [-0.0987, -0.0786, 0.09, 0.0087, 0.0577], + ] + ), + }, + { + "task_id": 2, + "vectors": np.array( + [ + [0.094, -0.1065, 0.1305, 0.0547, 0.0556], + [0.0315, -0.1468, 0.065, 0.0568, 0.0546], + ] + ), + }, + { + "task_id": 3, + "vectors": np.array( + [ + [0.0606, -0.0877, 0.1384, 0.0065, 0.0722], + [-0.0502, -0.119, 0.032, 0.0514, 0.0689], + ] + ), + }, + { + "task_id": 4, + "vectors": np.array( + [ + [0.0911, -0.0341, 0.1305, -0.026, 0.0576], + [-0.1432, -0.05, 0.0133, 0.0464, 0.0789], + ] + ), + }, + ] +} +docs = ["Hello World", "Follow the white rabbit."] + + +def test_batch_embedding(): + is_ci = os.getenv("CI") + docs_to_embed = docs * 10 + default_task = 4 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} default task") + + embeddings = list(model.embed(documents=docs_to_embed, batch_size=6)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs_to_embed), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding(): + is_ci = os.getenv("CI") + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + for task in CANONICAL_VECTOR_VALUES[model_name]: + print(f"evaluating {model_name} task_id: {task['task_id']}") + + embeddings = list(model.embed(documents=docs, task_id=task["task_id"])) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = task["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding_query(): + is_ci = os.getenv("CI") + task_id = 0 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} query_embed task_id: {task_id}") + + embeddings = list(model.query_embed(query=docs)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding_passage(): + is_ci = os.getenv("CI") + task_id = 1 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} passage_embed task_id: {task_id}") + + embeddings = list(model.passage_embed(texts=docs)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_parallel_processing(): + is_ci = os.getenv("CI") + + docs = ["Hello World", "Follow the white rabbit."] * 100 + + model_name = "jinaai/jina-embeddings-v3" + dim = 1024 + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) + + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) + + assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim + assert np.allclose(embeddings, embeddings_2, atol=1e-4) + assert np.allclose(embeddings, embeddings_3, atol=1e-4) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize( + "model_name", + ["jinaai/jina-embeddings-v3"], +) +def test_lazy_load(model_name): + is_ci = os.getenv("CI") + model = TextEmbedding(model_name=model_name, lazy_load=True, cache_dir="models") + assert not hasattr(model.model, "model") + + list(model.embed(docs)) + assert hasattr(model.model, "model") + + if is_ci: + delete_model_cache(model.model._model_dir) From c9172015298247342bd9c2f545756d9e9aedd00a Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 09:51:53 +0200 Subject: [PATCH 7/8] nit: Remove cache dir from tests --- tests/test_text_multitask_embeddings.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 606b4138..67964f62 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -74,7 +74,7 @@ def test_batch_embedding(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} default task") @@ -105,7 +105,7 @@ def test_single_embedding(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) for task in CANONICAL_VECTOR_VALUES[model_name]: print(f"evaluating {model_name} task_id: {task['task_id']}") @@ -138,7 +138,7 @@ def test_single_embedding_query(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} query_embed task_id: {task_id}") @@ -170,7 +170,7 @@ def test_single_embedding_passage(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} passage_embed task_id: {task_id}") @@ -196,7 +196,7 @@ def test_parallel_processing(): model_name = "jinaai/jina-embeddings-v3" dim = 1024 - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) embeddings = list(model.embed(docs, batch_size=10, parallel=2)) embeddings = np.stack(embeddings, axis=0) @@ -221,7 +221,7 @@ def test_parallel_processing(): ) def test_lazy_load(model_name): is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name, lazy_load=True, cache_dir="models") + model = TextEmbedding(model_name=model_name, lazy_load=True) assert not hasattr(model.model, "model") list(model.embed(docs)) From 1ed62e96922bb8c01662548e14c6a86fae0a28fc Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 13:35:06 +0200 Subject: [PATCH 8/8] tests: Updated tests --- tests/test_text_multitask_embeddings.py | 38 ++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 67964f62..5e2d241e 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -65,8 +65,8 @@ def test_batch_embedding(): default_task = 4 for model_desc in TextEmbedding.list_supported_models(): - # if not is_ci and model_desc["size_in_GB"] > 1: - # continue + if not is_ci and model_desc["size_in_GB"] > 1: + continue model_name = model_desc["model"] dim = model_desc["dim"] @@ -96,8 +96,8 @@ def test_single_embedding(): is_ci = os.getenv("CI") for model_desc in TextEmbedding.list_supported_models(): - # if not is_ci and model_desc["size_in_GB"] > 1: - # continue + if not is_ci and model_desc["size_in_GB"] > 1: + continue model_name = model_desc["model"] dim = model_desc["dim"] @@ -129,8 +129,8 @@ def test_single_embedding_query(): task_id = 0 for model_desc in TextEmbedding.list_supported_models(): - # if not is_ci and model_desc["size_in_GB"] > 1: - # continue + if not is_ci and model_desc["size_in_GB"] > 1: + continue model_name = model_desc["model"] dim = model_desc["dim"] @@ -161,8 +161,8 @@ def test_single_embedding_passage(): task_id = 1 for model_desc in TextEmbedding.list_supported_models(): - # if not is_ci and model_desc["size_in_GB"] > 1: - # continue + if not is_ci and model_desc["size_in_GB"] > 1: + continue model_name = model_desc["model"] dim = model_desc["dim"] @@ -196,22 +196,22 @@ def test_parallel_processing(): model_name = "jinaai/jina-embeddings-v3" dim = 1024 - model = TextEmbedding(model_name=model_name) + if is_ci: + model = TextEmbedding(model_name=model_name) - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - embeddings = np.stack(embeddings, axis=0) + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - embeddings_2 = np.stack(embeddings_2, axis=0) + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) - embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) - embeddings_3 = np.stack(embeddings_3, axis=0) + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) - assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim - assert np.allclose(embeddings, embeddings_2, atol=1e-4) - assert np.allclose(embeddings, embeddings_3, atol=1e-4) + assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim + assert np.allclose(embeddings, embeddings_2, atol=1e-4) + assert np.allclose(embeddings, embeddings_3, atol=1e-4) - if is_ci: delete_model_cache(model.model._model_dir)