From b5d2abcedfbdf0f0e0cc9cc805c5be8becbad75a Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Mon, 16 Dec 2024 16:38:28 +0100 Subject: [PATCH] nit: fix post process, update docstring, update tokenize, remove redundant imports --- .../cross_encoder/onnx_text_cross_encoder.py | 5 +--- .../rerank/cross_encoder/onnx_text_model.py | 25 ++++++------------- .../cross_encoder/text_cross_encoder.py | 11 +++----- tests/test_text_cross_encoder.py | 19 +++++--------- 4 files changed, 19 insertions(+), 41 deletions(-) diff --git a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py index 43788bc0..3b716719 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py @@ -1,6 +1,5 @@ from typing import Any, Iterable, Optional, Sequence, Type -import numpy as np from loguru import logger from fastembed.common import OnnxProvider @@ -206,9 +205,7 @@ def rerank_pairs( def _get_worker_class(cls) -> Type[TextRerankerWorker]: return TextCrossEncoderWorker - def _post_process_onnx_output( - self, output: OnnxOutputContext - ) -> Iterable[float]: + def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]: return (float(elem) for elem in output.model_output) diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 4a303d09..f3a02b64 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -11,7 +11,6 @@ OnnxModel, OnnxOutputContext, OnnxProvider, - T, ) from fastembed.common.preprocessor_utils import load_tokenizer from fastembed.common.utils import iter_batch @@ -44,8 +43,8 @@ def _load_onnx_model( ) self.tokenizer, _ = load_tokenizer(model_dir=model_dir) - def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]: - return self.tokenizer.encode_batch(pairs, **kwargs) + def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]: + return self.tokenizer.encode_batch(pairs) def _build_onnx_input(self, tokenized_input): input_names = {node.name for node in self.model.get_inputs()} @@ -57,14 +56,12 @@ def _build_onnx_input(self, tokenized_input): [enc.type_ids for enc in tokenized_input], dtype=np.int64 ) if "attention_mask" in input_names: - inputs['attention_mask'] = np.array( + inputs["attention_mask"] = np.array( [enc.attention_mask for enc in tokenized_input], dtype=np.int64 ) return inputs - def onnx_embed( - self, query: str, documents: list[str], **kwargs: Any - ) -> OnnxOutputContext: + def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext: pairs = [(query, doc) for doc in documents] return self.onnx_embed_pairs(pairs, **kwargs) @@ -83,9 +80,7 @@ def _rerank_documents( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(documents, batch_size): - yield from self._post_process_onnx_output( - self.onnx_embed(query, batch, **kwargs) - ) + yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs)) def _rerank_pairs( self, @@ -113,16 +108,12 @@ def _rerank_pairs( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(pairs, batch_size): - yield from self._post_process_onnx_output( - self.onnx_embed_pairs(batch, **kwargs) - ) + yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs)) else: if parallel == 0: parallel = os.cpu_count() - start_method = ( - "forkserver" if "forkserver" in get_all_start_methods() else "spawn" - ) + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" params = { "model_name": model_name, "cache_dir": cache_dir, @@ -141,7 +132,7 @@ def _rerank_pairs( yield from self._post_process_onnx_output(batch) def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]: - return output.model_output + raise NotImplementedError("Subclasses must implement this method") def _preprocess_onnx_input( self, onnx_input: dict[str, np.ndarray], **kwargs: Any diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index a90dc11f..e6d0b2ef 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -53,10 +53,7 @@ def __init__( for CROSS_ENCODER_TYPE in self.CROSS_ENCODER_REGISTRY: supported_models = CROSS_ENCODER_TYPE.list_supported_models() - if any( - model_name.lower() == model["model"].lower() - for model in supported_models - ): + if any(model_name.lower() == model["model"].lower() for model in supported_models): self.model = CROSS_ENCODER_TYPE( model_name=model_name, cache_dir=cache_dir, @@ -112,11 +109,11 @@ def rerank_pairs( Higher scores indicate a stronger match between the query and the document. Example: - >>> encoder = TextCrossEncoder("some-model") + >>> encoder = TextCrossEncoder("Xenova/ms-marco-MiniLM-L-6-v2") >>> pairs = [("What is AI?", "Artificial intelligence is ..."), ("What is ML?", "Machine learning is ...")] >>> scores = list(encoder.rerank_pairs(pairs)) - >>> print(scores) - [0.92, 0.87] + >>> print(list(map(lambda x: round(x, 2), scores))) + [-1.24, -10.6] """ yield from self.model.rerank_pairs( pairs, batch_size=batch_size, parallel=parallel, **kwargs diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index 3b3deba5..15ab3731 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -24,10 +24,7 @@ @pytest.mark.parametrize( "model_name", - [ - model_name - for model_name in CANONICAL_SCORE_VALUES.keys() - ], + [model_name for model_name in CANONICAL_SCORE_VALUES], ) def test_rerank(model_name): is_ci = os.getenv("CI") @@ -51,12 +48,10 @@ def test_rerank(model_name): if is_ci: delete_model_cache(model.model._model_dir) + @pytest.mark.parametrize( "model_name", - [ - model_name - for provider, model_name in SELECTED_MODELS.items() - ], + [model_name for model_name in SELECTED_MODELS.values()], ) def test_batch_rerank(model_name): is_ci = os.getenv("CI") @@ -99,12 +94,10 @@ def test_lazy_load(model_name): if is_ci: delete_model_cache(model.model._model_dir) + @pytest.mark.parametrize( "model_name", - [ - model_name - for provider, model_name in SELECTED_MODELS.items() - ], + [model_name for model_name in SELECTED_MODELS.values()], ) def test_rerank_pairs_parallel(model_name): is_ci = os.getenv("CI") @@ -120,7 +113,7 @@ def test_rerank_pairs_parallel(model_name): ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Scores (Sequential): {scores_sequential}" canonical_scores = CANONICAL_SCORE_VALUES[model_name] assert np.allclose( - scores_parallel[:len(canonical_scores)], canonical_scores, atol=1e-3 + scores_parallel[: len(canonical_scores)], canonical_scores, atol=1e-3 ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Expected: {canonical_scores}" if is_ci: delete_model_cache(model.model._model_dir)