From 60d24321479e5248ea20dadf33ebcfb08bf7fe59 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 4 Nov 2024 20:59:20 +0200 Subject: [PATCH] new: Added matryoshka support for jina embedding v3 new: Added normalized mean pooling --- fastembed/multi_task/jina_embedding_v3.py | 24 ++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/fastembed/multi_task/jina_embedding_v3.py b/fastembed/multi_task/jina_embedding_v3.py index 2ad93efe..09805291 100644 --- a/fastembed/multi_task/jina_embedding_v3.py +++ b/fastembed/multi_task/jina_embedding_v3.py @@ -5,7 +5,12 @@ import numpy as np from fastembed.common import OnnxProvider -from fastembed.common.utils import iter_batch, define_cache_dir +from fastembed.common.utils import ( + iter_batch, + define_cache_dir, + normalize, + adjust_matryoshka_embedding, +) from fastembed.common.onnx_model import OnnxOutputContext from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker from fastembed.multi_task.multi_task_embedding_base import ( @@ -156,7 +161,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.nd embeddings = output.model_output attn_mask = output.attention_mask - return self.mean_pooling(embeddings, attn_mask).astype(np.float32) + return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) def onnx_embed( self, @@ -204,6 +209,8 @@ def _embed_documents( device_ids: Optional[List[int]] = None, **kwargs, ) -> Iterable[MultiTaskEmbedding]: + embeddings_size = kwargs.get("embeddings_size", None) + is_small = False if isinstance(documents, str): @@ -219,6 +226,16 @@ def _embed_documents( self.load_onnx_model() for batch in iter_batch(documents, batch_size): embeddings = self._post_process_onnx_output(self.onnx_embed(batch, task_id)) + if embeddings_size is not None: + if not isinstance(self.model_description["dim"], list): + raise ValueError( + f"Model does not support Matryoshka embeddings. The only size supported is: {self.model_description['dim']}." + ) + if embeddings_size not in self.model_description["dim"]: + raise ValueError( + f"Requested embeddings size {embeddings_size} is not supported by the model. Supported sizes: {self.model_description['dim']}." + ) + embeddings = adjust_matryoshka_embedding(np.array(embeddings), embeddings_size) for embedding in embeddings: yield MultiTaskEmbedding( embedding=embedding, @@ -269,9 +286,6 @@ def task_embed( task_id = self.get_task_types_dict()[task_type] - if isinstance(documents, str): - documents = [documents] - yield from self._embed_documents( model_name=self.model_name, cache_dir=str(self.cache_dir),