Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jina colbert v2 #363

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions fastembed/common/preprocessor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def load_special_tokens(model_dir: Path) -> dict:
return tokens_map


def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]:
def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict, str]:
config_path = model_dir / "config.json"
if not config_path.exists():
raise ValueError(f"Could not find config.json in {model_dir}")
Expand Down Expand Up @@ -68,7 +68,19 @@ def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]:
token_str = token.get("content", "")
special_token_to_id[token_str] = tokenizer.token_to_id(token_str)

return tokenizer, special_token_to_id
if tokenizer_config["tokenizer_class"] == "BertTokenizer":
query_marker = {"[Q]": 1}
document_marker = {"[D]": 2}
elif tokenizer_config["tokenizer_class"] == "XLMRobertaTokenizer":
query_marker = {"[QueryMarker]": 250002}
document_marker = {"[DocumentMarker]": 250003}
else:
query_marker = {}
document_marker = {}
hh-space-invader marked this conversation as resolved.
Show resolved Hide resolved

special_token_to_id.update(query_marker)
special_token_to_id.update(document_marker)
return tokenizer, special_token_to_id, tokenizer_config["tokenizer_class"]


def load_preprocessor(model_dir: Path) -> Compose:
Expand Down
62 changes: 51 additions & 11 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,26 @@
},
"model_file": "vespa_colbert.onnx",
},
{
"model": "jinaai/jina-colbert-v2",
"dim": 1024,
"description": "New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year",
"license": "cc-by-nc-4.0",
"size_in_GB": 2.24,
"sources": {
"hf": "jinaai/jina-colbert-v2",
},
"model_file": "onnx/model.onnx",
"additional_files": ["onnx/model.onnx_data"],
},
]


class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]):
QUERY_MARKER_TOKEN_ID = 1
DOCUMENT_MARKER_TOKEN_ID = 2
MIN_QUERY_LENGTH = 32
MASK_TOKEN = "[MASK]"
MASK_TOKENS = ["[MASK]", "<mask>"]
QUERY_MARKER_TOKENS = ["[Q]", "[QueryMarker]"]
DOCUMENT_MARKER_TOKENS = ["[D]", "[DocumentMarker]"]

def _post_process_onnx_output(
self, output: OnnxOutputContext, is_doc: bool = True
Expand Down Expand Up @@ -70,9 +82,12 @@ def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
) -> Dict[str, np.ndarray]:
if is_doc:
onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID
onnx_input["input_ids"][:, 1] = self.document_marker_token_id
else:
onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID
onnx_input["input_ids"][:, 1] = self.query_marker_token_id

if self.tokenizer_class == "XLMRobertaTokenizer":
hh-space-invader marked this conversation as resolved.
Show resolved Hide resolved
onnx_input["attention_mask"][:] = 1
return onnx_input

def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
Expand All @@ -83,16 +98,17 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
)

def _tokenize_query(self, query: str) -> List[Encoding]:
# ". " is added to a query to be replaced with a special query token
query = [f". {query}"]
# "@ " is added to a query to be replaced with a special query token
# please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models
query = [f"@ {query}"]
encoded = self.tokenizer.encode_batch(query)
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
prev_padding = None
if self.tokenizer.padding:
prev_padding = self.tokenizer.padding
self.tokenizer.enable_padding(
pad_token=self.MASK_TOKEN,
pad_token=self.mask_token,
pad_id=self.mask_token_id,
length=self.MIN_QUERY_LENGTH,
)
Expand All @@ -104,8 +120,9 @@ def _tokenize_query(self, query: str) -> List[Encoding]:
return encoded

def _tokenize_documents(self, documents: List[str]) -> List[Encoding]:
# ". " is added to a document to be replaced with a special document token
documents = [". " + doc for doc in documents]
# "@ " is added to a document to be replaced with a special document token
# please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models
documents = ["@ " + doc for doc in documents]
encoded = self.tokenizer.encode_batch(documents)
return encoded

Expand Down Expand Up @@ -189,7 +206,30 @@ def load_onnx_model(self) -> None:
cuda=self.cuda,
device_id=self.device_id,
)
self.mask_token_id = self.special_token_to_id["[MASK]"]
self.mask_token_id, self.mask_token = next(
(
(self.special_token_to_id[token], token)
for token in self.MASK_TOKENS
if token in self.special_token_to_id
),
(None, None),
)
hh-space-invader marked this conversation as resolved.
Show resolved Hide resolved
self.query_marker_token_id = next(
(
self.special_token_to_id[token]
for token in self.QUERY_MARKER_TOKENS
if token in self.special_token_to_id
),
None,
)
self.document_marker_token_id = next(
(
self.special_token_to_id[token]
for token in self.DOCUMENT_MARKER_TOKENS
if token in self.special_token_to_id
),
None,
)
hh-space-invader marked this conversation as resolved.
Show resolved Hide resolved
self.pad_token_id = self.tokenizer.padding["pad_id"]
self.skip_list = {
self.tokenizer.encode(symbol, add_special_tokens=False).ids[0]
Expand Down
2 changes: 1 addition & 1 deletion fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _load_onnx_model(
cuda=cuda,
device_id=device_id,
)
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)
self.tokenizer, _, _ = load_tokenizer(model_dir=model_dir)

