Skip to content

Commit

Permalink
new: Added matryoshka support for jina embedding v3
Browse files Browse the repository at this point in the history
new: Added normalized mean pooling
  • Loading branch information
hh-space-invader committed Nov 4, 2024
1 parent 4ff34c4 commit 60d2432
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions fastembed/multi_task/jina_embedding_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import numpy as np

from fastembed.common import OnnxProvider
from fastembed.common.utils import iter_batch, define_cache_dir
from fastembed.common.utils import (
iter_batch,
define_cache_dir,
normalize,
adjust_matryoshka_embedding,
)
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
from fastembed.multi_task.multi_task_embedding_base import (
Expand Down Expand Up @@ -156,7 +161,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.nd

embeddings = output.model_output
attn_mask = output.attention_mask
return self.mean_pooling(embeddings, attn_mask).astype(np.float32)
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)

def onnx_embed(
self,
Expand Down Expand Up @@ -204,6 +209,8 @@ def _embed_documents(
device_ids: Optional[List[int]] = None,
**kwargs,
) -> Iterable[MultiTaskEmbedding]:
embeddings_size = kwargs.get("embeddings_size", None)

is_small = False

if isinstance(documents, str):
Expand All @@ -219,6 +226,16 @@ def _embed_documents(
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
embeddings = self._post_process_onnx_output(self.onnx_embed(batch, task_id))
if embeddings_size is not None:
if not isinstance(self.model_description["dim"], list):
raise ValueError(
f"Model does not support Matryoshka embeddings. The only size supported is: {self.model_description['dim']}."
)
if embeddings_size not in self.model_description["dim"]:
raise ValueError(
f"Requested embeddings size {embeddings_size} is not supported by the model. Supported sizes: {self.model_description['dim']}."
)
embeddings = adjust_matryoshka_embedding(np.array(embeddings), embeddings_size)
for embedding in embeddings:
yield MultiTaskEmbedding(
embedding=embedding,
Expand Down Expand Up @@ -269,9 +286,6 @@ def task_embed(

task_id = self.get_task_types_dict()[task_type]

if isinstance(documents, str):
documents = [documents]

yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
Expand Down

0 comments on commit 60d2432

Please sign in to comment.