diff --git a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py index f9adc65c..0c104432 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py @@ -1,6 +1,15 @@ -from typing import Any, Iterable, Optional, Sequence, Type +from typing import Iterable, Any, Sequence, Optional, Type + import numpy as np + +from fastembed.common import OnnxProvider +from fastembed.rerank.cross_encoder.onnx_text_model import ( + OnnxCrossEncoderModel, + TextRerankerWorker, +) +from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase +from fastembed.common.utils import define_cache_dir from loguru import logger from fastembed.common import OnnxProvider @@ -12,6 +21,7 @@ ) from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase + supported_onnx_models = [ { "model": "Xenova/ms-marco-MiniLM-L-6-v2", diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 633c284d..0ee33b58 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -1,7 +1,7 @@ import os from multiprocessing import get_all_start_methods from pathlib import Path -from typing import Any, Iterable, Optional, Sequence +from typing import Sequence, Optional, Iterable, Any, Type import numpy as np from tokenizers import Encoding @@ -18,9 +18,13 @@ from fastembed.parallel_processor import ParallelWorkerPool -class OnnxCrossEncoderModel(OnnxModel): +class OnnxCrossEncoderModel(OnnxModel[float]): ONNX_OUTPUT_NAMES: Optional[list[str]] = None + @classmethod + def _get_worker_class(cls) -> Type["TextRerankerWorker"]: + raise NotImplementedError("Subclasses must implement this method") + def _load_onnx_model( self, model_dir: Path, @@ -40,10 +44,8 @@ def _load_onnx_model( ) self.tokenizer, _ = load_tokenizer(model_dir=model_dir) - def tokenize( - self, pairs: Iterable[tuple[str, str]], **kwargs: Any - ) -> list[Encoding]: - return self.tokenizer.encode_batch([pair for pair in pairs], **kwargs) + def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]: + return self.tokenizer.encode_batch(pairs, **kwargs) def _build_onnx_input(self, tokenized_input): inputs = { @@ -59,10 +61,8 @@ def _build_onnx_input(self, tokenized_input): ) return inputs - def onnx_embed( - self, query: str, documents: list[str], **kwargs: Any - ) -> OnnxOutputContext: - pairs = ((query, doc) for doc in documents) + def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext: + pairs = [(query, doc) for doc in documents] return self.onnx_embed_pairs(pairs, **kwargs) def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any): @@ -72,7 +72,7 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any): outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) relevant_output = outputs[0] scores = relevant_output[:, 0] - return OnnxOutputContext(model_output=scores.tolist()) + return OnnxOutputContext(model_output=scores) def _rerank_documents( self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any @@ -80,7 +80,7 @@ def _rerank_documents( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(documents, batch_size): - yield from self.onnx_embed(query, batch, **kwargs).model_output + yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs)) def _rerank_pairs( self, @@ -96,9 +96,6 @@ def _rerank_pairs( ) -> Iterable[float]: is_small = False - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() - if isinstance(pairs, tuple): pairs = [pairs] is_small = True @@ -111,7 +108,7 @@ def _rerank_pairs( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(pairs, batch_size): - yield from self.onnx_embed_pairs(batch, **kwargs).model_output + yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs)) else: if parallel == 0: parallel = os.cpu_count() @@ -138,7 +135,7 @@ def _rerank_pairs( self.onnx_embed_pairs(batch, **kwargs) ) - def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]: + def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]: return output.model_output def _preprocess_onnx_input( diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder_base.py b/fastembed/rerank/cross_encoder/text_cross_encoder_base.py index 9bf09e46..a34b8cd8 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder_base.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder_base.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Optional +from typing import Iterable, Optional, Any from fastembed.common.model_management import ModelManagement @@ -23,7 +23,7 @@ def rerank( batch_size: int = 64, **kwargs, ) -> Iterable[float]: - """Reranks a list of documents given a query. + """Rerank a list of documents given a query. Args: query (str): The query to rerank the documents.