Skip to content

Commit

Permalink
nit: fix post process, update docstring, update tokenize, remove redu…
Browse files Browse the repository at this point in the history
…ndant imports
  • Loading branch information
joein committed Dec 16, 2024
1 parent 5dfec6d commit b5d2abc
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 41 deletions.
5 changes: 1 addition & 4 deletions fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Optional, Sequence, Type

import numpy as np
from loguru import logger

from fastembed.common import OnnxProvider
Expand Down Expand Up @@ -206,9 +205,7 @@ def rerank_pairs(
def _get_worker_class(cls) -> Type[TextRerankerWorker]:
return TextCrossEncoderWorker

def _post_process_onnx_output(
self, output: OnnxOutputContext
) -> Iterable[float]:
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
return (float(elem) for elem in output.model_output)


Expand Down
25 changes: 8 additions & 17 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
OnnxModel,
OnnxOutputContext,
OnnxProvider,
T,
)
from fastembed.common.preprocessor_utils import load_tokenizer
from fastembed.common.utils import iter_batch
Expand Down Expand Up @@ -44,8 +43,8 @@ def _load_onnx_model(
)
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)

def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]:
return self.tokenizer.encode_batch(pairs, **kwargs)
def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
return self.tokenizer.encode_batch(pairs)

def _build_onnx_input(self, tokenized_input):
input_names = {node.name for node in self.model.get_inputs()}
Expand All @@ -57,14 +56,12 @@ def _build_onnx_input(self, tokenized_input):
[enc.type_ids for enc in tokenized_input], dtype=np.int64
)
if "attention_mask" in input_names:
inputs['attention_mask'] = np.array(
inputs["attention_mask"] = np.array(
[enc.attention_mask for enc in tokenized_input], dtype=np.int64
)
return inputs

def onnx_embed(
self, query: str, documents: list[str], **kwargs: Any
) -> OnnxOutputContext:
def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
pairs = [(query, doc) for doc in documents]
return self.onnx_embed_pairs(pairs, **kwargs)

Expand All @@ -83,9 +80,7 @@ def _rerank_documents(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self._post_process_onnx_output(
self.onnx_embed(query, batch, **kwargs)
)
yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs))

def _rerank_pairs(
self,
Expand Down Expand Up @@ -113,16 +108,12 @@ def _rerank_pairs(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(pairs, batch_size):
yield from self._post_process_onnx_output(
self.onnx_embed_pairs(batch, **kwargs)
)
yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
else:
if parallel == 0:
parallel = os.cpu_count()

start_method = (
"forkserver" if "forkserver" in get_all_start_methods() else "spawn"
)
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"model_name": model_name,
"cache_dir": cache_dir,
Expand All @@ -141,7 +132,7 @@ def _rerank_pairs(
yield from self._post_process_onnx_output(batch)

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
return output.model_output
raise NotImplementedError("Subclasses must implement this method")

def _preprocess_onnx_input(
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
Expand Down
11 changes: 4 additions & 7 deletions fastembed/rerank/cross_encoder/text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def __init__(

for CROSS_ENCODER_TYPE in self.CROSS_ENCODER_REGISTRY:
supported_models = CROSS_ENCODER_TYPE.list_supported_models()
if any(
model_name.lower() == model["model"].lower()
for model in supported_models
):
if any(model_name.lower() == model["model"].lower() for model in supported_models):
self.model = CROSS_ENCODER_TYPE(
model_name=model_name,
cache_dir=cache_dir,
Expand Down Expand Up @@ -112,11 +109,11 @@ def rerank_pairs(
Higher scores indicate a stronger match between the query and the document.
Example:
>>> encoder = TextCrossEncoder("some-model")
>>> encoder = TextCrossEncoder("Xenova/ms-marco-MiniLM-L-6-v2")
>>> pairs = [("What is AI?", "Artificial intelligence is ..."), ("What is ML?", "Machine learning is ...")]
>>> scores = list(encoder.rerank_pairs(pairs))
>>> print(scores)
[0.92, 0.87]
>>> print(list(map(lambda x: round(x, 2), scores)))
[-1.24, -10.6]
"""
yield from self.model.rerank_pairs(
pairs, batch_size=batch_size, parallel=parallel, **kwargs
Expand Down
19 changes: 6 additions & 13 deletions tests/test_text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@

@pytest.mark.parametrize(
"model_name",
[
model_name
for model_name in CANONICAL_SCORE_VALUES.keys()
],
[model_name for model_name in CANONICAL_SCORE_VALUES],
)
def test_rerank(model_name):
is_ci = os.getenv("CI")
Expand All @@ -51,12 +48,10 @@ def test_rerank(model_name):
if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
[
model_name
for provider, model_name in SELECTED_MODELS.items()
],
[model_name for model_name in SELECTED_MODELS.values()],
)
def test_batch_rerank(model_name):
is_ci = os.getenv("CI")
Expand Down Expand Up @@ -99,12 +94,10 @@ def test_lazy_load(model_name):
if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
[
model_name
for provider, model_name in SELECTED_MODELS.items()
],
[model_name for model_name in SELECTED_MODELS.values()],
)
def test_rerank_pairs_parallel(model_name):
is_ci = os.getenv("CI")
Expand All @@ -120,7 +113,7 @@ def test_rerank_pairs_parallel(model_name):
), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Scores (Sequential): {scores_sequential}"
canonical_scores = CANONICAL_SCORE_VALUES[model_name]
assert np.allclose(
scores_parallel[:len(canonical_scores)], canonical_scores, atol=1e-3
scores_parallel[: len(canonical_scores)], canonical_scores, atol=1e-3
), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Expected: {canonical_scores}"
if is_ci:
delete_model_cache(model.model._model_dir)

0 comments on commit b5d2abc

Please sign in to comment.