def tokenize(self, query: str, documents: List[str], **kwargs) -> List[Encoding]:
return self.tokenizer.encode_batch([(query, doc) for doc in documents])
Expand Down
4 changes: 3 additions & 1 deletion fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def _load_onnx_model(
cuda=cuda,
device_id=device_id,
)
self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir)
self.tokenizer, self.special_token_to_id, self.tokenizer_class = load_tokenizer(
model_dir=model_dir
)

def load_onnx_model(self) -> None:
raise NotImplementedError("Subclasses must implement this method")
Expand Down
45 changes: 45 additions & 0 deletions tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
[-0.07281, 0.04633, -0.04711, 0.00762, -0.07374],
]
),
"jinaai/jina-colbert-v2": np.array(
[
[0.0744, 0.0590, -0.2404, -0.1776, 0.0198],
[0.1318, 0.0882, -0.1136, -0.2065, 0.1461],
[-0.0178, -0.1359, -0.0136, -0.1075, -0.0509],
[0.0004, -0.1198, -0.0696, -0.0482, -0.0650],
[0.0766, 0.0448, -0.2344, -0.1829, 0.0061],
]
),
}

CANONICAL_QUERY_VALUES = {
Expand Down Expand Up @@ -103,6 +112,42 @@
[-0.03473, 0.04792, -0.07033, 0.02196, -0.05314],
]
),
"jinaai/jina-colbert-v2": np.array(
[
[0.0475, 0.0250, -0.2225, -0.1087, -0.0297],
[0.0211, -0.0844, -0.0070, -0.1715, 0.0154],
[-0.0062, -0.0958, -0.0142, -0.1283, -0.0218],
[0.0490, -0.0500, -0.1613, 0.0193, 0.0280],
[0.0477, 0.0250, -0.2279, -0.1128, -0.0289],
[0.0597, -0.0676, -0.0955, -0.0756, 0.0234],
[0.0592, -0.0858, -0.0621, -0.1088, 0.0148],
[0.0870, -0.0715, -0.0769, -0.1414, 0.0365],
[0.1015, -0.0552, -0.0667, -0.1637, 0.0492],
[0.1135, -0.0469, -0.0573, -0.1702, 0.0535],
[0.1226, -0.0430, -0.0508, -0.1729, 0.0553],
[0.1287, -0.0387, -0.0425, -0.1757, 0.0567],
[0.1360, -0.0337, -0.0327, -0.1790, 0.0570],
[0.1434, -0.0267, -0.0242, -0.1831, 0.0569],
[0.1528, -0.0091, -0.0184, -0.1881, 0.0570],
[0.1547, 0.0185, -0.0231, -0.1803, 0.0538],
[0.1396, 0.0533, -0.0349, -0.1637, 0.0429],
[0.1074, 0.0851, -0.0418, -0.1461, 0.0231],
[0.0719, 0.1061, -0.0440, -0.1291, -0.0003],
[0.0456, 0.1146, -0.0457, -0.1118, -0.0192],
[0.0347, 0.1132, -0.0493, -0.0955, -0.0341],
[0.0357, 0.1074, -0.0491, -0.0821, -0.0449],
[0.0421, 0.1036, -0.0461, -0.0763, -0.0488],
[0.0479, 0.1019, -0.0434, -0.0721, -0.0483],
[0.0470, 0.0988, -0.0423, -0.0654, -0.0440],
[0.0439, 0.0947, -0.0418, -0.0591, -0.0349],
[0.0397, 0.0898, -0.0415, -0.0555, -0.0206],
[0.0434, 0.0815, -0.0411, -0.0543, 0.0057],
[0.0512, 0.0629, -0.0442, -0.0547, 0.0378],
[0.0584, 0.0483, -0.0528, -0.0607, 0.0568],
[0.0568, 0.0456, -0.0674, -0.0699, 0.0768],
[0.0205, -0.0859, -0.0385, -0.1231, -0.0331],
]
),
}

docs = ["Hello World"]
Expand Down
Loading