diff --git a/fastembed/common/onnx_model.py b/fastembed/common/onnx_model.py index 859416ec..5d26e55a 100644 --- a/fastembed/common/onnx_model.py +++ b/fastembed/common/onnx_model.py @@ -1,17 +1,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import ( - Any, - Dict, - Generic, - Iterable, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +from typing import Any, Dict, Generic, Iterable, Optional, Sequence, Tuple, Type, TypeVar import numpy as np import onnxruntime as ort diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 1a2fbcd2..52fb79e7 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -68,14 +68,14 @@ def _post_process_onnx_output( return output.model_output.astype(np.float32) def _preprocess_onnx_input( - self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True + self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True, **kwargs: Any ) -> Dict[str, np.ndarray]: marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1) onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1) return onnx_input - def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: + def tokenize(self, documents: List[str], is_doc: bool = True, **kwargs: Any) -> List[Encoding]: return ( self._tokenize_documents(documents=documents) if is_doc @@ -226,7 +226,7 @@ def embed( **kwargs, ) - def query_embed(self, query: Union[str, List[str]], **kwargs) -> Iterable[np.ndarray]: + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: if isinstance(query, str): query = [query] diff --git a/fastembed/late_interaction/jina_colbert.py b/fastembed/late_interaction/jina_colbert.py index 9f7c4b32..4d62ca8f 100644 --- a/fastembed/late_interaction/jina_colbert.py +++ b/fastembed/late_interaction/jina_colbert.py @@ -42,7 +42,7 @@ def list_supported_models(cls) -> List[Dict[str, Any]]: return supported_jina_colbert_models def _preprocess_onnx_input( - self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True + self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True, **kwargs: Any ) -> Dict[str, np.ndarray]: onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc) diff --git a/fastembed/parallel_processor.py b/fastembed/parallel_processor.py index 386b0d3e..93245ae8 100644 --- a/fastembed/parallel_processor.py +++ b/fastembed/parallel_processor.py @@ -24,7 +24,7 @@ class QueueSignals(str, Enum): class Worker: @classmethod - def start(cls, **kwargs: Any) -> "Worker": + def start(cls, *args: Any, **kwargs: Any) -> "Worker": raise NotImplementedError() def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]]: diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 85f9420c..b1a30e2e 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -4,7 +4,7 @@ import numpy as np from tokenizers import Encoding -from fastembed.common.onnx_model import OnnxModel, OnnxProvider +from fastembed.common.onnx_model import OnnxModel, OnnxProvider, OnnxOutputContext from fastembed.common.preprocessor_utils import load_tokenizer from fastembed.common.utils import iter_batch @@ -34,7 +34,7 @@ def _load_onnx_model( def tokenize(self, query: str, documents: List[str], **kwargs) -> List[Encoding]: return self.tokenizer.encode_batch([(query, doc) for doc in documents]) - def onnx_embed(self, query: str, documents: List[str], **kwargs) -> List[float]: + def onnx_embed(self, query: str, documents: List[str], **kwargs) -> OnnxOutputContext: tokenized_input = self.tokenize(query, documents, **kwargs) inputs = { @@ -51,7 +51,7 @@ def onnx_embed(self, query: str, documents: List[str], **kwargs) -> List[float]: onnx_input = self._preprocess_onnx_input(inputs, **kwargs) outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) - return outputs[0][:, 0].tolist() + return OnnxOutputContext(model_output=outputs[0][:, 0].tolist()) def _rerank_documents( self, query: str, documents: Iterable[str], batch_size: int, **kwargs @@ -59,7 +59,7 @@ def _rerank_documents( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(documents, batch_size): - yield from self.onnx_embed(query, batch, **kwargs) + yield from self.onnx_embed(query, batch, **kwargs).model_output def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], **kwargs diff --git a/fastembed/sparse/utils/tokenizer.py b/fastembed/sparse/utils/tokenizer.py index 88b6c059..22045759 100644 --- a/fastembed/sparse/utils/tokenizer.py +++ b/fastembed/sparse/utils/tokenizer.py @@ -5,6 +5,7 @@ class SimpleTokenizer: + @staticmethod def tokenize(text: str) -> List[str]: text = re.sub(r"[^\w]", " ", text.lower()) text = re.sub(r"\s+", " ", text)