Skip to content

Commit

Permalink
improve: Changed the way we are adding the query and document markers…
Browse files Browse the repository at this point in the history
… in colbert
  • Loading branch information
hh-space-invader committed Nov 8, 2024
1 parent 5b89542 commit 2623ab8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
16 changes: 8 additions & 8 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,15 @@ def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
) -> Dict[str, np.ndarray]:
if is_doc:
onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID
onnx_input["input_ids"] = np.insert(
onnx_input["input_ids"], 1, self.DOCUMENT_MARKER_TOKEN_ID, axis=1
)
else:
onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID
onnx_input["input_ids"] = np.insert(
onnx_input["input_ids"], 1, self.QUERY_MARKER_TOKEN_ID, axis=1
)

onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1)
return onnx_input

def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
Expand All @@ -83,9 +89,6 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
)

def _tokenize_query(self, query: str) -> List[Encoding]:
# "@ " is added to a query to be replaced with a special query token
# make sure that "@ " is considered as a single token
query = f"@ {query}"
encoded = self.tokenizer.encode_batch([query])
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
Expand All @@ -105,9 +108,6 @@ def _tokenize_query(self, query: str) -> List[Encoding]:
return encoded

def _tokenize_documents(self, documents: List[str]) -> List[Encoding]:
# "@ " is added to a document to be replaced with a special document token
# make sure that "@ " is considered as a single token
documents = ["@ " + doc for doc in documents]
encoded = self.tokenizer.encode_batch(documents)
return encoded

Expand Down
10 changes: 8 additions & 2 deletions fastembed/late_interaction/jina_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,17 @@ def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
) -> Dict[str, np.ndarray]:
if is_doc:
onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID
onnx_input["input_ids"] = np.insert(
onnx_input["input_ids"], 1, self.DOCUMENT_MARKER_TOKEN_ID, axis=1
)
else:
onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID
onnx_input["input_ids"] = np.insert(
onnx_input["input_ids"], 1, self.QUERY_MARKER_TOKEN_ID, axis=1
)
# the attention mask for jina-colbert-v2 is always 1 in queries
onnx_input["attention_mask"][:] = 1

onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1)
return onnx_input


Expand Down

0 comments on commit 2623ab8

Please sign in to comment.