Skip to content

Commit

Permalink
wip: start reviewing (#420)
Browse files Browse the repository at this point in the history
Co-authored-by: Dmitrii Ogn <[email protected]>
  • Loading branch information
joein and I8dNLo authored Dec 11, 2024
1 parent 13acf23 commit 42b8f4e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
12 changes: 11 additions & 1 deletion fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from typing import Any, Iterable, Optional, Sequence, Type
from typing import Iterable, Any, Sequence, Optional, Type


import numpy as np

from fastembed.common import OnnxProvider
from fastembed.rerank.cross_encoder.onnx_text_model import (
OnnxCrossEncoderModel,
TextRerankerWorker,
)
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.common.utils import define_cache_dir
from loguru import logger

from fastembed.common import OnnxProvider
Expand All @@ -12,6 +21,7 @@
)
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase


supported_onnx_models = [
{
"model": "Xenova/ms-marco-MiniLM-L-6-v2",
Expand Down
31 changes: 14 additions & 17 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Any, Iterable, Optional, Sequence
from typing import Sequence, Optional, Iterable, Any, Type

import numpy as np
from tokenizers import Encoding
Expand All @@ -18,9 +18,13 @@
from fastembed.parallel_processor import ParallelWorkerPool


class OnnxCrossEncoderModel(OnnxModel):
class OnnxCrossEncoderModel(OnnxModel[float]):
ONNX_OUTPUT_NAMES: Optional[list[str]] = None

@classmethod
def _get_worker_class(cls) -> Type["TextRerankerWorker"]:
raise NotImplementedError("Subclasses must implement this method")

def _load_onnx_model(
self,
model_dir: Path,
Expand All @@ -40,10 +44,8 @@ def _load_onnx_model(
)
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)

def tokenize(
self, pairs: Iterable[tuple[str, str]], **kwargs: Any
) -> list[Encoding]:
return self.tokenizer.encode_batch([pair for pair in pairs], **kwargs)
def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]:
return self.tokenizer.encode_batch(pairs, **kwargs)

def _build_onnx_input(self, tokenized_input):
inputs = {
Expand All @@ -59,10 +61,8 @@ def _build_onnx_input(self, tokenized_input):
)
return inputs

def onnx_embed(
self, query: str, documents: list[str], **kwargs: Any
) -> OnnxOutputContext:
pairs = ((query, doc) for doc in documents)
def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
pairs = [(query, doc) for doc in documents]
return self.onnx_embed_pairs(pairs, **kwargs)

def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any):
Expand All @@ -72,15 +72,15 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any):
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
relevant_output = outputs[0]
scores = relevant_output[:, 0]
return OnnxOutputContext(model_output=scores.tolist())
return OnnxOutputContext(model_output=scores)

def _rerank_documents(
self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any
) -> Iterable[float]:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self.onnx_embed(query, batch, **kwargs).model_output
yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs))

def _rerank_pairs(
self,
Expand All @@ -96,9 +96,6 @@ def _rerank_pairs(
) -> Iterable[float]:
is_small = False

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

if isinstance(pairs, tuple):
pairs = [pairs]
is_small = True
Expand All @@ -111,7 +108,7 @@ def _rerank_pairs(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(pairs, batch_size):
yield from self.onnx_embed_pairs(batch, **kwargs).model_output
yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
else:
if parallel == 0:
parallel = os.cpu_count()
Expand All @@ -138,7 +135,7 @@ def _rerank_pairs(
self.onnx_embed_pairs(batch, **kwargs)
)

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
return output.model_output

def _preprocess_onnx_input(
Expand Down
4 changes: 2 additions & 2 deletions fastembed/rerank/cross_encoder/text_cross_encoder_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Iterable, Optional, Any

from fastembed.common.model_management import ModelManagement

Expand All @@ -23,7 +23,7 @@ def rerank(
batch_size: int = 64,
**kwargs,
) -> Iterable[float]:
"""Reranks a list of documents given a query.
"""Rerank a list of documents given a query.
Args:
query (str): The query to rerank the documents.
Expand Down

0 comments on commit 42b8f4e

Please sign in to comment.