diff --git a/fastembed/embedding.py b/fastembed/embedding.py index e539de95..52cf4207 100644 --- a/fastembed/embedding.py +++ b/fastembed/embedding.py @@ -1,3 +1,4 @@ +import functools import json import os import shutil @@ -7,13 +8,16 @@ from itertools import islice from multiprocessing import get_all_start_methods from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union import numpy as np import onnxruntime as ort import requests from tokenizers import AddedToken, Tokenizer from tqdm import tqdm +from huggingface_hub import snapshot_download +from huggingface_hub.utils import RepositoryNotFoundError +from loguru import logger from fastembed.parallel_processor import ParallelWorkerPool, Worker @@ -179,71 +183,44 @@ class Embedding(ABC): _type_: _description_ """ + # Internal helper decorator to maintain backward compatibility + # by supporting a fallback to download from Google Cloud Storage (GCS) + # if the model couldn't be downloaded from HuggingFace. + def gcs_fallback(hf_download_method: Callable) -> Callable: + @functools.wraps(hf_download_method) + def wrapper(self, *args, **kwargs): + try: + return hf_download_method(self, *args, **kwargs) + except (EnvironmentError, RepositoryNotFoundError, ValueError) as e: + logger.exception( + f"Could not download model from HuggingFace: {e}" + "Falling back to download from Google Cloud Storage" + ) + return self.retrieve_model_gcs(*args, **kwargs) + + return wrapper + @abstractmethod def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = None) -> List[np.ndarray]: raise NotImplementedError @classmethod - def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]: - """ - Lists the supported models. + def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Any]]: + """Lists the supported models. + + Args: + exclude (List[str], optional): Keys to exclude from the result. Defaults to []. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. """ - return [ - { - "model": "BAAI/bge-small-en", - "dim": 384, - "description": "Fast English model", - "size_in_GB": 0.2, - }, - { - "model": "BAAI/bge-small-en-v1.5", - "dim": 384, - "description": "Fast and Default English model", - "size_in_GB": 0.13, - }, - { - "model": "BAAI/bge-small-zh-v1.5", - "dim": 512, - "description": "Fast and recommended Chinese model", - "size_in_GB": 0.1, - }, - { - "model": "BAAI/bge-base-en", - "dim": 768, - "description": "Base English model", - "size_in_GB": 0.5, - }, - { - "model": "BAAI/bge-base-en-v1.5", - "dim": 768, - "description": "Base English model, v1.5", - "size_in_GB": 0.44, - }, - { - "model": "sentence-transformers/all-MiniLM-L6-v2", - "dim": 384, - "description": "Sentence Transformer model, MiniLM-L6-v2", - "size_in_GB": 0.09, - }, - { - "model": "intfloat/multilingual-e5-large", - "dim": 1024, - "description": "Multilingual model, e5-large. Recommend using this model for non-English languages", - "size_in_GB": 2.24, - }, - { - "model": "jinaai/jina-embeddings-v2-base-en", - "dim": 768, - "description": " English embedding model supporting 8192 sequence length", - "size_in_GB": 0.55, - }, - { - "model": "jinaai/jina-embeddings-v2-small-en", - "dim": 512, - "description": " English embedding model supporting 8192 sequence length", - "size_in_GB": 0.13, - }, - ] + models_file_path = Path(__file__).with_name("models.json") + with open(models_file_path, "r") as file: + models = json.load(file) + + models = [{k: v for k, v in model.items() if k not in exclude} for model in models] + + return models @classmethod def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str: @@ -276,48 +253,49 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool if total_size_in_bytes == 0: print(f"Warning: Content-length header is missing or zero in the response from {url}.") - # Initialize the progress bar - progress_bar = ( - tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - if total_size_in_bytes and show_progress - else None - ) + show_progress = total_size_in_bytes and show_progress - # Attempt to download the file - try: + with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar: with open(output_path, "wb") as file: - for chunk in response.iter_content(chunk_size=1024): # Adjust chunk size to your preference + for chunk in response.iter_content(chunk_size=1024): if chunk: # Filter out keep-alive new chunks - if progress_bar is not None: - progress_bar.update(len(chunk)) + progress_bar.update(len(chunk)) file.write(chunk) - except Exception as e: - print(f"An error occurred while trying to download the file: {str(e)}") - return - finally: - if progress_bar is not None: - progress_bar.close() return output_path @classmethod - def download_files_from_huggingface(cls, repod_id: str, cache_dir: Optional[str] = None) -> str: + def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[str] = None) -> str: """ Downloads a model from HuggingFace Hub. Args: - repod_id (str): The HF hub id (name) of the model to retrieve. + model_name (str): Name of the model to download. cache_dir (Optional[str]): The path to the cache directory. Raises: ValueError: If the model_name is not in the format / e.g. "jinaai/jina-embeddings-v2-small-en". Returns: Path: The path to the model directory. """ - from huggingface_hub import snapshot_download - - return snapshot_download( - repo_id=repod_id, - ignore_patterns=["model.safetensors", "pytorch_model.bin"], - cache_dir=cache_dir, - ) + models = cls.list_supported_models(exclude=["compressed_url_sources"]) + + hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]] + + # Check if the HF sources list is empty + # Raise an exception causing a fallback to GCS + if not hf_sources: + raise ValueError(f"No HuggingFace source for {model_name}") + + for index, repo_id in enumerate(hf_sources): + try: + return snapshot_download( + repo_id=repo_id, + ignore_patterns=["model.safetensors", "pytorch_model.bin"], + cache_dir=cache_dir, + ) + except (RepositoryNotFoundError, EnvironmentError) as e: + logger.exception(f"Failed to download model from HF source: {repo_id}: {e} ") + if repo_id == hf_sources[-1]: + raise e + logger.info(f"Trying another source: {hf_sources[index+1]}") @classmethod def decompress_to_cache(cls, targz_path: str, cache_dir: str): @@ -368,9 +346,6 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path: Returns: Path: The path to the model directory. """ - - assert "/" in model_name, "model_name must be in the format / e.g. BAAI/bge-base-en" - fast_model_name = f"fast-{model_name.split('/')[-1]}" model_dir = Path(cache_dir) / fast_model_name @@ -378,18 +353,29 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path: return model_dir model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz" - try: - self.download_file_from_gcs( - f"https://storage.googleapis.com/qdrant-fastembed/{fast_model_name}.tar.gz", - output_path=str(model_tar_gz), - ) - except PermissionError: - simple_model_name = model_name.replace("/", "-") - print(f"Was not able to download {fast_model_name}.tar.gz, trying {simple_model_name}.tar.gz") - self.download_file_from_gcs( - f"https://storage.googleapis.com/qdrant-fastembed/{simple_model_name}.tar.gz", - output_path=str(model_tar_gz), - ) + + models = self.list_supported_models(exclude=["hf_sources"]) + + compressed_url_sources = [ + item for model in models if model["model"] == model_name for item in model["compressed_url_sources"] + ] + + # Check if the GCS sources list is empty after falling back from HF + # A model should always have at least one source + if not compressed_url_sources: + raise ValueError(f"No GCS source for {model_name}") + + for index, source in enumerate(compressed_url_sources): + try: + self.download_file_from_gcs( + source, + output_path=str(model_tar_gz), + ) + except (RuntimeError, PermissionError) as e: + logger.exception(f"Failed to download model from GCS source: {source}: {e} ") + if source == compressed_url_sources[-1]: + raise e + logger.info(f"Trying another source: {compressed_url_sources[index+1]}") self.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=cache_dir) assert model_dir.exists(), f"Could not find {model_dir} in {cache_dir}" @@ -398,23 +384,31 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path: return model_dir + @gcs_fallback def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path: """ Retrieves a model from HuggingFace Hub. Args: model_name (str): The name of the model to retrieve. cache_dir (str): The path to the cache directory. - Raises: - ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. Returns: Path: The path to the model directory. """ - assert ( - "/" in model_name - ), "model_name must be in the format / e.g. jinaai/jina-embeddings-v2-small-en" + return Path(self.download_files_from_huggingface(model_name=model_name, cache_dir=cache_dir)) + + @classmethod + def assert_model_name(cls, model_name: str): + assert "/" in model_name, "model_name must be in the format / e.g. BAAI/bge-base-en" - return Path(self.download_files_from_huggingface(repod_id=model_name, cache_dir=cache_dir)) + models = cls.list_supported_models() + model_names = [model["model"] for model in models] + if model_name not in model_names: + raise ValueError( + f"{model_name} is not a supported model.\n" + f"Try one of {', '.join(model_names)}.\n" + f"Use the 'list_supported_models()' method to get the model information." + ) def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: """ @@ -475,6 +469,9 @@ def __init__( Raises: ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. """ + + self.assert_model_name(model_name) + self.model_name = model_name if cache_dir is None: @@ -483,7 +480,7 @@ def __init__( cache_dir.mkdir(parents=True, exist_ok=True) self._cache_dir = cache_dir - self._model_dir = self.retrieve_model_gcs(model_name, cache_dir) + self._model_dir = self.retrieve_model_hf(model_name, cache_dir) self._max_length = max_length self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length, max_threads=threads) @@ -536,12 +533,21 @@ def embed( yield from normalize(embeddings[:, 0]).astype(np.float32) @classmethod - def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]: - """ - Lists the supported models. + def list_supported_models( + cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"] + ) -> List[Dict[str, Any]]: + """Lists the supported models. + + Args: + exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"]. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. """ # jina models are not supported by this class - return [model for model in super().list_supported_models() if not model["model"].startswith("jinaai")] + return [ + model for model in super().list_supported_models(exclude=exclude) if not model["model"].startswith("jinaai") + ] class DefaultEmbedding(FlagEmbedding): @@ -591,8 +597,10 @@ def __init__( Defaults to `fastembed_cache` in the system's temp directory. threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. Raises: - ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. + ValueError: If the model_name is not in the format / e.g. jinaai/jina-embeddings-v2-base-en. """ + self.assert_model_name(model_name) + self.model_name = model_name if cache_dir is None: @@ -652,12 +660,21 @@ def embed( yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) @classmethod - def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]: - """ - Lists the supported models. + def list_supported_models( + cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"] + ) -> List[Dict[str, Any]]: + """Lists the supported models. + + Args: + exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"]. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. """ # only jina models are supported by this class - return [model for model in Embedding.list_supported_models() if model["model"].startswith("jinaai")] + return [ + model for model in Embedding.list_supported_models(exclude=exclude) if model["model"].startswith("jinaai") + ] @staticmethod def mean_pooling(model_output, attention_mask): diff --git a/fastembed/models.json b/fastembed/models.json new file mode 100644 index 00000000..1c4e7b5d --- /dev/null +++ b/fastembed/models.json @@ -0,0 +1,113 @@ +[ + { + "model": "BAAI/bge-base-en", + "dim": 768, + "description": "Base English model", + "size_in_GB": 0.5, + "hf_sources": [], + "compressed_url_sources": [ + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz" + ] + }, + { + "model": "BAAI/bge-base-en-v1.5", + "dim": 768, + "description": "Base English model, v1.5", + "size_in_GB": 0.44, + "hf_sources": [ + "qdrant/bge-base-en-v1.5-onnx-q" + ], + "compressed_url_sources": [ + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz" + ] + }, + { + "model": "BAAI/bge-large-en-v1.5", + "dim": 1024, + "description": "Large English model, v1.5", + "size_in_GB": 1.34, + "hf_sources": [ + "qdrant/bge-large-en-v1.5-onnx", + "qdrant/bge-large-en-v1.5-onnx-q" + ], + "compressed_url_sources": [] + }, + { + "model": "BAAI/bge-small-en", + "dim": 384, + "description": "Fast English model", + "size_in_GB": 0.2, + "hf_sources": [], + "compressed_url_sources": [ + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en.tar.gz", + "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz" + ] + }, + { + "model": "BAAI/bge-small-en-v1.5", + "dim": 384, + "description": "Fast and Default English model", + "size_in_GB": 0.13, + "hf_sources": [ + "qdrant/bge-small-en-v1.5-onnx-q" + ], + "compressed_url_sources": [ + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en-v1.5.tar.gz" + ] + }, + { + "model": "BAAI/bge-small-zh-v1.5", + "dim": 512, + "description": "Fast and recommended Chinese model", + "size_in_GB": 0.1, + "hf_sources": [], + "compressed_url_sources": [ + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz" + ] + }, + { + "model": "intfloat/multilingual-e5-large", + "dim": 1024, + "description": "Multilingual model, e5-large. Recommend using this model for non-English languages", + "size_in_GB": 2.24, + "hf_sources": [ + "qdrant/multilingual-e5-large-onnx" + ], + "compressed_url_sources": [ + "https://storage.googleapis.com/qdrant-fastembed/intfloat-multilingual-e5-large.tar.gz" + ] + }, + { + "model": "jinaai/jina-embeddings-v2-base-en", + "dim": 768, + "description": " English embedding model supporting 8192 sequence length", + "size_in_GB": 0.55, + "hf_sources": [ + "jinaai/jina-embeddings-v2-base-en" + ], + "compressed_url_sources": [] + }, + { + "model": "jinaai/jina-embeddings-v2-small-en", + "dim": 512, + "description": " English embedding model supporting 8192 sequence length", + "size_in_GB": 0.13, + "hf_sources": [ + "jinaai/jina-embeddings-v2-small-en" + ], + "compressed_url_sources": [] + }, + { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "dim": 384, + "description": "Sentence Transformer model, MiniLM-L6-v2", + "size_in_GB": 0.09, + "hf_sources": [ + "qdrant/all-MiniLM-L6-v2-onnx" + ], + "compressed_url_sources": [ + "https://storage.googleapis.com/qdrant-fastembed/fast-all-MiniLM-L6-v2.tar.gz", + "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz" + ] + } +] \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 103a84dc..465f2234 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "anyio" @@ -1272,6 +1272,24 @@ docs = ["autodoc-traits", "jinja2 (<3.2.0)", "mistune (<4)", "myst-parser", "pyd openapi = ["openapi-core (>=0.18.0,<0.19.0)", "ruamel-yaml"] test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.8.0)", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"] +[[package]] +name = "loguru" +version = "0.7.2" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +files = [ + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] + [[package]] name = "markdown" version = "3.5.2" @@ -1317,6 +1335,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2298,6 +2326,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2305,8 +2334,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2323,6 +2359,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2330,6 +2367,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3244,6 +3282,20 @@ docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"] optional = ["python-socks", "wsaccel"] test = ["websockets"] +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + [[package]] name = "zipp" version = "3.17.0" @@ -3262,4 +3314,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8.0,<3.12" -content-hash = "fc73a467844d8a9d92d44b20aa01a1e5187cb1cf7cb1a7e623b91f2b9f4bcb8f" +content-hash = "796aff6e2e0ea96f8751858afb7df5e7f04c0f4b751fd9028ab7b1eaa51c45a9" diff --git a/pyproject.toml b/pyproject.toml index 770f7a4f..b257e04c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ tqdm = "^4.65" requests = "^2.31" tokenizers = "^0.15.0" huggingface-hub = "0.19.4" +loguru = "^0.7.2" [tool.poetry.group.dev.dependencies] pytest = "^7.4.2" diff --git a/tests/test_onnx_embeddings.py b/tests/test_onnx_embeddings.py index c9af9e10..743dd3e5 100644 --- a/tests/test_onnx_embeddings.py +++ b/tests/test_onnx_embeddings.py @@ -11,6 +11,7 @@ "BAAI/bge-small-zh-v1.5": np.array([-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762]), "BAAI/bge-base-en": np.array([0.0115, 0.0372, 0.0295, 0.0121, 0.0346]), "BAAI/bge-base-en-v1.5": np.array([0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045]), + "BAAI/bge-large-en-v1.5": np.array([0.03434538, 0.03316108, 0.02191251, -0.03713358, -0.01577825]), "sentence-transformers/all-MiniLM-L6-v2": np.array([0.0259, 0.0058, 0.0114, 0.0380, -0.0233]), "intfloat/multilingual-e5-large": np.array([0.0098, 0.0045, 0.0066, -0.0354, 0.0070]), "jinaai/jina-embeddings-v2-small-en": np.array([-0.0455, -0.0428, -0.0122, 0.0613, 0.0015]),