From d0005d4bfd6baf51669856f4eaaf129da7a2d27e Mon Sep 17 00:00:00 2001 From: generall Date: Fri, 2 Feb 2024 10:03:10 +0100 Subject: [PATCH] ruff --- fastembed/__init__.py | 1 - fastembed/common/model_management.py | 24 ++---------- fastembed/common/models.py | 2 +- fastembed/embedding.py | 15 +++----- fastembed/text/e5_onnx_embedding.py | 9 ++--- fastembed/text/jina_onnx_embedding.py | 15 +++----- fastembed/text/onnx_embedding.py | 55 +++++++++++++-------------- fastembed/text/text_embedding.py | 20 +++++----- fastembed/text/text_embedding_base.py | 19 +++------ tests/test_text_onnx_embeddings.py | 12 +----- 10 files changed, 66 insertions(+), 106 deletions(-) diff --git a/fastembed/__init__.py b/fastembed/__init__.py index ae22e2d6..e69de29b 100644 --- a/fastembed/__init__.py +++ b/fastembed/__init__.py @@ -1 +0,0 @@ -from fastembed.text.text_embedding import TextEmbedding diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 2a7b6e59..e614a44f 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -28,7 +28,6 @@ def locate_model_file(model_dir: Path, file_names: List[str]): class ModelManagement: - @classmethod def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str: """ @@ -124,12 +123,7 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str): return cache_dir @classmethod - def retrieve_model_gcs( - cls, - model_name: str, - source_url: str, - cache_dir: str - ) -> Path: + def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> Path: fast_model_name = f"fast-{model_name.split('/')[-1]}" cache_tmp_dir = Path(cache_dir) / "tmp" @@ -191,21 +185,11 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path: if hf_source: try: - return Path(cls.download_files_from_huggingface( - hf_source, - cache_dir=str(cache_dir) - )) + return Path(cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir))) except (EnvironmentError, RepositoryNotFoundError, ValueError) as e: - logger.error( - f"Could not download model from HuggingFace: {e}" - "Falling back to other sources." - ) + logger.error(f"Could not download model from HuggingFace: {e}" "Falling back to other sources.") if url_source: - return cls.retrieve_model_gcs( - model["model"], - url_source, - str(cache_dir) - ) + return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)) raise ValueError(f"Could not download model {model['model']} from any source.") diff --git a/fastembed/common/models.py b/fastembed/common/models.py index 333e6ae9..74cbdab8 100644 --- a/fastembed/common/models.py +++ b/fastembed/common/models.py @@ -44,7 +44,7 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer: return tokenizer -def normalize(input_array, p=2, dim=1, eps= 1e-12) -> np.ndarray: +def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray: # Calculate the Lp norm along the specified dimension norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True) norm = np.maximum(norm, eps) # Avoid division by zero diff --git a/fastembed/embedding.py b/fastembed/embedding.py index 6e4f3a79..e7c207e6 100644 --- a/fastembed/embedding.py +++ b/fastembed/embedding.py @@ -4,10 +4,7 @@ from fastembed.text.text_embedding import TextEmbedding -logger.warning( - "DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated." - " Use TextEmbedding instead." -) +logger.warning("DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated." " Use TextEmbedding instead.") DefaultEmbedding = TextEmbedding FlagEmbedding = TextEmbedding @@ -15,10 +12,10 @@ class JinaEmbedding(TextEmbedding): def __init__( - self, - model_name: str = "jinaai/jina-embeddings-v2-base-en", - cache_dir: Optional[str] = None, - threads: Optional[int] = None, - **kwargs + self, + model_name: str = "jinaai/jina-embeddings-v2-base-en", + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs, ): super().__init__(model_name, cache_dir, threads, **kwargs) diff --git a/fastembed/text/e5_onnx_embedding.py b/fastembed/text/e5_onnx_embedding.py index b32c151e..0d4b8809 100644 --- a/fastembed/text/e5_onnx_embedding.py +++ b/fastembed/text/e5_onnx_embedding.py @@ -13,13 +13,12 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", "hf": "qdrant/multilingual-e5-large-onnx", - } + }, } ] class E5OnnxEmbedding(OnnxTextEmbedding): - @classmethod def _get_worker_class(cls) -> Type["EmbeddingWorker"]: return E5OnnxEmbeddingWorker @@ -43,8 +42,8 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str, class E5OnnxEmbeddingWorker(OnnxTextEmbeddingWorker): def init_embedding( - self, - model_name: str, - cache_dir: str, + self, + model_name: str, + cache_dir: str, ) -> E5OnnxEmbedding: return E5OnnxEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1) diff --git a/fastembed/text/jina_onnx_embedding.py b/fastembed/text/jina_onnx_embedding.py index db3e75ce..bf9b204c 100644 --- a/fastembed/text/jina_onnx_embedding.py +++ b/fastembed/text/jina_onnx_embedding.py @@ -11,22 +11,19 @@ "dim": 768, "description": "English embedding model supporting 8192 sequence length", "size_in_GB": 0.55, - "sources": { - "hf": "xenova/jina-embeddings-v2-base-en" - } + "sources": {"hf": "xenova/jina-embeddings-v2-base-en"}, }, { "model": "jinaai/jina-embeddings-v2-small-en", "dim": 512, "description": "English embedding model supporting 8192 sequence length", "size_in_GB": 0.13, - "sources": {"hf": "xenova/jina-embeddings-v2-small-en"} - } + "sources": {"hf": "xenova/jina-embeddings-v2-small-en"}, + }, ] class JinaOnnxEmbedding(OnnxTextEmbedding): - @classmethod def _get_worker_class(cls) -> Type[EmbeddingWorker]: return JinaEmbeddingWorker @@ -58,8 +55,8 @@ def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> np. class JinaEmbeddingWorker(OnnxTextEmbeddingWorker): def init_embedding( - self, - model_name: str, - cache_dir: str, + self, + model_name: str, + cache_dir: str, ) -> OnnxTextEmbedding: return JinaOnnxEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 8abfd3df..56132948 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -29,7 +29,7 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz", "hf": "qdrant/bge-base-en-v1.5-onnx-q", - } + }, }, { "model": "BAAI/bge-large-en-v1.5-quantized", @@ -38,7 +38,7 @@ "size_in_GB": 1.34, "sources": { "hf": "qdrant/bge-large-en-v1.5-onnx-q", - } + }, }, { "model": "BAAI/bge-large-en-v1.5", @@ -47,7 +47,7 @@ "size_in_GB": 1.34, "sources": { "hf": "qdrant/bge-large-en-v1.5-onnx", - } + }, }, { "model": "BAAI/bge-small-en", @@ -56,7 +56,7 @@ "size_in_GB": 0.2, "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz", - } + }, }, # { # "model": "BAAI/bge-small-en", @@ -77,7 +77,7 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en-v1.5.tar.gz", "hf": "qdrant/bge-small-en-v1.5-onnx-q", - } + }, }, { "model": "BAAI/bge-small-zh-v1.5", @@ -86,7 +86,7 @@ "size_in_GB": 0.1, "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz", - } + }, }, { # todo: it is not a flag embedding "model": "sentence-transformers/all-MiniLM-L6-v2", @@ -96,7 +96,7 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", "hf": "qdrant/all-MiniLM-L6-v2-onnx", - } + }, }, # { # "model": "sentence-transformers/all-MiniLM-L6-v2", @@ -147,11 +147,11 @@ def _get_model_description(cls, model_name: str) -> Dict[str, Any]: raise ValueError(f"Model {model_name} is not supported in FlagEmbedding.") def __init__( - self, - model_name: str = "BAAI/bge-small-en-v1.5", - cache_dir: str = None, - threads: int = None, - **kwargs, + self, + model_name: str = "BAAI/bge-small-en-v1.5", + cache_dir: str = None, + threads: int = None, + **kwargs, ): """ Args: @@ -190,11 +190,11 @@ def __init__( self.model = ort.InferenceSession(str(model_path), providers=onnx_providers, sess_options=so) def embed( - self, - documents: Union[str, Iterable[str]], - batch_size: int = 256, - parallel: int = None, - **kwargs, + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: int = None, + **kwargs, ) -> Iterable[np.ndarray]: """ Encode a list of documents into list of embeddings. @@ -260,7 +260,7 @@ def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]: onnx_input = { "input_ids": np.array(input_ids, dtype=np.int64), "attention_mask": np.array(attention_mask, dtype=np.int64), - "token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64) + "token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64), } onnx_input = self._preprocess_onnx_input(onnx_input) @@ -271,18 +271,17 @@ def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]: class EmbeddingWorker(Worker): - def init_embedding( - self, - model_name: str, - cache_dir: str, + self, + model_name: str, + cache_dir: str, ) -> OnnxTextEmbedding: raise NotImplementedError() def __init__( - self, - model_name: str, - cache_dir: str, + self, + model_name: str, + cache_dir: str, ): self.model = self.init_embedding(model_name, cache_dir) @@ -301,8 +300,8 @@ def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]] class OnnxTextEmbeddingWorker(EmbeddingWorker): def init_embedding( - self, - model_name: str, - cache_dir: str, + self, + model_name: str, + cache_dir: str, ) -> OnnxTextEmbedding: return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index d2f9f74b..6107044b 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -45,11 +45,11 @@ def list_supported_models(cls) -> List[Dict[str, Any]]: return result def __init__( - self, - model_name: str = "BAAI/bge-small-en-v1.5", - cache_dir: Optional[str] = None, - threads: Optional[int] = None, - **kwargs + self, + model_name: str = "BAAI/bge-small-en-v1.5", + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs, ): super().__init__(model_name, cache_dir, threads, **kwargs) @@ -67,11 +67,11 @@ def __init__( ) def embed( - self, - documents: Union[str, Iterable[str]], - batch_size: int = 256, - parallel: int = None, - **kwargs, + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: int = None, + **kwargs, ) -> Iterable[np.ndarray]: """ Encode a list of documents into list of embeddings. diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index 8e2de732..669abb5a 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -10,23 +10,17 @@ class TextEmbeddingBase(ModelManagement): def list_supported_models(cls) -> List[Dict[str, Any]]: raise NotImplementedError() - def __init__( - self, - model_name: str, - cache_dir: Optional[str] = None, - threads: Optional[int] = None, - **kwargs - ): + def __init__(self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, **kwargs): self.model_name = model_name self.cache_dir = cache_dir self.threads = threads def embed( - self, - documents: Union[str, Iterable[str]], - batch_size: int = 256, - parallel: int = None, - **kwargs, + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: int = None, + **kwargs, ) -> Iterable[np.ndarray]: raise NotImplementedError() @@ -59,4 +53,3 @@ def query_embed(self, query: str, **kwargs) -> np.ndarray: # This is model-specific, so that different models can have specialized implementations query_embedding = list(self.embed([query], **kwargs))[0] return query_embedding - diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e44dcf23..4dc74f62 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -40,11 +40,7 @@ def test_embedding(): @pytest.mark.parametrize( - "n_dims,model_name", - [ - (384, "BAAI/bge-small-en-v1.5"), - (768, "jinaai/jina-embeddings-v2-base-en") - ] + "n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")] ) def test_batch_embedding(n_dims, model_name): model = TextEmbedding(model_name=model_name) @@ -57,11 +53,7 @@ def test_batch_embedding(n_dims, model_name): @pytest.mark.parametrize( - "n_dims,model_name", - [ - (384, "BAAI/bge-small-en-v1.5"), - (768, "jinaai/jina-embeddings-v2-base-en") - ] + "n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")] ) def test_parallel_processing(n_dims, model_name): model = TextEmbedding(model_name=model_name)