Skip to content

Commit

Permalink
add token embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
generall committed Dec 16, 2024
1 parent 3b5e4c8 commit 4c2f997
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions fastembed/late_interaction/token_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Union, Iterable, Optional, List, Dict, Any

import numpy as np

from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.late_interaction.late_interaction_embedding_base import LateInteractionTextEmbeddingBase
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.onnx_text_model import TextEmbeddingWorker

supported_token_embeddings_models = [
{
"model": "jinaai/jina-embeddings-v2-small-en-tokens",
"dim": 512,
"description": "Text embeddings, Unimodal (text), English, 8192 input tokens truncation,"
" Prefixes for queries/documents: not necessary, 2023 year.",
"license": "apache-2.0",
"size_in_GB": 0.12,
"sources": {"hf": "xenova/jina-embeddings-v2-small-en"},
"model_file": "onnx/model.onnx",
},
]


class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase):
@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""Lists the supported models.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_token_embeddings_models

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
# Size: (batch_size, sequence_length, hidden_size)
embeddings = output.model_output
# Size: (batch_size, sequence_length)
masks = output.attention_mask

# For each document we only select those embeddings that are not masked out

for i in range(embeddings.shape[0]):
yield embeddings[i, masks[i] == 1]

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
yield from OnnxTextEmbedding.embed(self, documents, batch_size=batch_size, parallel=parallel, **kwargs)

def tokenize_docs(self, documents: List[str]) -> List[np.ndarray]:
encoded = self.tokenizer.encode_batch(documents)
return [e.ids for e in encoded]


class TokensEmbeddingWorker(TextEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> TokenEmbeddingsModel:
return TokenEmbeddingsModel(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)


if __name__ == "__main__":
# Example usage
model = TokenEmbeddingsModel(model_name="jinaai/jina-embeddings-v2-small-en-tokens")
docs = ["Hello, world!", "hello", "hello hello"]

embeddings = model.embed(docs)
for emb in embeddings:
print(emb.shape)

print(model.tokenize_docs(docs))

0 comments on commit 4c2f997

Please sign in to comment.