Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jina colbert v2 #363

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:

def _tokenize_query(self, query: str) -> List[Encoding]:
# ". " is added to a query to be replaced with a special query token
query = [f". {query}"]
encoded = self.tokenizer.encode_batch(query)
query = f". {query}"
encoded = self.tokenizer.encode_batch([query])
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
prev_padding = None
Expand All @@ -96,7 +96,7 @@ def _tokenize_query(self, query: str) -> List[Encoding]:
pad_id=self.mask_token_id,
length=self.MIN_QUERY_LENGTH,
)
encoded = self.tokenizer.encode_batch(query)
encoded = self.tokenizer.encode_batch([query])
if prev_padding is None:
self.tokenizer.no_padding()
else:
Expand Down Expand Up @@ -189,7 +189,7 @@ def load_onnx_model(self) -> None:
cuda=self.cuda,
device_id=self.device_id,
)
self.mask_token_id = self.special_token_to_id["[MASK]"]
self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN]
self.pad_token_id = self.tokenizer.padding["pad_id"]
self.skip_list = {
self.tokenizer.encode(symbol, add_special_tokens=False).ids[0]
Expand Down
252 changes: 252 additions & 0 deletions fastembed/late_interaction/jina_colbert_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import string
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union

import numpy as np
from tokenizers import Encoding

from fastembed.common import OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir
from fastembed.late_interaction.late_interaction_embedding_base import (
LateInteractionTextEmbeddingBase,
)
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker

supported_colbert_models = [
{
"model": "jinaai/jina-colbert-v2",
"dim": 1024,
"description": "New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year",
"license": "cc-by-nc-4.0",
"size_in_GB": 2.24,
"sources": {
"hf": "jinaai/jina-colbert-v2",
},
"model_file": "onnx/model.onnx",
"additional_files": ["onnx/model.onnx_data"],
},
]


class JinaColbertV2(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]):
QUERY_MARKER_TOKEN_ID = 250002
DOCUMENT_MARKER_TOKEN_ID = 250003
MIN_QUERY_LENGTH = 32
MASK_TOKEN = "<mask>"

def _post_process_onnx_output(
self, output: OnnxOutputContext, is_doc: bool = True
) -> Iterable[np.ndarray]:
if not is_doc:
return output.model_output.astype(np.float32)

if output.input_ids is None or output.attention_mask is None:
raise ValueError(
"input_ids and attention_mask must be provided for document post-processing"
)

for i, token_sequence in enumerate(output.input_ids):
for j, token_id in enumerate(token_sequence):
if token_id in self.skip_list or token_id == self.pad_token_id:
output.attention_mask[i, j] = 0

output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32)
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
norm_clamped = np.maximum(norm, 1e-12)
output.model_output /= norm_clamped
return output.model_output.astype(np.float32)

def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
) -> Dict[str, np.ndarray]:
if is_doc:
onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID
else:
onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID

# the attention mask for jina-colbert-v2 is always 1
onnx_input["attention_mask"][:] = 1
return onnx_input

def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
return (
self._tokenize_documents(documents=documents)
if is_doc
else self._tokenize_query(query=next(iter(documents)))
)

def _tokenize_query(self, query: str) -> List[Encoding]:
# "@ " is added to a query to be replaced with a special query token
# "@ " is considered as one token in jina-colbert-v2 tokenizer
query = f"@ {query}"
encoded = self.tokenizer.encode_batch([query])
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
prev_padding = None
if self.tokenizer.padding:
prev_padding = self.tokenizer.padding
self.tokenizer.enable_padding(
pad_token=self.MASK_TOKEN,
pad_id=self.mask_token_id,
length=self.MIN_QUERY_LENGTH,
)
encoded = self.tokenizer.encode_batch([query])
if prev_padding is None:
self.tokenizer.no_padding()
else:
self.tokenizer.enable_padding(**prev_padding)
return encoded

def _tokenize_documents(self, documents: List[str]) -> List[Encoding]:
# "@ " is added to a document to be replaced with a special document token
# "@ " is considered as one token in jina-colbert-v2 tokenizer
documents = ["@ " + doc for doc in documents]
encoded = self.tokenizer.encode_batch(documents)
return encoded

@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""Lists the supported models.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_colbert_models

