Skip to content

Commit

Permalink
Fix type hint (#392)
Browse files Browse the repository at this point in the history
* fix: Fix return of onnx_embed

* fix: Fix type hint of start method in worker class

* fix: Fix not passing kwargs in _preprocess_onnx_input and tokenize as base class

* fix: Fix not passing kwargs in _preprocess_onnx_input as base class

* fix: change tokenize in simpleTokenizer to classmethod

* chore: Changed query argument to Iterable to match base class

* chore: changed mask token id and pad token id to be int

* review suggestions (#398)

---------

Co-authored-by: George <[email protected]>
  • Loading branch information
hh-space-invader and joein authored Nov 13, 2024
1 parent 7c93571 commit d141e29
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 20 deletions.
12 changes: 1 addition & 11 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Dict,
Generic,
Iterable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from typing import Any, Dict, Generic, Iterable, Optional, Sequence, Tuple, Type, TypeVar

import numpy as np
import onnxruntime as ort
Expand Down
6 changes: 3 additions & 3 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def _post_process_onnx_output(
return output.model_output.astype(np.float32)

def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True, **kwargs: Any
) -> Dict[str, np.ndarray]:
marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID
onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1)
onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1)
return onnx_input

def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
def tokenize(self, documents: List[str], is_doc: bool = True, **kwargs: Any) -> List[Encoding]:
return (
self._tokenize_documents(documents=documents)
if is_doc
Expand Down Expand Up @@ -226,7 +226,7 @@ def embed(
**kwargs,
)

def query_embed(self, query: Union[str, List[str]], **kwargs) -> Iterable[np.ndarray]:
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
if isinstance(query, str):
query = [query]

Expand Down
2 changes: 1 addition & 1 deletion fastembed/late_interaction/jina_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def list_supported_models(cls) -> List[Dict[str, Any]]:
return supported_jina_colbert_models

def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True, **kwargs: Any
) -> Dict[str, np.ndarray]:
onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc)

Expand Down
2 changes: 1 addition & 1 deletion fastembed/parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class QueueSignals(str, Enum):

class Worker:
@classmethod
def start(cls, **kwargs: Any) -> "Worker":
def start(cls, *args: Any, **kwargs: Any) -> "Worker":
raise NotImplementedError()

def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]]:
Expand Down
8 changes: 4 additions & 4 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from tokenizers import Encoding

from fastembed.common.onnx_model import OnnxModel, OnnxProvider
from fastembed.common.onnx_model import OnnxModel, OnnxProvider, OnnxOutputContext
from fastembed.common.preprocessor_utils import load_tokenizer
from fastembed.common.utils import iter_batch

Expand Down Expand Up @@ -34,7 +34,7 @@ def _load_onnx_model(
def tokenize(self, query: str, documents: List[str], **kwargs) -> List[Encoding]:
return self.tokenizer.encode_batch([(query, doc) for doc in documents])

def onnx_embed(self, query: str, documents: List[str], **kwargs) -> List[float]:
def onnx_embed(self, query: str, documents: List[str], **kwargs) -> OnnxOutputContext:
tokenized_input = self.tokenize(query, documents, **kwargs)

inputs = {
Expand All @@ -51,15 +51,15 @@ def onnx_embed(self, query: str, documents: List[str], **kwargs) -> List[float]:

onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
return outputs[0][:, 0].tolist()
return OnnxOutputContext(model_output=outputs[0][:, 0].tolist())

def _rerank_documents(
self, query: str, documents: Iterable[str], batch_size: int, **kwargs
) -> Iterable[float]:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self.onnx_embed(query, batch, **kwargs)
yield from self.onnx_embed(query, batch, **kwargs).model_output

def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], **kwargs
Expand Down
1 change: 1 addition & 0 deletions fastembed/sparse/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


class SimpleTokenizer:
@staticmethod
def tokenize(text: str) -> List[str]:
text = re.sub(r"[^\w]", " ", text.lower())
text = re.sub(r"\s+", " ", text)
Expand Down

0 comments on commit d141e29

Please sign in to comment.