From c578de22add2c5168fbbef82a3cfe0e0d6e5e355 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Mon, 11 Nov 2024 16:20:13 +0100 Subject: [PATCH] refactor --- fastembed/late_interaction/colbert.py | 9 ++------- fastembed/late_interaction/jina_colbert.py | 13 ++++--------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 1f33e3c6..1a2fbcd2 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -12,6 +12,7 @@ ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker + supported_colbert_models = [ { "model": "colbert-ir/colbertv2.0", @@ -41,7 +42,7 @@ class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]): QUERY_MARKER_TOKEN_ID = 1 DOCUMENT_MARKER_TOKEN_ID = 2 - MIN_QUERY_LENGTH = 32 + MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning MASK_TOKEN = "[MASK]" def _post_process_onnx_output( @@ -69,15 +70,9 @@ def _post_process_onnx_output( def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True ) -> Dict[str, np.ndarray]: - original_length = onnx_input["input_ids"].shape[1] marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID - onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1) onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1) - - if not is_doc: - onnx_input["input_ids"] = onnx_input["input_ids"][:, :original_length] - onnx_input["attention_mask"] = onnx_input["attention_mask"][:, :original_length] return onnx_input def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: diff --git a/fastembed/late_interaction/jina_colbert.py b/fastembed/late_interaction/jina_colbert.py index 19f03877..9f7c4b32 100644 --- a/fastembed/late_interaction/jina_colbert.py +++ b/fastembed/late_interaction/jina_colbert.py @@ -5,6 +5,7 @@ from fastembed.late_interaction.colbert import Colbert from fastembed.text.onnx_text_model import TextEmbeddingWorker + supported_jina_colbert_models = [ { "model": "jinaai/jina-colbert-v2", @@ -24,7 +25,7 @@ class JinaColbert(Colbert): QUERY_MARKER_TOKEN_ID = 250002 DOCUMENT_MARKER_TOKEN_ID = 250003 - MIN_QUERY_LENGTH = 32 + MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning MASK_TOKEN = "" @classmethod @@ -43,16 +44,10 @@ def list_supported_models(cls) -> List[Dict[str, Any]]: def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True ) -> Dict[str, np.ndarray]: - original_length = onnx_input["input_ids"].shape[1] - marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID - - onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1) - onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1) + onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc) + # the attention mask for jina-colbert-v2 is always 1 in queries if not is_doc: - onnx_input["input_ids"] = onnx_input["input_ids"][:, :original_length] - onnx_input["attention_mask"] = onnx_input["attention_mask"][:, :original_length] - # the attention mask for jina-colbert-v2 is always 1 in queries onnx_input["attention_mask"][:] = 1 return onnx_input