def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[List[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
**kwargs,
):
"""
Args:
model_name (str): The name of the model to use.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.

Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""

super().__init__(model_name, cache_dir, threads, **kwargs)
self.providers = providers
self.lazy_load = lazy_load

# List of device ids, that can be used for data parallel processing in workers
self.device_ids = device_ids
self.cuda = cuda

# This device_id will be used if we need to load model in current process
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)

self._model_dir = self.download_model(
self.model_description, self.cache_dir, local_files_only=self._local_files_only
)
self.mask_token_id = None
self.pad_token_id = None
self.skip_list = set()

if not self.lazy_load:
self.load_onnx_model()

def load_onnx_model(self) -> None:
self._load_onnx_model(
model_dir=self._model_dir,
model_file=self.model_description["model_file"],
threads=self.threads,
providers=self.providers,
cuda=self.cuda,
device_id=self.device_id,
)
self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN]
self.pad_token_id = self.tokenizer.padding["pad_id"]
self.skip_list = {
self.tokenizer.encode(symbol, add_special_tokens=False).ids[0]
for symbol in string.punctuation
}

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
We use mean pooling with attention so that the model can handle variable-length inputs.

Args:
documents: Iterator of documents or single document to embed
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.

Returns:
List of embeddings, one per document
"""
yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
documents=documents,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)

def query_embed(self, query: Union[str, List[str]], **kwargs) -> Iterable[np.ndarray]:
if isinstance(query, str):
query = [query]

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

for text in query:
yield from self._post_process_onnx_output(
self.onnx_embed([text], is_doc=False), is_doc=False
)

@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
return ColbertEmbeddingWorker


class ColbertEmbeddingWorker(TextEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> JinaColbertV2:
return JinaColbertV2(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
45 changes: 45 additions & 0 deletions tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
[-0.07281, 0.04633, -0.04711, 0.00762, -0.07374],
]
),
"jinaai/jina-colbert-v2": np.array(
[
[0.0744, 0.0590, -0.2404, -0.1776, 0.0198],
[0.1318, 0.0882, -0.1136, -0.2065, 0.1461],
[-0.0178, -0.1359, -0.0136, -0.1075, -0.0509],
[0.0004, -0.1198, -0.0696, -0.0482, -0.0650],
[0.0766, 0.0448, -0.2344, -0.1829, 0.0061],
]
),
}

CANONICAL_QUERY_VALUES = {
Expand Down Expand Up @@ -103,6 +112,42 @@
[-0.03473, 0.04792, -0.07033, 0.02196, -0.05314],
]
),
"jinaai/jina-colbert-v2": np.array(
[
[0.0475, 0.0250, -0.2225, -0.1087, -0.0297],
[0.0211, -0.0844, -0.0070, -0.1715, 0.0154],
[-0.0062, -0.0958, -0.0142, -0.1283, -0.0218],
[0.0490, -0.0500, -0.1613, 0.0193, 0.0280],
[0.0477, 0.0250, -0.2279, -0.1128, -0.0289],
[0.0597, -0.0676, -0.0955, -0.0756, 0.0234],
[0.0592, -0.0858, -0.0621, -0.1088, 0.0148],
[0.0870, -0.0715, -0.0769, -0.1414, 0.0365],
[0.1015, -0.0552, -0.0667, -0.1637, 0.0492],
[0.1135, -0.0469, -0.0573, -0.1702, 0.0535],
[0.1226, -0.0430, -0.0508, -0.1729, 0.0553],
[0.1287, -0.0387, -0.0425, -0.1757, 0.0567],
[0.1360, -0.0337, -0.0327, -0.1790, 0.0570],
[0.1434, -0.0267, -0.0242, -0.1831, 0.0569],
[0.1528, -0.0091, -0.0184, -0.1881, 0.0570],
[0.1547, 0.0185, -0.0231, -0.1803, 0.0538],
[0.1396, 0.0533, -0.0349, -0.1637, 0.0429],
[0.1074, 0.0851, -0.0418, -0.1461, 0.0231],
[0.0719, 0.1061, -0.0440, -0.1291, -0.0003],
[0.0456, 0.1146, -0.0457, -0.1118, -0.0192],
[0.0347, 0.1132, -0.0493, -0.0955, -0.0341],
[0.0357, 0.1074, -0.0491, -0.0821, -0.0449],
[0.0421, 0.1036, -0.0461, -0.0763, -0.0488],
[0.0479, 0.1019, -0.0434, -0.0721, -0.0483],
[0.0470, 0.0988, -0.0423, -0.0654, -0.0440],
[0.0439, 0.0947, -0.0418, -0.0591, -0.0349],
[0.0397, 0.0898, -0.0415, -0.0555, -0.0206],
[0.0434, 0.0815, -0.0411, -0.0543, 0.0057],
[0.0512, 0.0629, -0.0442, -0.0547, 0.0378],
[0.0584, 0.0483, -0.0528, -0.0607, 0.0568],
[0.0568, 0.0456, -0.0674, -0.0699, 0.0768],
[0.0205, -0.0859, -0.0385, -0.1231, -0.0331],
]
),
}

docs = ["Hello World"]
Expand Down
Loading