Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: Added jina embedding v3 #428

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
83 changes: 83 additions & 0 deletions fastembed/text/multitask_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Any, Type, Iterable, Union, Optional

import numpy as np

from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
from fastembed.text.onnx_text_model import TextEmbeddingWorker

supported_multitask_models = [
{
"model": "jinaai/jina-embeddings-v3",
"dim": [32, 64, 128, 256, 512, 768, 1024],
hh-space-invader marked this conversation as resolved.
Show resolved Hide resolved
"tasks": {
"retrieval.query": 0,
"retrieval.passage": 1,
"separation": 2,
"classification": 3,
"text-matching": 4,
},
"description": "Multi-task, multi-lingual embedding model with Matryoshka architecture",
"license": "cc-by-nc-4.0",
"size_in_GB": 2.29,
"sources": {
"hf": "jinaai/jina-embeddings-v3",
},
"model_file": "onnx/model.onnx",
"additional_files": ["onnx/model.onnx_data"],
},
]


class JinaEmbeddingV3(PooledNormalizedEmbedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._current_task_id = 4

@classmethod
def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
return JinaEmbeddingV3Worker

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
return supported_multitask_models

def _preprocess_onnx_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
onnx_input["task_id"] = np.array(self._current_task_id, dtype=np.int64)
return onnx_input

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
task_id: int = 4,
**kwargs,
) -> Iterable[np.ndarray]:
self._current_task_id = task_id
yield from super().embed(documents, batch_size, parallel, **kwargs)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
self._current_task_id = 0
yield from super().query_embed(query, **kwargs)

def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
self._current_task_id = 1
yield from super().passage_embed(texts, **kwargs)


class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
**kwargs,
) -> JinaEmbeddingV3:
return JinaEmbeddingV3(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
2 changes: 2 additions & 0 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.pooled_embedding import PooledEmbedding
from fastembed.text.multitask_embedding import JinaEmbeddingV3
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.text_embedding_base import TextEmbeddingBase

Expand All @@ -18,6 +19,7 @@ class TextEmbedding(TextEmbeddingBase):
CLIPOnnxEmbedding,
PooledNormalizedEmbedding,
PooledEmbedding,
JinaEmbeddingV3,
]

@classmethod
Expand Down
Loading