Skip to content

Commit

Permalink
Test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
I8dNLo committed Dec 11, 2024
1 parent 42b8f4e commit fdb4681
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 30 deletions.
27 changes: 11 additions & 16 deletions fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,15 @@
from typing import Iterable, Any, Sequence, Optional, Type

from typing import Any, Iterable, Optional, Sequence, Type

import numpy as np

from fastembed.common import OnnxProvider
from fastembed.rerank.cross_encoder.onnx_text_model import (
OnnxCrossEncoderModel,
TextRerankerWorker,
)
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.common.utils import define_cache_dir
from loguru import logger

from fastembed.common import OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir
from fastembed.rerank.cross_encoder.onnx_text_model import (
OnnxCrossEncoderModel,
TextRerankerWorker,
)
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase

OnnxCrossEncoderModel, TextRerankerWorker)
from fastembed.rerank.cross_encoder.text_cross_encoder_base import \
TextCrossEncoderBase

supported_onnx_models = [
{
Expand Down Expand Up @@ -199,7 +188,13 @@ def rerank_pairs(
batch_size: int = 64,
**kwargs: Any,
) -> Iterable[float]:
yield from self._rerank_pairs(pairs=pairs, batch_size=batch_size, **kwargs)
yield from self._rerank_pairs(
model_name=self._model_dir,
cache_dir=self.cache_dir,
pairs=pairs,
batch_size=batch_size,
**kwargs,
)

@classmethod
def _get_worker_class(cls) -> Type[TextRerankerWorker]:
Expand Down
23 changes: 12 additions & 11 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import os
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Sequence, Optional, Iterable, Any, Type
from typing import Any, Iterable, Optional, Sequence, Type

import numpy as np
from tokenizers import Encoding

from fastembed.common.onnx_model import (
EmbeddingWorker,
OnnxModel,
OnnxOutputContext,
OnnxProvider,
T,
)
from fastembed.common.onnx_model import (EmbeddingWorker, OnnxModel,
OnnxOutputContext, OnnxProvider, T)
from fastembed.common.preprocessor_utils import load_tokenizer
from fastembed.common.utils import iter_batch
from fastembed.parallel_processor import ParallelWorkerPool
Expand Down Expand Up @@ -61,7 +56,9 @@ def _build_onnx_input(self, tokenized_input):
)
return inputs

def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
def onnx_embed(
self, query: str, documents: list[str], **kwargs: Any
) -> OnnxOutputContext:
pairs = [(query, doc) for doc in documents]
return self.onnx_embed_pairs(pairs, **kwargs)

Expand All @@ -80,7 +77,9 @@ def _rerank_documents(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs))
yield from self._post_process_onnx_output(
self.onnx_embed(query, batch, **kwargs)
)

def _rerank_pairs(
self,
Expand Down Expand Up @@ -108,7 +107,9 @@ def _rerank_pairs(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(pairs, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
yield from self._post_process_onnx_output(
self.onnx_embed_pairs(batch, **kwargs)
)
else:
if parallel == 0:
parallel = os.cpu_count()
Expand Down
6 changes: 4 additions & 2 deletions fastembed/rerank/cross_encoder/text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Iterable, Optional, Sequence, Type

from fastembed.common import OnnxProvider
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import \
OnnxTextCrossEncoder
from fastembed.rerank.cross_encoder.text_cross_encoder_base import \
TextCrossEncoderBase


class TextCrossEncoder(TextCrossEncoderBase):
Expand Down
2 changes: 1 addition & 1 deletion fastembed/rerank/cross_encoder/text_cross_encoder_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Optional, Any
from typing import Any, Iterable, Optional

from fastembed.common.model_management import ModelManagement

Expand Down

0 comments on commit fdb4681

Please sign in to comment.