Skip to content

Commit

Permalink
review suggestions (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Nov 13, 2024
1 parent 4ad2944 commit 39491ab
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 22 deletions.
16 changes: 2 additions & 14 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Dict,
Generic,
Iterable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
List,
)
from typing import Any, Dict, Generic, Iterable, Optional, Sequence, Tuple, Type, TypeVar

import numpy as np
import onnxruntime as ort
Expand Down Expand Up @@ -108,7 +96,7 @@ def _load_onnx_model(
def load_onnx_model(self) -> None:
raise NotImplementedError("Subclasses must implement this method")

def onnx_embed(self, *args, **kwargs) -> Union[OnnxOutputContext, List[float]]:
def onnx_embed(self, *args, **kwargs) -> OnnxOutputContext:
raise NotImplementedError("Subclasses must implement this method")


Expand Down
4 changes: 2 additions & 2 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def __init__(
self._model_dir = self.download_model(
self.model_description, self.cache_dir, local_files_only=self._local_files_only
)
self.mask_token_id = int()
self.pad_token_id = int()
self.mask_token_id = None
self.pad_token_id = None
self.skip_list = set()

if not self.lazy_load:
Expand Down
8 changes: 4 additions & 4 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from tokenizers import Encoding

from fastembed.common.onnx_model import OnnxModel, OnnxProvider
from fastembed.common.onnx_model import OnnxModel, OnnxProvider, OnnxOutputContext
from fastembed.common.preprocessor_utils import load_tokenizer
from fastembed.common.utils import iter_batch

Expand Down Expand Up @@ -34,7 +34,7 @@ def _load_onnx_model(
def tokenize(self, query: str, documents: List[str], **kwargs) -> List[Encoding]:
return self.tokenizer.encode_batch([(query, doc) for doc in documents])

def onnx_embed(self, query: str, documents: List[str], **kwargs) -> List[float]:
def onnx_embed(self, query: str, documents: List[str], **kwargs) -> OnnxOutputContext:
tokenized_input = self.tokenize(query, documents, **kwargs)

inputs = {
Expand All @@ -51,15 +51,15 @@ def onnx_embed(self, query: str, documents: List[str], **kwargs) -> List[float]:

onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
return outputs[0][:, 0].tolist()
return OnnxOutputContext(model_output=outputs[0][:, 0].tolist())

def _rerank_documents(
self, query: str, documents: Iterable[str], batch_size: int, **kwargs
) -> Iterable[float]:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self.onnx_embed(query, batch, **kwargs)
yield from self.onnx_embed(query, batch, **kwargs).model_output

def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], **kwargs
Expand Down
4 changes: 2 additions & 2 deletions fastembed/sparse/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class SimpleTokenizer:
@classmethod
def tokenize(cls, text: str) -> List[str]:
@staticmethod
def tokenize(text: str) -> List[str]:
text = re.sub(r"[^\w]", " ", text.lower())
text = re.sub(r"\s+", " ", text)

Expand Down

0 comments on commit 39491ab

Please sign in to comment.