diff --git a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py index 0c104432..e1f5fa4b 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py @@ -1,26 +1,15 @@ -from typing import Iterable, Any, Sequence, Optional, Type - +from typing import Any, Iterable, Optional, Sequence, Type import numpy as np - -from fastembed.common import OnnxProvider -from fastembed.rerank.cross_encoder.onnx_text_model import ( - OnnxCrossEncoderModel, - TextRerankerWorker, -) -from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase -from fastembed.common.utils import define_cache_dir from loguru import logger from fastembed.common import OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir from fastembed.rerank.cross_encoder.onnx_text_model import ( - OnnxCrossEncoderModel, - TextRerankerWorker, -) -from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase - + OnnxCrossEncoderModel, TextRerankerWorker) +from fastembed.rerank.cross_encoder.text_cross_encoder_base import \ + TextCrossEncoderBase supported_onnx_models = [ { @@ -199,7 +188,13 @@ def rerank_pairs( batch_size: int = 64, **kwargs: Any, ) -> Iterable[float]: - yield from self._rerank_pairs(pairs=pairs, batch_size=batch_size, **kwargs) + yield from self._rerank_pairs( + model_name=self._model_dir, + cache_dir=self.cache_dir, + pairs=pairs, + batch_size=batch_size, + **kwargs, + ) @classmethod def _get_worker_class(cls) -> Type[TextRerankerWorker]: diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 0ee33b58..cf076cd0 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -1,18 +1,13 @@ import os from multiprocessing import get_all_start_methods from pathlib import Path -from typing import Sequence, Optional, Iterable, Any, Type +from typing import Any, Iterable, Optional, Sequence, Type import numpy as np from tokenizers import Encoding -from fastembed.common.onnx_model import ( - EmbeddingWorker, - OnnxModel, - OnnxOutputContext, - OnnxProvider, - T, -) +from fastembed.common.onnx_model import (EmbeddingWorker, OnnxModel, + OnnxOutputContext, OnnxProvider, T) from fastembed.common.preprocessor_utils import load_tokenizer from fastembed.common.utils import iter_batch from fastembed.parallel_processor import ParallelWorkerPool @@ -61,7 +56,9 @@ def _build_onnx_input(self, tokenized_input): ) return inputs - def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext: + def onnx_embed( + self, query: str, documents: list[str], **kwargs: Any + ) -> OnnxOutputContext: pairs = [(query, doc) for doc in documents] return self.onnx_embed_pairs(pairs, **kwargs) @@ -80,7 +77,9 @@ def _rerank_documents( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(documents, batch_size): - yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs)) + yield from self._post_process_onnx_output( + self.onnx_embed(query, batch, **kwargs) + ) def _rerank_pairs( self, @@ -108,7 +107,9 @@ def _rerank_pairs( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(pairs, batch_size): - yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs)) + yield from self._post_process_onnx_output( + self.onnx_embed_pairs(batch, **kwargs) + ) else: if parallel == 0: parallel = os.cpu_count() diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index b27bedc7..f77416ff 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -1,8 +1,10 @@ from typing import Any, Iterable, Optional, Sequence, Type from fastembed.common import OnnxProvider -from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder -from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase +from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import \ + OnnxTextCrossEncoder +from fastembed.rerank.cross_encoder.text_cross_encoder_base import \ + TextCrossEncoderBase class TextCrossEncoder(TextCrossEncoderBase): diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder_base.py b/fastembed/rerank/cross_encoder/text_cross_encoder_base.py index a34b8cd8..81e76b1c 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder_base.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder_base.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Any +from typing import Any, Iterable, Optional from fastembed.common.model_management import ModelManagement