Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Nov 11, 2024
1 parent d63262d commit c578de2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 16 deletions.
9 changes: 2 additions & 7 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker


supported_colbert_models = [
{
"model": "colbert-ir/colbertv2.0",
Expand Down Expand Up @@ -41,7 +42,7 @@
class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]):
QUERY_MARKER_TOKEN_ID = 1
DOCUMENT_MARKER_TOKEN_ID = 2
MIN_QUERY_LENGTH = 32
MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
MASK_TOKEN = "[MASK]"

def _post_process_onnx_output(
Expand Down Expand Up @@ -69,15 +70,9 @@ def _post_process_onnx_output(
def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
) -> Dict[str, np.ndarray]:
original_length = onnx_input["input_ids"].shape[1]
marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID

onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1)
onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1)

if not is_doc:
onnx_input["input_ids"] = onnx_input["input_ids"][:, :original_length]
onnx_input["attention_mask"] = onnx_input["attention_mask"][:, :original_length]
return onnx_input

def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
Expand Down
13 changes: 4 additions & 9 deletions fastembed/late_interaction/jina_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastembed.late_interaction.colbert import Colbert
from fastembed.text.onnx_text_model import TextEmbeddingWorker


supported_jina_colbert_models = [
{
"model": "jinaai/jina-colbert-v2",
Expand All @@ -24,7 +25,7 @@
class JinaColbert(Colbert):
QUERY_MARKER_TOKEN_ID = 250002
DOCUMENT_MARKER_TOKEN_ID = 250003
MIN_QUERY_LENGTH = 32
MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
MASK_TOKEN = "<mask>"

@classmethod
Expand All @@ -43,16 +44,10 @@ def list_supported_models(cls) -> List[Dict[str, Any]]:
def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
) -> Dict[str, np.ndarray]:
original_length = onnx_input["input_ids"].shape[1]
marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID

onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1)
onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1)
onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc)

# the attention mask for jina-colbert-v2 is always 1 in queries
if not is_doc:
onnx_input["input_ids"] = onnx_input["input_ids"][:, :original_length]
onnx_input["attention_mask"] = onnx_input["attention_mask"][:, :original_length]
# the attention mask for jina-colbert-v2 is always 1 in queries
onnx_input["attention_mask"][:] = 1
return onnx_input

Expand Down

0 comments on commit c578de2

Please sign in to comment.