Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: Added jina embedding v3 #428

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ This distribution includes the following Jina AI models, each with its respectiv
- License: cc-by-nc-4.0
- jinaai/jina-reranker-v2-base-multilingual
- License: cc-by-nc-4.0
- jinaai/jina-embeddings-v3
- License: cc-by-nc-4.0

These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms.

Expand Down
96 changes: 96 additions & 0 deletions fastembed/text/multitask_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Any, Type, Iterable, Union, Optional

import numpy as np

from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
from fastembed.text.onnx_text_model import TextEmbeddingWorker

supported_multitask_models = [
{
"model": "jinaai/jina-embeddings-v3",
"dim": 1024,
"tasks": {
"retrieval.query": 0,
"retrieval.passage": 1,
"separation": 2,
"classification": 3,
"text-matching": 4,
},
"description": "Multi-task, multi-lingual embedding model with Matryoshka architecture",
"license": "cc-by-nc-4.0",
"size_in_GB": 2.29,
"sources": {
"hf": "jinaai/jina-embeddings-v3",
},
"model_file": "onnx/model.onnx",
"additional_files": ["onnx/model.onnx_data"],
},
]


class JinaEmbeddingV3(PooledNormalizedEmbedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._current_task_id = 4

@classmethod
def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
return JinaEmbeddingV3Worker

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
return supported_multitask_models

def _preprocess_onnx_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
onnx_input["task_id"] = np.array(self._current_task_id, dtype=np.int64)
return onnx_input

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
task_id: int = 4,
**kwargs,
) -> Iterable[np.ndarray]:
self._current_task_id = task_id
yield from super().embed(documents, batch_size, parallel, **kwargs)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
self._current_task_id = 0

if isinstance(query, str):
query = [query]

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

for text in query:
yield from self._post_process_onnx_output(self.onnx_embed([text]))

def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
self._current_task_id = 1

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

for text in texts:
yield from self._post_process_onnx_output(self.onnx_embed([text]))


class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
**kwargs,
) -> JinaEmbeddingV3:
return JinaEmbeddingV3(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
29 changes: 29 additions & 0 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.pooled_embedding import PooledEmbedding
from fastembed.text.multitask_embedding import JinaEmbeddingV3
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.text_embedding_base import TextEmbeddingBase

Expand All @@ -18,6 +19,7 @@ class TextEmbedding(TextEmbeddingBase):
CLIPOnnxEmbedding,
PooledNormalizedEmbedding,
PooledEmbedding,
JinaEmbeddingV3,
]

@classmethod
Expand Down Expand Up @@ -105,3 +107,30 @@ def embed(
List of embeddings, one per document
"""
yield from self.model.embed(documents, batch_size, parallel, **kwargs)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
"""
Embeds queries

Args:
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.

Returns:
Iterable[np.ndarray]: The embeddings.
"""
# This is model-specific, so that different models can have specialized implementations
yield from self.model.query_embed(query, **kwargs)

def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
"""
Embeds a list of text passages into a list of embeddings.

Args:
texts (Iterable[str]): The list of texts to embed.
**kwargs: Additional keyword argument to pass to the embed method.

Yields:
Iterable[SparseEmbedding]: The sparse embeddings.
"""
# This is model-specific, so that different models can have specialized implementations
yield from self.model.passage_embed(texts, **kwargs)
231 changes: 231 additions & 0 deletions tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import os

import numpy as np
import pytest

from fastembed import TextEmbedding
from tests.utils import delete_model_cache


CANONICAL_VECTOR_VALUES = {
"jinaai/jina-embeddings-v3": [
{
"task_id": 0,
"vectors": np.array(
[
[0.0623, -0.0402, 0.1706, -0.0143, 0.0617],
[-0.1064, -0.0733, 0.0353, 0.0096, 0.0667],
]
),
},
{
"task_id": 1,
"vectors": np.array(
[
[0.0513, -0.0247, 0.1751, -0.0075, 0.0679],
[-0.0987, -0.0786, 0.09, 0.0087, 0.0577],
]
),
},
{
"task_id": 2,
"vectors": np.array(
[
[0.094, -0.1065, 0.1305, 0.0547, 0.0556],
[0.0315, -0.1468, 0.065, 0.0568, 0.0546],
]
),
},
{
"task_id": 3,
"vectors": np.array(
[
[0.0606, -0.0877, 0.1384, 0.0065, 0.0722],
[-0.0502, -0.119, 0.032, 0.0514, 0.0689],
]
),
},
{
"task_id": 4,
"vectors": np.array(
[
[0.0911, -0.0341, 0.1305, -0.026, 0.0576],
[-0.1432, -0.05, 0.0133, 0.0464, 0.0789],
]
),
},
]
}
docs = ["Hello World", "Follow the white rabbit."]


def test_batch_embedding():
is_ci = os.getenv("CI")
docs_to_embed = docs * 10
default_task = 4

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
dim = model_desc["dim"]

if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} default task")

embeddings = list(model.embed(documents=docs_to_embed, batch_size=6))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (len(docs_to_embed), dim)

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc["model"]

if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding():
is_ci = os.getenv("CI")

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
dim = model_desc["dim"]

if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

for task in CANONICAL_VECTOR_VALUES[model_name]:
print(f"evaluating {model_name} task_id: {task['task_id']}")

embeddings = list(model.embed(documents=docs, task_id=task["task_id"]))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (len(docs), dim)

canonical_vector = task["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc["model"]

if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding_query():
is_ci = os.getenv("CI")
task_id = 0

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
dim = model_desc["dim"]

if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} query_embed task_id: {task_id}")

embeddings = list(model.query_embed(query=docs))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (len(docs), dim)

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc["model"]

if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding_passage():
is_ci = os.getenv("CI")
task_id = 1

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
dim = model_desc["dim"]

if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} passage_embed task_id: {task_id}")

embeddings = list(model.passage_embed(texts=docs))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (len(docs), dim)

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc["model"]

if is_ci:
delete_model_cache(model.model._model_dir)


def test_parallel_processing():
is_ci = os.getenv("CI")

docs = ["Hello World", "Follow the white rabbit."] * 100

model_name = "jinaai/jina-embeddings-v3"
dim = 1024

if is_ci:
model = TextEmbedding(model_name=model_name)

embeddings = list(model.embed(docs, batch_size=10, parallel=2))
embeddings = np.stack(embeddings, axis=0)

embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None))
embeddings_2 = np.stack(embeddings_2, axis=0)

embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0))
embeddings_3 = np.stack(embeddings_3, axis=0)

assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim
assert np.allclose(embeddings, embeddings_2, atol=1e-4)
assert np.allclose(embeddings, embeddings_3, atol=1e-4)

delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["jinaai/jina-embeddings-v3"],
)
def test_lazy_load(model_name):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")

list(model.embed(docs))
assert hasattr(model.model, "model")

if is_ci:
delete_model_cache(model.model._model_dir)
Loading