From 5e5f6c664702070d82bd36fe4b46e694855e2edb Mon Sep 17 00:00:00 2001 From: George Date: Wed, 13 Nov 2024 09:43:20 +0100 Subject: [PATCH] review suggestions (#398) --- fastembed/common/onnx_model.py | 16 ++-------------- fastembed/late_interaction/colbert.py | 4 ++-- .../rerank/cross_encoder/onnx_text_model.py | 8 ++++---- fastembed/sparse/utils/tokenizer.py | 4 ++-- 4 files changed, 10 insertions(+), 22 deletions(-) diff --git a/fastembed/common/onnx_model.py b/fastembed/common/onnx_model.py index 443cfa3f..5d26e55a 100644 --- a/fastembed/common/onnx_model.py +++ b/fastembed/common/onnx_model.py @@ -1,19 +1,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import ( - Any, - Dict, - Generic, - Iterable, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - List, -) +from typing import Any, Dict, Generic, Iterable, Optional, Sequence, Tuple, Type, TypeVar import numpy as np import onnxruntime as ort @@ -108,7 +96,7 @@ def _load_onnx_model( def load_onnx_model(self) -> None: raise NotImplementedError("Subclasses must implement this method") - def onnx_embed(self, *args, **kwargs) -> Union[OnnxOutputContext, List[float]]: + def onnx_embed(self, *args, **kwargs) -> OnnxOutputContext: raise NotImplementedError("Subclasses must implement this method") diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 92d74c57..89b87805 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -175,8 +175,8 @@ def __init__( self._model_dir = self.download_model( self.model_description, self.cache_dir, local_files_only=self._local_files_only ) - self.mask_token_id = int() - self.pad_token_id = int() + self.mask_token_id = None + self.pad_token_id = None self.skip_list = set() if not self.lazy_load: diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 85f9420c..b1a30e2e 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -4,7 +4,7 @@ import numpy as np from tokenizers import Encoding -from fastembed.common.onnx_model import OnnxModel, OnnxProvider +from fastembed.common.onnx_model import OnnxModel, OnnxProvider, OnnxOutputContext from fastembed.common.preprocessor_utils import load_tokenizer from fastembed.common.utils import iter_batch @@ -34,7 +34,7 @@ def _load_onnx_model( def tokenize(self, query: str, documents: List[str], **kwargs) -> List[Encoding]: return self.tokenizer.encode_batch([(query, doc) for doc in documents]) - def onnx_embed(self, query: str, documents: List[str], **kwargs) -> List[float]: + def onnx_embed(self, query: str, documents: List[str], **kwargs) -> OnnxOutputContext: tokenized_input = self.tokenize(query, documents, **kwargs) inputs = { @@ -51,7 +51,7 @@ def onnx_embed(self, query: str, documents: List[str], **kwargs) -> List[float]: onnx_input = self._preprocess_onnx_input(inputs, **kwargs) outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) - return outputs[0][:, 0].tolist() + return OnnxOutputContext(model_output=outputs[0][:, 0].tolist()) def _rerank_documents( self, query: str, documents: Iterable[str], batch_size: int, **kwargs @@ -59,7 +59,7 @@ def _rerank_documents( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(documents, batch_size): - yield from self.onnx_embed(query, batch, **kwargs) + yield from self.onnx_embed(query, batch, **kwargs).model_output def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], **kwargs diff --git a/fastembed/sparse/utils/tokenizer.py b/fastembed/sparse/utils/tokenizer.py index 9e57d2fd..22045759 100644 --- a/fastembed/sparse/utils/tokenizer.py +++ b/fastembed/sparse/utils/tokenizer.py @@ -5,8 +5,8 @@ class SimpleTokenizer: - @classmethod - def tokenize(cls, text: str) -> List[str]: + @staticmethod + def tokenize(text: str) -> List[str]: text = re.sub(r"[^\w]", " ", text.lower()) text = re.sub(r"\s+", " ", text)