From 38e003c78460463acd11306ad8e6ceed7cba0e00 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Sun, 13 Oct 2024 19:35:49 +0300 Subject: [PATCH 1/6] feat: Added support for jina-colbert-v2 --- .github/workflows/python-tests.yml | 4 +- fastembed/late_interaction/colbert.py | 24 ++++++++++-- tests/test_late_interaction_embeddings.py | 45 +++++++++++++++++++++++ 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index c42ac41a..72993f35 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -2,8 +2,8 @@ name: Tests on: push: - branches: [ master, main, gpu ] - pull_request: + branches: [ master, main, gpu, jina-colbert-v2 ] + workflow_dispatch: env: CARGO_TERM_COLOR: always diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 6e7f34d9..60264aa4 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -35,6 +35,17 @@ }, "model_file": "vespa_colbert.onnx", }, + { + "model": "jinaai/jina-colbert-v2", + "dim": 1024, + "description": "Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, 2024 year", + "size_in_GB": 2.24, + "sources": { + "hf": "jinaai/jina-colbert-v2", + }, + "model_file": "onnx/model.onnx", + "additional_files": ["onnx/model.onnx_data"], + }, ] @@ -42,7 +53,7 @@ class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]): QUERY_MARKER_TOKEN_ID = 1 DOCUMENT_MARKER_TOKEN_ID = 2 MIN_QUERY_LENGTH = 32 - MASK_TOKEN = "[MASK]" + MASK_TOKENS = ["[MASK]", ""] def _post_process_onnx_output( self, output: OnnxOutputContext, is_doc: bool = True @@ -92,7 +103,7 @@ def _tokenize_query(self, query: str) -> List[Encoding]: if self.tokenizer.padding: prev_padding = self.tokenizer.padding self.tokenizer.enable_padding( - pad_token=self.MASK_TOKEN, + pad_token=self.MASK_TOKENS[0], pad_id=self.mask_token_id, length=self.MIN_QUERY_LENGTH, ) @@ -189,7 +200,14 @@ def load_onnx_model(self) -> None: cuda=self.cuda, device_id=self.device_id, ) - self.mask_token_id = self.special_token_to_id["[MASK]"] + self.mask_token_id = next( + ( + self.special_token_to_id[token] + for token in self.MASK_TOKENS + if token in self.special_token_to_id + ), + None, + ) self.pad_token_id = self.tokenizer.padding["pad_id"] self.skip_list = { self.tokenizer.encode(symbol, add_special_tokens=False).ids[0] diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index c8d62cac..ce0287e6 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -28,6 +28,15 @@ [-0.07281, 0.04633, -0.04711, 0.00762, -0.07374], ] ), + "jinaai/jina-colbert-v2": np.array( + [ + [0.0744, 0.0590, -0.2404, -0.1776, 0.0198], + [0.1318, 0.0882, -0.1136, -0.2065, 0.1461], + [-0.0178, -0.1359, -0.0136, -0.1075, -0.0509], + [0.0004, -0.1198, -0.0696, -0.0482, -0.0650], + [0.0766, 0.0448, -0.2344, -0.1829, 0.0061], + ] + ), } CANONICAL_QUERY_VALUES = { @@ -103,6 +112,42 @@ [-0.03473, 0.04792, -0.07033, 0.02196, -0.05314], ] ), + "jinaai/jina-colbert-v2": np.array( + [ + [0.0475, 0.0250, -0.2225, -0.1087, -0.0297], + [0.0211, -0.0844, -0.0070, -0.1715, 0.0154], + [-0.0062, -0.0958, -0.0142, -0.1283, -0.0218], + [0.0490, -0.0500, -0.1613, 0.0193, 0.0280], + [0.0477, 0.0250, -0.2279, -0.1128, -0.0289], + [0.0597, -0.0676, -0.0955, -0.0756, 0.0234], + [0.0592, -0.0858, -0.0621, -0.1088, 0.0148], + [0.0870, -0.0715, -0.0769, -0.1414, 0.0365], + [0.1015, -0.0552, -0.0667, -0.1637, 0.0492], + [0.1135, -0.0469, -0.0573, -0.1702, 0.0535], + [0.1226, -0.0430, -0.0508, -0.1729, 0.0553], + [0.1287, -0.0387, -0.0425, -0.1757, 0.0567], + [0.1360, -0.0337, -0.0327, -0.1790, 0.0570], + [0.1434, -0.0267, -0.0242, -0.1831, 0.0569], + [0.1528, -0.0091, -0.0184, -0.1881, 0.0570], + [0.1547, 0.0185, -0.0231, -0.1803, 0.0538], + [0.1396, 0.0533, -0.0349, -0.1637, 0.0429], + [0.1074, 0.0851, -0.0418, -0.1461, 0.0231], + [0.0719, 0.1061, -0.0440, -0.1291, -0.0003], + [0.0456, 0.1146, -0.0457, -0.1118, -0.0192], + [0.0347, 0.1132, -0.0493, -0.0955, -0.0341], + [0.0357, 0.1074, -0.0491, -0.0821, -0.0449], + [0.0421, 0.1036, -0.0461, -0.0763, -0.0488], + [0.0479, 0.1019, -0.0434, -0.0721, -0.0483], + [0.0470, 0.0988, -0.0423, -0.0654, -0.0440], + [0.0439, 0.0947, -0.0418, -0.0591, -0.0349], + [0.0397, 0.0898, -0.0415, -0.0555, -0.0206], + [0.0434, 0.0815, -0.0411, -0.0543, 0.0057], + [0.0512, 0.0629, -0.0442, -0.0547, 0.0378], + [0.0584, 0.0483, -0.0528, -0.0607, 0.0568], + [0.0568, 0.0456, -0.0674, -0.0699, 0.0768], + [0.0205, -0.0859, -0.0385, -0.1231, -0.0331], + ] + ), } docs = ["Hello World"] From f74ce95676addf73f1d39b93bb27690c3f7574a1 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 14 Oct 2024 17:24:36 +0300 Subject: [PATCH 2/6] chore: Generalized query marker and document marker --- fastembed/common/preprocessor_utils.py | 12 ++++++++ fastembed/late_interaction/colbert.py | 40 +++++++++++++++++++------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index db2432f3..7dc55e29 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -68,6 +68,18 @@ def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]: token_str = token.get("content", "") special_token_to_id[token_str] = tokenizer.token_to_id(token_str) + if tokenizer_config["tokenizer_class"] == "BertTokenizer": + query_marker = {"[Q]": 1} + document_marker = {"[D]": 2} + elif tokenizer_config["tokenizer_class"] == "XLMRobertaTokenizer": + query_marker = {"[QueryMarker]": 250002} + document_marker = {"[DocumentMarker]": 250003} + else: + query_marker = {} + document_marker = {} + + special_token_to_id.update(query_marker) + special_token_to_id.update(document_marker) return tokenizer, special_token_to_id diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 60264aa4..4279be89 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -50,10 +50,10 @@ class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]): - QUERY_MARKER_TOKEN_ID = 1 - DOCUMENT_MARKER_TOKEN_ID = 2 MIN_QUERY_LENGTH = 32 MASK_TOKENS = ["[MASK]", ""] + QUERY_MARKER_TOKENS = ["[Q]", "[QueryMarker]"] + DOCUMENT_MARKER_TOKENS = ["[D]", "[DocumentMarker]"] def _post_process_onnx_output( self, output: OnnxOutputContext, is_doc: bool = True @@ -81,9 +81,9 @@ def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True ) -> Dict[str, np.ndarray]: if is_doc: - onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID + onnx_input["input_ids"][:, 1] = self.document_marker_token_id else: - onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID + onnx_input["input_ids"][:, 1] = self.query_marker_token_id return onnx_input def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: @@ -94,8 +94,9 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: ) def _tokenize_query(self, query: str) -> List[Encoding]: - # ". " is added to a query to be replaced with a special query token - query = [f". {query}"] + # "@ " is added to a query to be replaced with a special query token + # please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models + query = [f"@ {query}"] encoded = self.tokenizer.encode_batch(query) # colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance if len(encoded[0].ids) < self.MIN_QUERY_LENGTH: @@ -103,7 +104,7 @@ def _tokenize_query(self, query: str) -> List[Encoding]: if self.tokenizer.padding: prev_padding = self.tokenizer.padding self.tokenizer.enable_padding( - pad_token=self.MASK_TOKENS[0], + pad_token=self.mask_token, pad_id=self.mask_token_id, length=self.MIN_QUERY_LENGTH, ) @@ -115,8 +116,9 @@ def _tokenize_query(self, query: str) -> List[Encoding]: return encoded def _tokenize_documents(self, documents: List[str]) -> List[Encoding]: - # ". " is added to a document to be replaced with a special document token - documents = [". " + doc for doc in documents] + # "@ " is added to a document to be replaced with a special document token + # please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models + documents = ["@ " + doc for doc in documents] encoded = self.tokenizer.encode_batch(documents) return encoded @@ -200,12 +202,28 @@ def load_onnx_model(self) -> None: cuda=self.cuda, device_id=self.device_id, ) - self.mask_token_id = next( + self.mask_token_id, self.mask_token = next( ( - self.special_token_to_id[token] + (self.special_token_to_id[token], token) for token in self.MASK_TOKENS if token in self.special_token_to_id ), + (None, None), + ) + self.query_marker_token_id = next( + ( + self.special_token_to_id[token] + for token in self.QUERY_MARKER_TOKENS + if token in self.special_token_to_id + ), + None, + ) + self.document_marker_token_id = next( + ( + self.special_token_to_id[token] + for token in self.DOCUMENT_MARKER_TOKENS + if token in self.special_token_to_id + ), None, ) self.pad_token_id = self.tokenizer.padding["pad_id"] From f0a6f2097008222a271b7d6ba04a06d8801a3117 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 14 Oct 2024 17:33:38 +0300 Subject: [PATCH 3/6] nit: remove github action on dispatch --- .github/workflows/python-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 72993f35..c42ac41a 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -2,8 +2,8 @@ name: Tests on: push: - branches: [ master, main, gpu, jina-colbert-v2 ] - workflow_dispatch: + branches: [ master, main, gpu ] + pull_request: env: CARGO_TERM_COLOR: always From b59dc47c5833bbd9e88420d452762375aa85f893 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 28 Oct 2024 09:00:39 +0300 Subject: [PATCH 4/6] chore: updated license --- fastembed/late_interaction/colbert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 4279be89..32f62b77 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -38,7 +38,8 @@ { "model": "jinaai/jina-colbert-v2", "dim": 1024, - "description": "Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, 2024 year", + "description": "New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year", + "license": "cc-by-nc-4.0", "size_in_GB": 2.24, "sources": { "hf": "jinaai/jina-colbert-v2", From 4e90914b4c2e84067b1804a4eb88b8ceb44e7804 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 28 Oct 2024 18:05:03 +0300 Subject: [PATCH 5/6] fix: Fix attention mask to be all 1 in xlmrobertatokenizer --- fastembed/common/preprocessor_utils.py | 4 ++-- fastembed/late_interaction/colbert.py | 3 +++ fastembed/rerank/cross_encoder/onnx_text_model.py | 2 +- fastembed/text/onnx_text_model.py | 4 +++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index 7dc55e29..9f5ba5e6 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -17,7 +17,7 @@ def load_special_tokens(model_dir: Path) -> dict: return tokens_map -def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]: +def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict, str]: config_path = model_dir / "config.json" if not config_path.exists(): raise ValueError(f"Could not find config.json in {model_dir}") @@ -80,7 +80,7 @@ def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]: special_token_to_id.update(query_marker) special_token_to_id.update(document_marker) - return tokenizer, special_token_to_id + return tokenizer, special_token_to_id, tokenizer_config["tokenizer_class"] def load_preprocessor(model_dir: Path) -> Compose: diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 32f62b77..e830a704 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -85,6 +85,9 @@ def _preprocess_onnx_input( onnx_input["input_ids"][:, 1] = self.document_marker_token_id else: onnx_input["input_ids"][:, 1] = self.query_marker_token_id + + if self.tokenizer_class == "XLMRobertaTokenizer": + onnx_input["attention_mask"][:] = 1 return onnx_input def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 85f9420c..2ba1e299 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -29,7 +29,7 @@ def _load_onnx_model( cuda=cuda, device_id=device_id, ) - self.tokenizer, _ = load_tokenizer(model_dir=model_dir) + self.tokenizer, _, _ = load_tokenizer(model_dir=model_dir) def tokenize(self, query: str, documents: List[str], **kwargs) -> List[Encoding]: return self.tokenizer.encode_batch([(query, doc) for doc in documents]) diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 36bfcb08..3939c5a4 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -53,7 +53,9 @@ def _load_onnx_model( cuda=cuda, device_id=device_id, ) - self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) + self.tokenizer, self.special_token_to_id, self.tokenizer_class = load_tokenizer( + model_dir=model_dir + ) def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") From 7a2f04d22c921e09266f591ff902f316d6579d4f Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 30 Oct 2024 06:38:05 +0300 Subject: [PATCH 6/6] feat: Added class for JinaColbertV2 --- fastembed/common/preprocessor_utils.py | 16 +- fastembed/late_interaction/colbert.py | 66 +---- fastembed/late_interaction/jina_colbert_v2.py | 252 ++++++++++++++++++ .../rerank/cross_encoder/onnx_text_model.py | 2 +- fastembed/text/onnx_text_model.py | 4 +- 5 files changed, 269 insertions(+), 71 deletions(-) create mode 100644 fastembed/late_interaction/jina_colbert_v2.py diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index 9f5ba5e6..db2432f3 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -17,7 +17,7 @@ def load_special_tokens(model_dir: Path) -> dict: return tokens_map -def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict, str]: +def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]: config_path = model_dir / "config.json" if not config_path.exists(): raise ValueError(f"Could not find config.json in {model_dir}") @@ -68,19 +68,7 @@ def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict, str]: token_str = token.get("content", "") special_token_to_id[token_str] = tokenizer.token_to_id(token_str) - if tokenizer_config["tokenizer_class"] == "BertTokenizer": - query_marker = {"[Q]": 1} - document_marker = {"[D]": 2} - elif tokenizer_config["tokenizer_class"] == "XLMRobertaTokenizer": - query_marker = {"[QueryMarker]": 250002} - document_marker = {"[DocumentMarker]": 250003} - else: - query_marker = {} - document_marker = {} - - special_token_to_id.update(query_marker) - special_token_to_id.update(document_marker) - return tokenizer, special_token_to_id, tokenizer_config["tokenizer_class"] + return tokenizer, special_token_to_id def load_preprocessor(model_dir: Path) -> Compose: diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index e830a704..d8e29e11 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -35,26 +35,14 @@ }, "model_file": "vespa_colbert.onnx", }, - { - "model": "jinaai/jina-colbert-v2", - "dim": 1024, - "description": "New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year", - "license": "cc-by-nc-4.0", - "size_in_GB": 2.24, - "sources": { - "hf": "jinaai/jina-colbert-v2", - }, - "model_file": "onnx/model.onnx", - "additional_files": ["onnx/model.onnx_data"], - }, ] class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]): + QUERY_MARKER_TOKEN_ID = 1 + DOCUMENT_MARKER_TOKEN_ID = 2 MIN_QUERY_LENGTH = 32 - MASK_TOKENS = ["[MASK]", ""] - QUERY_MARKER_TOKENS = ["[Q]", "[QueryMarker]"] - DOCUMENT_MARKER_TOKENS = ["[D]", "[DocumentMarker]"] + MASK_TOKEN = "[MASK]" def _post_process_onnx_output( self, output: OnnxOutputContext, is_doc: bool = True @@ -82,12 +70,9 @@ def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True ) -> Dict[str, np.ndarray]: if is_doc: - onnx_input["input_ids"][:, 1] = self.document_marker_token_id + onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID else: - onnx_input["input_ids"][:, 1] = self.query_marker_token_id - - if self.tokenizer_class == "XLMRobertaTokenizer": - onnx_input["attention_mask"][:] = 1 + onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID return onnx_input def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: @@ -98,21 +83,20 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: ) def _tokenize_query(self, query: str) -> List[Encoding]: - # "@ " is added to a query to be replaced with a special query token - # please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models - query = [f"@ {query}"] - encoded = self.tokenizer.encode_batch(query) + # ". " is added to a query to be replaced with a special query token + query = f". {query}" + encoded = self.tokenizer.encode_batch([query]) # colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance if len(encoded[0].ids) < self.MIN_QUERY_LENGTH: prev_padding = None if self.tokenizer.padding: prev_padding = self.tokenizer.padding self.tokenizer.enable_padding( - pad_token=self.mask_token, + pad_token=self.MASK_TOKEN, pad_id=self.mask_token_id, length=self.MIN_QUERY_LENGTH, ) - encoded = self.tokenizer.encode_batch(query) + encoded = self.tokenizer.encode_batch([query]) if prev_padding is None: self.tokenizer.no_padding() else: @@ -120,9 +104,8 @@ def _tokenize_query(self, query: str) -> List[Encoding]: return encoded def _tokenize_documents(self, documents: List[str]) -> List[Encoding]: - # "@ " is added to a document to be replaced with a special document token - # please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models - documents = ["@ " + doc for doc in documents] + # ". " is added to a document to be replaced with a special document token + documents = [". " + doc for doc in documents] encoded = self.tokenizer.encode_batch(documents) return encoded @@ -206,30 +189,7 @@ def load_onnx_model(self) -> None: cuda=self.cuda, device_id=self.device_id, ) - self.mask_token_id, self.mask_token = next( - ( - (self.special_token_to_id[token], token) - for token in self.MASK_TOKENS - if token in self.special_token_to_id - ), - (None, None), - ) - self.query_marker_token_id = next( - ( - self.special_token_to_id[token] - for token in self.QUERY_MARKER_TOKENS - if token in self.special_token_to_id - ), - None, - ) - self.document_marker_token_id = next( - ( - self.special_token_to_id[token] - for token in self.DOCUMENT_MARKER_TOKENS - if token in self.special_token_to_id - ), - None, - ) + self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN] self.pad_token_id = self.tokenizer.padding["pad_id"] self.skip_list = { self.tokenizer.encode(symbol, add_special_tokens=False).ids[0] diff --git a/fastembed/late_interaction/jina_colbert_v2.py b/fastembed/late_interaction/jina_colbert_v2.py new file mode 100644 index 00000000..ec0f18dd --- /dev/null +++ b/fastembed/late_interaction/jina_colbert_v2.py @@ -0,0 +1,252 @@ +import string +from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union + +import numpy as np +from tokenizers import Encoding + +from fastembed.common import OnnxProvider +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import define_cache_dir +from fastembed.late_interaction.late_interaction_embedding_base import ( + LateInteractionTextEmbeddingBase, +) +from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker + +supported_colbert_models = [ + { + "model": "jinaai/jina-colbert-v2", + "dim": 1024, + "description": "New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year", + "license": "cc-by-nc-4.0", + "size_in_GB": 2.24, + "sources": { + "hf": "jinaai/jina-colbert-v2", + }, + "model_file": "onnx/model.onnx", + "additional_files": ["onnx/model.onnx_data"], + }, +] + + +class JinaColbertV2(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]): + QUERY_MARKER_TOKEN_ID = 250002 + DOCUMENT_MARKER_TOKEN_ID = 250003 + MIN_QUERY_LENGTH = 32 + MASK_TOKEN = "" + + def _post_process_onnx_output( + self, output: OnnxOutputContext, is_doc: bool = True + ) -> Iterable[np.ndarray]: + if not is_doc: + return output.model_output.astype(np.float32) + + if output.input_ids is None or output.attention_mask is None: + raise ValueError( + "input_ids and attention_mask must be provided for document post-processing" + ) + + for i, token_sequence in enumerate(output.input_ids): + for j, token_id in enumerate(token_sequence): + if token_id in self.skip_list or token_id == self.pad_token_id: + output.attention_mask[i, j] = 0 + + output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32) + norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True) + norm_clamped = np.maximum(norm, 1e-12) + output.model_output /= norm_clamped + return output.model_output.astype(np.float32) + + def _preprocess_onnx_input( + self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True + ) -> Dict[str, np.ndarray]: + if is_doc: + onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID + else: + onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID + + # the attention mask for jina-colbert-v2 is always 1 + onnx_input["attention_mask"][:] = 1 + return onnx_input + + def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: + return ( + self._tokenize_documents(documents=documents) + if is_doc + else self._tokenize_query(query=next(iter(documents))) + ) + + def _tokenize_query(self, query: str) -> List[Encoding]: + # "@ " is added to a query to be replaced with a special query token + # "@ " is considered as one token in jina-colbert-v2 tokenizer + query = f"@ {query}" + encoded = self.tokenizer.encode_batch([query]) + # colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance + if len(encoded[0].ids) < self.MIN_QUERY_LENGTH: + prev_padding = None + if self.tokenizer.padding: + prev_padding = self.tokenizer.padding + self.tokenizer.enable_padding( + pad_token=self.MASK_TOKEN, + pad_id=self.mask_token_id, + length=self.MIN_QUERY_LENGTH, + ) + encoded = self.tokenizer.encode_batch([query]) + if prev_padding is None: + self.tokenizer.no_padding() + else: + self.tokenizer.enable_padding(**prev_padding) + return encoded + + def _tokenize_documents(self, documents: List[str]) -> List[Encoding]: + # "@ " is added to a document to be replaced with a special document token + # "@ " is considered as one token in jina-colbert-v2 tokenizer + documents = ["@ " + doc for doc in documents] + encoded = self.tokenizer.encode_batch(documents) + return encoded + + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + """Lists the supported models. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_colbert_models + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[List[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + **kwargs, + ): + """ + Args: + model_name (str): The name of the model to use. + cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. + providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use. + Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None. + cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers` + Defaults to False. + device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in + workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None. + lazy_load (bool, optional): Whether to load the model during class initialization or on demand. + Should be set to True when using multiple-gpu and parallel encoding. Defaults to False. + device_id (Optional[int], optional): The device id to use for loading the model in the worker process. + + Raises: + ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. + """ + + super().__init__(model_name, cache_dir, threads, **kwargs) + self.providers = providers + self.lazy_load = lazy_load + + # List of device ids, that can be used for data parallel processing in workers + self.device_ids = device_ids + self.cuda = cuda + + # This device_id will be used if we need to load model in current process + if device_id is not None: + self.device_id = device_id + elif self.device_ids is not None: + self.device_id = self.device_ids[0] + else: + self.device_id = None + + self.model_description = self._get_model_description(model_name) + self.cache_dir = define_cache_dir(cache_dir) + + self._model_dir = self.download_model( + self.model_description, self.cache_dir, local_files_only=self._local_files_only + ) + self.mask_token_id = None + self.pad_token_id = None + self.skip_list = set() + + if not self.lazy_load: + self.load_onnx_model() + + def load_onnx_model(self) -> None: + self._load_onnx_model( + model_dir=self._model_dir, + model_file=self.model_description["model_file"], + threads=self.threads, + providers=self.providers, + cuda=self.cuda, + device_id=self.device_id, + ) + self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN] + self.pad_token_id = self.tokenizer.padding["pad_id"] + self.skip_list = { + self.tokenizer.encode(symbol, add_special_tokens=False).ids[0] + for symbol in string.punctuation + } + + def embed( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + We use mean pooling with attention so that the model can handle variable-length inputs. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self._embed_documents( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + documents=documents, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + + def query_embed(self, query: Union[str, List[str]], **kwargs) -> Iterable[np.ndarray]: + if isinstance(query, str): + query = [query] + + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + for text in query: + yield from self._post_process_onnx_output( + self.onnx_embed([text], is_doc=False), is_doc=False + ) + + @classmethod + def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: + return ColbertEmbeddingWorker + + +class ColbertEmbeddingWorker(TextEmbeddingWorker): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> JinaColbertV2: + return JinaColbertV2( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 2ba1e299..85f9420c 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -29,7 +29,7 @@ def _load_onnx_model( cuda=cuda, device_id=device_id, ) - self.tokenizer, _, _ = load_tokenizer(model_dir=model_dir) + self.tokenizer, _ = load_tokenizer(model_dir=model_dir) def tokenize(self, query: str, documents: List[str], **kwargs) -> List[Encoding]: return self.tokenizer.encode_batch([(query, doc) for doc in documents]) diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 3939c5a4..36bfcb08 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -53,9 +53,7 @@ def _load_onnx_model( cuda=cuda, device_id=device_id, ) - self.tokenizer, self.special_token_to_id, self.tokenizer_class = load_tokenizer( - model_dir=model_dir - ) + self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method")