From fb93e40c9e6fed0fb7bdf53cb453722796059cd0 Mon Sep 17 00:00:00 2001 From: generall Date: Fri, 2 Feb 2024 00:21:14 +0100 Subject: [PATCH] refactoring Co-authored-by: George Panchuk --- README.md | 4 +- fastembed/__init__.py | 1 + fastembed/common/__init__.py | 0 fastembed/common/model_management.py | 211 ++++++++ fastembed/common/models.py | 52 ++ fastembed/common/utils.py | 33 ++ fastembed/embedding.py | 700 +------------------------- fastembed/image/__init__.py | 0 fastembed/sparse/__init__.py | 0 fastembed/text/__init__.py | 0 fastembed/text/e5_onnx_embedding.py | 38 ++ fastembed/text/jina_onnx_embedding.py | 47 ++ fastembed/text/onnx_embedding.py | 207 ++++++++ fastembed/text/onnx_models.py | 133 +++++ fastembed/text/text_embedding.py | 91 ++++ fastembed/text/text_embedding_base.py | 62 +++ tests/test_onnx_embeddings.py | 3 + tests/test_text_onnx_embeddings.py | 81 +++ 18 files changed, 976 insertions(+), 687 deletions(-) create mode 100644 fastembed/__init__.py create mode 100644 fastembed/common/__init__.py create mode 100644 fastembed/common/model_management.py create mode 100644 fastembed/common/models.py create mode 100644 fastembed/common/utils.py create mode 100644 fastembed/image/__init__.py create mode 100644 fastembed/sparse/__init__.py create mode 100644 fastembed/text/__init__.py create mode 100644 fastembed/text/e5_onnx_embedding.py create mode 100644 fastembed/text/jina_onnx_embedding.py create mode 100644 fastembed/text/onnx_embedding.py create mode 100644 fastembed/text/onnx_models.py create mode 100644 fastembed/text/text_embedding.py create mode 100644 fastembed/text/text_embedding_base.py create mode 100644 tests/test_text_onnx_embeddings.py diff --git a/README.md b/README.md index 27e3cfdb..c28aa5cc 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ pip install fastembed ## 📖 Usage ```python -from fastembed.embedding import FlagEmbedding as Embedding +from fastembed import TextEmbedding from typing import List import numpy as np @@ -36,7 +36,7 @@ documents: List[str] = [ "passage: This is an example passage.", "fastembed is supported by and maintained by Qdrant." # You can leave out the prefix but it's recommended ] -embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512) +embedding_model = TextEmbedding(model_name="BAAI/bge-base-en") embeddings: List[np.ndarray] = list(embedding_model.embed(documents)) # Note the list() call - this is a generator ``` diff --git a/fastembed/__init__.py b/fastembed/__init__.py new file mode 100644 index 00000000..ae22e2d6 --- /dev/null +++ b/fastembed/__init__.py @@ -0,0 +1 @@ +from fastembed.text.text_embedding import TextEmbedding diff --git a/fastembed/common/__init__.py b/fastembed/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py new file mode 100644 index 00000000..c93f9aa9 --- /dev/null +++ b/fastembed/common/model_management.py @@ -0,0 +1,211 @@ +import os +import shutil +import tarfile +from pathlib import Path +from typing import List, Optional, Dict, Any + +import requests +from huggingface_hub import snapshot_download +from huggingface_hub.utils import RepositoryNotFoundError +from tqdm import tqdm +from loguru import logger + + +def locate_model_file(model_dir: Path, file_names: List[str]): + """ + Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used + by Optimum and Qdrant + """ + if not model_dir.is_dir(): + raise ValueError(f"Provided model path '{model_dir}' is not a directory.") + + for path in model_dir.rglob("*.onnx"): + for file_name in file_names: + if path.is_file() and path.name == file_name: + return path + + raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}") + + +class ModelManagement: + + @classmethod + def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str: + """ + Downloads a file from Google Cloud Storage. + + Args: + url (str): The URL to download the file from. + output_path (str): The path to save the downloaded file to. + show_progress (bool, optional): Whether to show a progress bar. Defaults to True. + + Returns: + str: The path to the downloaded file. + """ + + if os.path.exists(output_path): + return output_path + response = requests.get(url, stream=True) + + # Handle HTTP errors + if response.status_code == 403: + raise PermissionError( + "Authentication Error: You do not have permission to access this resource. " + "Please check your credentials." + ) + + # Get the total size of the file + total_size_in_bytes = int(response.headers.get("content-length", 0)) + + # Warn if the total size is zero + if total_size_in_bytes == 0: + print(f"Warning: Content-length header is missing or zero in the response from {url}.") + + show_progress = total_size_in_bytes and show_progress + + with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar: + with open(output_path, "wb") as file: + for chunk in response.iter_content(chunk_size=1024): + if chunk: # Filter out keep-alive new chunks + progress_bar.update(len(chunk)) + file.write(chunk) + return output_path + + @classmethod + def download_files_from_huggingface(cls, hf_source_repo: str, cache_dir: Optional[str] = None) -> str: + """ + Downloads a model from HuggingFace Hub. + Args: + hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx". + cache_dir (Optional[str]): The path to the cache directory. + Returns: + Path: The path to the model directory. + """ + + return snapshot_download( + repo_id=hf_source_repo, + ignore_patterns=["model.safetensors", "pytorch_model.bin"], + cache_dir=cache_dir, + ) + + @classmethod + def decompress_to_cache(cls, targz_path: str, cache_dir: str): + """ + Decompresses a .tar.gz file to a cache directory. + + Args: + targz_path (str): Path to the .tar.gz file. + cache_dir (str): Path to the cache directory. + + Returns: + cache_dir (str): Path to the cache directory. + """ + # Check if targz_path exists and is a file + if not os.path.isfile(targz_path): + raise ValueError(f"{targz_path} does not exist or is not a file.") + + # Check if targz_path is a .tar.gz file + if not targz_path.endswith(".tar.gz"): + raise ValueError(f"{targz_path} is not a .tar.gz file.") + + try: + # Open the tar.gz file + with tarfile.open(targz_path, "r:gz") as tar: + # Extract all files into the cache directory + tar.extractall(path=cache_dir) + except tarfile.TarError as e: + # If any error occurs while opening or extracting the tar.gz file, + # delete the cache directory (if it was created in this function) + # and raise the error again + if "tmp" in cache_dir: + shutil.rmtree(cache_dir) + raise ValueError(f"An error occurred while decompressing {targz_path}: {e}") + + return cache_dir + + @classmethod + def retrieve_model_gcs( + cls, + model_name: str, + source_url: str, + cache_dir: str + ) -> Path: + fast_model_name = f"fast-{model_name.split('/')[-1]}" + + cache_tmp_dir = Path(cache_dir) / "tmp" + model_tmp_dir = cache_tmp_dir / fast_model_name + model_dir = Path(cache_dir) / fast_model_name + + if model_dir.exists(): + return model_dir + + if model_tmp_dir.exists(): + shutil.rmtree(model_tmp_dir) + + cache_tmp_dir.mkdir(parents=True, exist_ok=True) + + model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz" + + cls.download_file_from_gcs( + source_url, + output_path=str(model_tar_gz), + ) + + cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir)) + assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}" + + model_tar_gz.unlink() + # Rename from tmp to final name is atomic + model_tmp_dir.rename(model_dir) + + return model_dir + + @classmethod + def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path: + """ + Downloads a model from HuggingFace Hub or Google Cloud Storage. + + Args: + model (Dict[str, Any]): The model description. + Example: + ``` + { + "model": "BAAI/bge-base-en-v1.5", + "dim": 768, + "description": "Base English model, v1.5", + "size_in_GB": 0.44, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz", + "hf": "qdrant/bge-base-en-v1.5-onnx-q", + } + } + ``` + cache_dir (str): The path to the cache directory. + + Returns: + Path: The path to the downloaded model directory. + """ + + hf_source = model.get("sources", {}).get("hf") + gcp_source = model.get("sources", {}).get("gcp") + + if hf_source: + try: + return Path(cls.download_files_from_huggingface( + hf_source, + cache_dir=str(cache_dir) + )) + except (EnvironmentError, RepositoryNotFoundError, ValueError) as e: + logger.error( + f"Could not download model from HuggingFace: {e}" + "Falling back to other sources." + ) + + if gcp_source: + return cls.retrieve_model_gcs( + model["model"], + gcp_source, + str(cache_dir) + ) + + raise ValueError(f"Could not download model {model['model']} from any source.") diff --git a/fastembed/common/models.py b/fastembed/common/models.py new file mode 100644 index 00000000..333e6ae9 --- /dev/null +++ b/fastembed/common/models.py @@ -0,0 +1,52 @@ +import json +from pathlib import Path + +import numpy as np +from tokenizers import Tokenizer, AddedToken + + +def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer: + config_path = model_dir / "config.json" + if not config_path.exists(): + raise ValueError(f"Could not find config.json in {model_dir}") + + tokenizer_path = model_dir / "tokenizer.json" + if not tokenizer_path.exists(): + raise ValueError(f"Could not find tokenizer.json in {model_dir}") + + tokenizer_config_path = model_dir / "tokenizer_config.json" + if not tokenizer_config_path.exists(): + raise ValueError(f"Could not find tokenizer_config.json in {model_dir}") + + tokens_map_path = model_dir / "special_tokens_map.json" + if not tokens_map_path.exists(): + raise ValueError(f"Could not find special_tokens_map.json in {model_dir}") + + with open(str(config_path)) as config_file: + config = json.load(config_file) + + with open(str(tokenizer_config_path)) as tokenizer_config_file: + tokenizer_config = json.load(tokenizer_config_file) + + with open(str(tokens_map_path)) as tokens_map_file: + tokens_map = json.load(tokens_map_file) + + tokenizer = Tokenizer.from_file(str(tokenizer_path)) + tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length)) + tokenizer.enable_padding(pad_id=config["pad_token_id"], pad_token=tokenizer_config["pad_token"]) + + for token in tokens_map.values(): + if isinstance(token, str): + tokenizer.add_special_tokens([token]) + elif isinstance(token, dict): + tokenizer.add_special_tokens([AddedToken(**token)]) + + return tokenizer + + +def normalize(input_array, p=2, dim=1, eps= 1e-12) -> np.ndarray: + # Calculate the Lp norm along the specified dimension + norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True) + norm = np.maximum(norm, eps) # Avoid division by zero + normalized_array = input_array / norm + return normalized_array diff --git a/fastembed/common/utils.py b/fastembed/common/utils.py new file mode 100644 index 00000000..89f881b2 --- /dev/null +++ b/fastembed/common/utils.py @@ -0,0 +1,33 @@ +import os +import tempfile +from itertools import islice +from pathlib import Path +from typing import Union, Iterable, Generator, Optional + + +def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable: + """ + >>> list(iter_batch([1,2,3,4,5], 3)) + [[1, 2, 3], [4, 5]] + """ + source_iter = iter(iterable) + while source_iter: + b = list(islice(source_iter, size)) + if len(b) == 0: + break + yield b + + +def define_cache_dir(cache_dir: Optional[str] = None) -> Path: + """ + Define the cache directory for fastembed + """ + if cache_dir is None: + default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache") + cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir)) + else: + cache_dir = Path(cache_dir) + + cache_dir.mkdir(parents=True, exist_ok=True) + + return cache_dir diff --git a/fastembed/embedding.py b/fastembed/embedding.py index 620b5d5e..6e4f3a79 100644 --- a/fastembed/embedding.py +++ b/fastembed/embedding.py @@ -1,694 +1,24 @@ -import functools -import json -import os -import shutil -import tarfile -import tempfile -from abc import ABC, abstractmethod -from itertools import islice -from multiprocessing import get_all_start_methods -from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union +from typing import Optional -import numpy as np -import onnxruntime as ort -import requests -from tokenizers import AddedToken, Tokenizer -from tqdm import tqdm -from huggingface_hub import snapshot_download -from huggingface_hub.utils import RepositoryNotFoundError from loguru import logger -from fastembed.parallel_processor import ParallelWorkerPool, Worker +from fastembed.text.text_embedding import TextEmbedding +logger.warning( + "DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated." + " Use TextEmbedding instead." +) -def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable: - """ - >>> list(iter_batch([1,2,3,4,5], 3)) - [[1, 2, 3], [4, 5]] - """ - source_iter = iter(iterable) - while source_iter: - b = list(islice(source_iter, size)) - if len(b) == 0: - break - yield b +DefaultEmbedding = TextEmbedding +FlagEmbedding = TextEmbedding -def locate_model_file(model_dir: Path, file_names: List[str]): - """ - Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used by Optimum and Qdrant - """ - if not model_dir.is_dir(): - raise ValueError(f"Provided model path '{model_dir}' is not a directory.") - - for path in model_dir.rglob("*.onnx"): - for file_name in file_names: - if path.is_file() and path.name == file_name: - return path - - raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}") - - -def normalize(input_array, p=2, dim=1, eps=1e-12): - # Calculate the Lp norm along the specified dimension - norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True) - norm = np.maximum(norm, eps) # Avoid division by zero - normalized_array = input_array / norm - return normalized_array - - -class EmbeddingModel: - @classmethod - def load_tokenizer(cls, model_dir: Path, max_length: int = 512) -> Tokenizer: - config_path = model_dir / "config.json" - if not config_path.exists(): - raise ValueError(f"Could not find config.json in {model_dir}") - - tokenizer_path = model_dir / "tokenizer.json" - if not tokenizer_path.exists(): - raise ValueError(f"Could not find tokenizer.json in {model_dir}") - - tokenizer_config_path = model_dir / "tokenizer_config.json" - if not tokenizer_config_path.exists(): - raise ValueError(f"Could not find tokenizer_config.json in {model_dir}") - - tokens_map_path = model_dir / "special_tokens_map.json" - if not tokens_map_path.exists(): - raise ValueError(f"Could not find special_tokens_map.json in {model_dir}") - - with open(str(config_path)) as config_file: - config = json.load(config_file) - - with open(str(tokenizer_config_path)) as tokenizer_config_file: - tokenizer_config = json.load(tokenizer_config_file) - - with open(str(tokens_map_path)) as tokens_map_file: - tokens_map = json.load(tokens_map_file) - - tokenizer = Tokenizer.from_file(str(tokenizer_path)) - tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length)) - tokenizer.enable_padding(pad_id=config["pad_token_id"], pad_token=tokenizer_config["pad_token"]) - - for token in tokens_map.values(): - if isinstance(token, str): - tokenizer.add_special_tokens([token]) - elif isinstance(token, dict): - tokenizer.add_special_tokens([AddedToken(**token)]) - - return tokenizer - +class JinaEmbedding(TextEmbedding): def __init__( - self, - path: Path, - model_name: str, - max_length: int = 512, - max_threads: int = None, + self, + model_name: str = "jinaai/jina-embeddings-v2-base-en", + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs ): - self.path = path - self.model_name = model_name - model_path = locate_model_file(self.path, ["model.onnx", "model_optimized.onnx"]) - - # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers - onnx_providers = ["CPUExecutionProvider"] - - # Hacky support for multilingual model - self.exclude_token_type_ids = False - if model_name == "intfloat/multilingual-e5-large": - self.exclude_token_type_ids = True - - so = ort.SessionOptions() - so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - - if max_threads is not None: - so.intra_op_num_threads = max_threads - so.inter_op_num_threads = max_threads - - self.tokenizer = self.load_tokenizer(self.path, max_length=max_length) - self.model = ort.InferenceSession(str(model_path), providers=onnx_providers, sess_options=so) - - def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]: - encoded = self.tokenizer.encode_batch(documents) - input_ids = np.array([e.ids for e in encoded]) - attention_mask = np.array([e.attention_mask for e in encoded]) - - onnx_input = { - "input_ids": np.array(input_ids, dtype=np.int64), - "attention_mask": np.array(attention_mask, dtype=np.int64), - } - - if not self.exclude_token_type_ids: - onnx_input["token_type_ids"] = np.array( - [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 - ) - - model_output = self.model.run(None, onnx_input) - embeddings = model_output[0] - return embeddings, attention_mask - - -class EmbeddingWorker(Worker): - def __init__( - self, - path: Path, - model_name: str, - max_length: int = 512, - ): - self.model = EmbeddingModel(path=path, model_name=model_name, max_length=max_length, max_threads=1) - - @classmethod - def start(cls, path: Path, model_name: str, max_length: int = 512, **kwargs: Any) -> "EmbeddingWorker": - return cls( - path=path, - model_name=model_name, - max_length=max_length, - ) - - def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]]: - for idx, batch in items: - embeddings, attn_mask = self.model.onnx_embed(batch) - yield idx, (embeddings, attn_mask) - - -class Embedding(ABC): - """ - Abstract class for embeddings. - - Inherits: - ABC: Abstract base class - - Raises: - NotImplementedError: Raised when you call an abstract method that has not been implemented. - PermissionError: _description_ - ValueError: Several possible reasons: 1) targz_path does not exist or is not a file, 2) targz_path is not a .tar.gz file, 3) An error occurred while decompressing targz_path, 4) Could not find model_dir in cache_dir, 5) Could not find tokenizer.json in model_dir, 6) Could not find model.onnx in model_dir. - NotImplementedError: _description_ - - Returns: - _type_: _description_ - - Yields: - _type_: _description_ - """ - - # Internal helper decorator to maintain backward compatibility - # by supporting a fallback to download from Google Cloud Storage (GCS) - # if the model couldn't be downloaded from HuggingFace. - def gcs_fallback(hf_download_method: Callable) -> Callable: - @functools.wraps(hf_download_method) - def wrapper(self, *args, **kwargs): - try: - return hf_download_method(self, *args, **kwargs) - except (EnvironmentError, RepositoryNotFoundError, ValueError) as e: - logger.exception( - f"Could not download model from HuggingFace: {e}" - "Falling back to download from Google Cloud Storage" - ) - return self.retrieve_model_gcs(*args, **kwargs) - - return wrapper - - @abstractmethod - def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = None) -> List[np.ndarray]: - raise NotImplementedError - - @classmethod - def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Any]]: - """Lists the supported models. - - Args: - exclude (List[str], optional): Keys to exclude from the result. Defaults to []. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing the model information. - """ - models_file_path = Path(__file__).with_name("models.json") - with open(models_file_path, "r") as file: - models = json.load(file) - - models = [{k: v for k, v in model.items() if k not in exclude} for model in models] - - return models - - @classmethod - def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str: - """ - Downloads a file from Google Cloud Storage. - - Args: - url (str): The URL to download the file from. - output_path (str): The path to save the downloaded file to. - show_progress (bool, optional): Whether to show a progress bar. Defaults to True. - - Returns: - str: The path to the downloaded file. - """ - - if os.path.exists(output_path): - return output_path - response = requests.get(url, stream=True) - - # Handle HTTP errors - if response.status_code == 403: - raise PermissionError( - "Authentication Error: You do not have permission to access this resource. Please check your credentials." - ) - - # Get the total size of the file - total_size_in_bytes = int(response.headers.get("content-length", 0)) - - # Warn if the total size is zero - if total_size_in_bytes == 0: - print(f"Warning: Content-length header is missing or zero in the response from {url}.") - - show_progress = total_size_in_bytes and show_progress - - with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar: - with open(output_path, "wb") as file: - for chunk in response.iter_content(chunk_size=1024): - if chunk: # Filter out keep-alive new chunks - progress_bar.update(len(chunk)) - file.write(chunk) - return output_path - - @classmethod - def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[str] = None) -> str: - """ - Downloads a model from HuggingFace Hub. - Args: - model_name (str): Name of the model to download. - cache_dir (Optional[str]): The path to the cache directory. - Raises: - ValueError: If the model_name is not in the format / e.g. "jinaai/jina-embeddings-v2-small-en". - Returns: - Path: The path to the model directory. - """ - models = cls.list_supported_models(exclude=["compressed_url_sources"]) - - hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]] - - # Check if the HF sources list is empty - # Raise an exception causing a fallback to GCS - if not hf_sources: - raise ValueError(f"No HuggingFace source for {model_name}") - - for index, repo_id in enumerate(hf_sources): - try: - return snapshot_download( - repo_id=repo_id, - ignore_patterns=["model.safetensors", "pytorch_model.bin"], - cache_dir=cache_dir, - ) - except (RepositoryNotFoundError, EnvironmentError) as e: - logger.exception(f"Failed to download model from HF source: {repo_id}: {e} ") - if repo_id == hf_sources[-1]: - raise e - logger.info(f"Trying another source: {hf_sources[index+1]}") - - @classmethod - def decompress_to_cache(cls, targz_path: str, cache_dir: str): - """ - Decompresses a .tar.gz file to a cache directory. - - Args: - targz_path (str): Path to the .tar.gz file. - cache_dir (str): Path to the cache directory. - - Returns: - cache_dir (str): Path to the cache directory. - """ - # Check if targz_path exists and is a file - if not os.path.isfile(targz_path): - raise ValueError(f"{targz_path} does not exist or is not a file.") - - # Check if targz_path is a .tar.gz file - if not targz_path.endswith(".tar.gz"): - raise ValueError(f"{targz_path} is not a .tar.gz file.") - - try: - # Open the tar.gz file - with tarfile.open(targz_path, "r:gz") as tar: - # Extract all files into the cache directory - tar.extractall(path=cache_dir) - except tarfile.TarError as e: - # If any error occurs while opening or extracting the tar.gz file, - # delete the cache directory (if it was created in this function) - # and raise the error again - if "tmp" in cache_dir: - shutil.rmtree(cache_dir) - raise ValueError(f"An error occurred while decompressing {targz_path}: {e}") - - return cache_dir - - def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path: - """ - Retrieves a model from Google Cloud Storage. - - Args: - model_name (str): The name of the model to retrieve. - cache_dir (str): The path to the cache directory. - - Raises: - ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. - - Returns: - Path: The path to the model directory. - """ - fast_model_name = f"fast-{model_name.split('/')[-1]}" - - model_dir = Path(cache_dir) / fast_model_name - if model_dir.exists(): - return model_dir - - model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz" - - models = self.list_supported_models(exclude=["hf_sources"]) - - compressed_url_sources = [ - item for model in models if model["model"] == model_name for item in model["compressed_url_sources"] - ] - - # Check if the GCS sources list is empty after falling back from HF - # A model should always have at least one source - if not compressed_url_sources: - raise ValueError(f"No GCS source for {model_name}") - - for index, source in enumerate(compressed_url_sources): - try: - self.download_file_from_gcs( - source, - output_path=str(model_tar_gz), - ) - except (RuntimeError, PermissionError) as e: - logger.exception(f"Failed to download model from GCS source: {source}: {e} ") - if source == compressed_url_sources[-1]: - raise e - logger.info(f"Trying another source: {compressed_url_sources[index+1]}") - - self.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=cache_dir) - assert model_dir.exists(), f"Could not find {model_dir} in {cache_dir}" - - model_tar_gz.unlink() - - return model_dir - - @gcs_fallback - def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path: - """ - Retrieves a model from HuggingFace Hub. - Args: - model_name (str): The name of the model to retrieve. - cache_dir (str): The path to the cache directory. - Returns: - Path: The path to the model directory. - """ - - return Path(self.download_files_from_huggingface(model_name=model_name, cache_dir=cache_dir)) - - @classmethod - def assert_model_name(cls, model_name: str): - assert "/" in model_name, "model_name must be in the format / e.g. BAAI/bge-base-en" - - models = cls.list_supported_models() - model_names = [model["model"] for model in models] - if model_name not in model_names: - raise ValueError( - f"{model_name} is not a supported model.\n" - f"Try one of {', '.join(model_names)}.\n" - f"Use the 'list_supported_models()' method to get the model information." - ) - - def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: - """ - Embeds a list of text passages into a list of embeddings. - - Args: - texts (Iterable[str]): The list of texts to embed. - **kwargs: Additional keyword argument to pass to the embed method. - - Yields: - Iterable[np.ndarray]: The embeddings. - """ - - yield from self.embed((f"passage: {t}" for t in texts), **kwargs) - - def query_embed(self, query: str) -> Iterable[np.ndarray]: - """ - Embeds a query - - Args: - query (str): The query to search for. - - Returns: - Iterable[np.ndarray]: The embeddings. - """ - - # Prepend "query: " to the query - query = f"query: {query}" - # Embed the query - query_embedding = self.embed([query]) - return query_embedding - - -class FlagEmbedding(Embedding): - """ - Implementation of the Flag Embedding model. - - Args: - Embedding (_type_): _description_ - """ - - def __init__( - self, - model_name: str = "BAAI/bge-small-en-v1.5", - max_length: int = 512, - cache_dir: str = None, - threads: int = None, - ): - """ - Args: - model_name (str): The name of the model to use. - max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512. - cache_dir (str, optional): The path to the cache directory. - Can be set using the `FASTEMBED_CACHE_PATH` env variable. - Defaults to `fastembed_cache` in the system's temp directory. - threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. - - Raises: - ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. - """ - - self.assert_model_name(model_name) - - self.model_name = model_name - - if cache_dir is None: - default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache") - cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir)) - cache_dir.mkdir(parents=True, exist_ok=True) - - self._cache_dir = cache_dir - self._model_dir = self.retrieve_model_hf(model_name, cache_dir) - self._max_length = max_length - - self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length, max_threads=threads) - - def embed( - self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None - ) -> Iterable[np.ndarray]: - """ - Encode a list of documents into list of embeddings. - We use mean pooling with attention so that the model can handle variable-length inputs. - - Args: - documents: Iterator of documents or single document to embed - batch_size: Batch size for encoding -- higher values will use more memory, but be faster - parallel: - If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. - If 0, use all available cores. - If None, don't use data-parallel processing, use default onnxruntime threading instead. - - Returns: - List of embeddings, one per document - """ - is_small = False - - if isinstance(documents, str): - documents = [documents] - is_small = True - - if isinstance(documents, list): - if len(documents) < batch_size: - is_small = True - - if parallel == 0: - parallel = os.cpu_count() - - if parallel is None or is_small: - for batch in iter_batch(documents, batch_size): - embeddings, _ = self.model.onnx_embed(batch) - yield from normalize(embeddings[:, 0]).astype(np.float32) - else: - start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" - params = { - "path": self._model_dir, - "model_name": self.model_name, - "max_length": self._max_length, - } - pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method) - for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): - embeddings, _ = batch - yield from normalize(embeddings[:, 0]).astype(np.float32) - - @classmethod - def list_supported_models( - cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"] - ) -> List[Dict[str, Any]]: - """Lists the supported models. - - Args: - exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"]. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing the model information. - """ - # jina models are not supported by this class - return [ - model for model in super().list_supported_models(exclude=exclude) if not model["model"].startswith("jinaai") - ] - - -class DefaultEmbedding(FlagEmbedding): - """ - Implementation of the default Flag Embedding model. - - Args: - FlagEmbedding (_type_): _description_ - """ - - def __init__( - self, - model_name: str = "BAAI/bge-small-en-v1.5", - max_length: int = 512, - cache_dir: Optional[str] = None, - threads: Optional[int] = None, - ): - super().__init__(model_name, max_length=max_length, cache_dir=cache_dir, threads=threads) - - -class OpenAIEmbedding(Embedding): - def __init__(self): - # Initialize your OpenAI model here - # self.model = ... - ... - - def embed(self, texts, batch_size: int = 256, parallel: int = None): - # Use your OpenAI model to embed the texts - # return self.model.embed(texts) - raise NotImplementedError - - -class JinaEmbedding(Embedding): - def __init__( - self, - model_name: str = "jinaai/jina-embeddings-v2-base-en", - max_length: int = 512, - cache_dir: str = None, - threads: int = None, - ): - """ - Args: - model_name (str): The name of the model to use. - max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512. - cache_dir (str, optional): The path to the cache directory. - Can be set using the `FASTEMBED_CACHE_PATH` env variable. - Defaults to `fastembed_cache` in the system's temp directory. - threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. - Raises: - ValueError: If the model_name is not in the format / e.g. jinaai/jina-embeddings-v2-base-en. - """ - self.assert_model_name(model_name) - - self.model_name = model_name - - if cache_dir is None: - default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache") - cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir)) - cache_dir.mkdir(parents=True, exist_ok=True) - - self._cache_dir = cache_dir - self._model_dir = self.retrieve_model_hf(model_name, cache_dir) - self._max_length = max_length - - self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length, max_threads=threads) - - def embed( - self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None - ) -> Iterable[np.ndarray]: - """ - Encode a list of documents into list of embeddings. - We use mean pooling with attention so that the model can handle variable-length inputs. - Args: - documents: Iterator of documents or single document to embed - batch_size: Batch size for encoding -- higher values will use more memory, but be faster - parallel: - If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. - If 0, use all available cores. - If None, don't use data-parallel processing, use default onnxruntime threading instead. - Returns: - List of embeddings, one per document - """ - is_small = False - - if isinstance(documents, str): - documents = [documents] - is_small = True - - if isinstance(documents, list): - if len(documents) < batch_size: - is_small = True - - if parallel == 0: - parallel = os.cpu_count() - - if parallel is None or is_small: - for batch in iter_batch(documents, batch_size): - embeddings, attn_mask = self.model.onnx_embed(batch) - yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) - else: - start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" - params = { - "path": self._model_dir, - "model_name": self.model_name, - "max_length": self._max_length, - } - pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method) - for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): - embeddings, attn_mask = batch - yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) - - @classmethod - def list_supported_models( - cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"] - ) -> List[Dict[str, Any]]: - """Lists the supported models. - - Args: - exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"]. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing the model information. - """ - # only jina models are supported by this class - return [ - model for model in Embedding.list_supported_models(exclude=exclude) if model["model"].startswith("jinaai") - ] - - @staticmethod - def mean_pooling(model_output, attention_mask): - token_embeddings = model_output - input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float) - - sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) - mask_sum = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None) - - return sum_embeddings / mask_sum + super().__init__(model_name, cache_dir, threads, **kwargs) diff --git a/fastembed/image/__init__.py b/fastembed/image/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastembed/sparse/__init__.py b/fastembed/sparse/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastembed/text/__init__.py b/fastembed/text/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastembed/text/e5_onnx_embedding.py b/fastembed/text/e5_onnx_embedding.py new file mode 100644 index 00000000..949e4de3 --- /dev/null +++ b/fastembed/text/e5_onnx_embedding.py @@ -0,0 +1,38 @@ +from typing import Type, List, Dict, Any + +import numpy as np + +from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker, EmbeddingWorker +from fastembed.text.onnx_models import supported_multilingual_e5_models + + +class E5OnnxEmbedding(OnnxTextEmbedding): + + @classmethod + def _get_worker_class(cls) -> Type["EmbeddingWorker"]: + return E5OnnxEmbeddingWorker + + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + """Lists the supported models. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_multilingual_e5_models + + def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Preprocess the onnx input. + """ + onnx_input.pop("token_type_ids", None) + return onnx_input + + +class E5OnnxEmbeddingWorker(OnnxTextEmbeddingWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + ) -> E5OnnxEmbedding: + return E5OnnxEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1) diff --git a/fastembed/text/jina_onnx_embedding.py b/fastembed/text/jina_onnx_embedding.py new file mode 100644 index 00000000..89164db6 --- /dev/null +++ b/fastembed/text/jina_onnx_embedding.py @@ -0,0 +1,47 @@ +from typing import Type, List, Dict, Any, Tuple + +import numpy as np + +from fastembed.common.models import normalize +from fastembed.text.onnx_embedding import OnnxTextEmbedding, EmbeddingWorker, OnnxTextEmbeddingWorker +from fastembed.text.onnx_models import supported_jina_models + + +class JinaOnnxEmbedding(OnnxTextEmbedding): + + @classmethod + def _get_worker_class(cls) -> Type[EmbeddingWorker]: + return JinaEmbeddingWorker + + @classmethod + def mean_pooling(cls, model_output, attention_mask) -> np.ndarray: + token_embeddings = model_output + input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float) + + sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) + mask_sum = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None) + + return sum_embeddings / mask_sum + + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + """Lists the supported models. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_jina_models + + @classmethod + def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + embeddings, attn_mask = output + return normalize(cls.mean_pooling(embeddings, attn_mask)).astype(np.float32) + + +class JinaEmbeddingWorker(OnnxTextEmbeddingWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + ) -> OnnxTextEmbedding: + return JinaOnnxEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py new file mode 100644 index 00000000..00100b06 --- /dev/null +++ b/fastembed/text/onnx_embedding.py @@ -0,0 +1,207 @@ +import os +from multiprocessing import get_all_start_methods +from typing import List, Dict, Any, Tuple, Union, Iterable, Type + +import numpy as np +import onnxruntime as ort + +from fastembed.common.model_management import locate_model_file +from fastembed.common.models import load_tokenizer, normalize +from fastembed.common.utils import define_cache_dir, iter_batch +from fastembed.parallel_processor import ParallelWorkerPool, Worker +from fastembed.text.onnx_models import supported_flag_models +from fastembed.text.text_embedding_base import TextEmbeddingBase + + +class OnnxTextEmbedding(TextEmbeddingBase): + """Implementation of the Flag Embedding model.""" + + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + """Lists the supported models. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_flag_models + + @classmethod + def _get_model_description(cls, model_name: str) -> Dict[str, Any]: + """ + Gets the model description from the model_name. + + Args: + model_name (str): The name of the model. + + raises: + ValueError: If the model_name is not supported. + + Returns: + Dict[str, Any]: The model description. + """ + for model in cls.list_supported_models(): + if model_name == model["model"]: + return model + + raise ValueError(f"Model {model_name} is not supported in FlagEmbedding.") + + def __init__( + self, + model_name: str = "BAAI/bge-small-en-v1.5", + cache_dir: str = None, + threads: int = None, + **kwargs, + ): + """ + Args: + model_name (str): The name of the model to use. + cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. + + Raises: + ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. + """ + + super().__init__(model_name, cache_dir, threads, **kwargs) + + self.model_name = model_name + self._model_description = self._get_model_description(model_name) + + self._cache_dir = define_cache_dir(cache_dir) + self._model_dir = self.download_model(self._model_description, self._cache_dir) + self._max_length = 512 + + model_path = locate_model_file(self._model_dir, ["model.onnx", "model_optimized.onnx"]) + + # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers + onnx_providers = ["CPUExecutionProvider"] + + so = ort.SessionOptions() + so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + if self.threads is not None: + so.intra_op_num_threads = self.threads + so.inter_op_num_threads = self.threads + + self.tokenizer = load_tokenizer(model_dir=self._model_dir, max_length=self._max_length) + self.model = ort.InferenceSession(str(model_path), providers=onnx_providers, sess_options=so) + + def embed( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: int = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + We use mean pooling with attention so that the model can handle variable-length inputs. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + is_small = False + + if isinstance(documents, str): + documents = [documents] + is_small = True + + if isinstance(documents, list): + if len(documents) < batch_size: + is_small = True + + if parallel == 0: + parallel = os.cpu_count() + + if parallel is None or is_small: + for batch in iter_batch(documents, batch_size): + yield from self._post_process_onnx_output(self.onnx_embed(batch)) + else: + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" + params = { + "model_name": self.model_name, + "cache_dir": str(self._cache_dir), + } + pool = ParallelWorkerPool(parallel, self._get_worker_class(), start_method=start_method) + for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): + yield from self._post_process_onnx_output(batch) + + @classmethod + def _get_worker_class(cls) -> Type["EmbeddingWorker"]: + return OnnxTextEmbeddingWorker + + def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Preprocess the onnx input. + """ + return onnx_input + + @classmethod + def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]): + embeddings, _ = output + return normalize(embeddings[:, 0]).astype(np.float32) + + def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]: + encoded = self.tokenizer.encode_batch(documents) + input_ids = np.array([e.ids for e in encoded]) + attention_mask = np.array([e.attention_mask for e in encoded]) + + onnx_input = { + "input_ids": np.array(input_ids, dtype=np.int64), + "attention_mask": np.array(attention_mask, dtype=np.int64), + "token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64) + } + + onnx_input = self._preprocess_onnx_input(onnx_input) + + model_output = self.model.run(None, onnx_input) + embeddings = model_output[0] + return embeddings, attention_mask + + +class EmbeddingWorker(Worker): + + def init_embedding( + self, + model_name: str, + cache_dir: str, + ) -> OnnxTextEmbedding: + raise NotImplementedError() + + def __init__( + self, + model_name: str, + cache_dir: str, + ): + self.model = self.init_embedding(model_name, cache_dir) + + @classmethod + def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "EmbeddingWorker": + return cls( + model_name=model_name, + cache_dir=cache_dir, + ) + + def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]]: + for idx, batch in items: + embeddings, attn_mask = self.model.onnx_embed(batch) + yield idx, (embeddings, attn_mask) + + +class OnnxTextEmbeddingWorker(EmbeddingWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + ) -> OnnxTextEmbedding: + return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1) diff --git a/fastembed/text/onnx_models.py b/fastembed/text/onnx_models.py new file mode 100644 index 00000000..b54433ec --- /dev/null +++ b/fastembed/text/onnx_models.py @@ -0,0 +1,133 @@ +supported_flag_models = [ + { + "model": "BAAI/bge-base-en", + "dim": 768, + "description": "Base English model", + "size_in_GB": 0.5, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz", + }, + }, + { + "model": "BAAI/bge-base-en-v1.5", + "dim": 768, + "description": "Base English model, v1.5", + "size_in_GB": 0.44, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz", + "hf": "qdrant/bge-base-en-v1.5-onnx-q", + } + }, + { + "model": "BAAI/bge-large-en-v1.5-quantized", + "dim": 1024, + "description": "Large English model, v1.5", + "size_in_GB": 1.34, + "sources": { + "hf": "qdrant/bge-large-en-v1.5-onnx-q", + } + }, + { + "model": "BAAI/bge-large-en-v1.5", + "dim": 1024, + "description": "Large English model, v1.5", + "size_in_GB": 1.34, + "sources": { + "hf": "qdrant/bge-large-en-v1.5-onnx", + } + }, + { + "model": "BAAI/bge-small-en", + "dim": 384, + "description": "Fast English model", + "size_in_GB": 0.2, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz", + } + }, + # { + # "model": "BAAI/bge-small-en", + # "dim": 384, + # "description": "Fast English model", + # "size_in_GB": 0.2, + # "hf_sources": [], + # "compressed_url_sources": [ + # "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en.tar.gz", + # "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz" + # ] + # }, + { + "model": "BAAI/bge-small-en-v1.5", + "dim": 384, + "description": "Fast and Default English model", + "size_in_GB": 0.13, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en-v1.5.tar.gz", + "hf": "qdrant/bge-small-en-v1.5-onnx-q", + } + }, + { + "model": "BAAI/bge-small-zh-v1.5", + "dim": 512, + "description": "Fast and recommended Chinese model", + "size_in_GB": 0.1, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz", + } + }, + { # todo: it is not a flag embedding + "model": "sentence-transformers/all-MiniLM-L6-v2", + "dim": 384, + "description": "Sentence Transformer model, MiniLM-L6-v2", + "size_in_GB": 0.09, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", + "hf": "qdrant/all-MiniLM-L6-v2-onnx", + } + }, + # { + # "model": "sentence-transformers/all-MiniLM-L6-v2", + # "dim": 384, + # "description": "Sentence Transformer model, MiniLM-L6-v2", + # "size_in_GB": 0.09, + # "hf_sources": [ + # "qdrant/all-MiniLM-L6-v2-onnx" + # ], + # "compressed_url_sources": [ + # "https://storage.googleapis.com/qdrant-fastembed/fast-all-MiniLM-L6-v2.tar.gz", + # "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz" + # ] + # } +] + +supported_multilingual_e5_models = [ + { + "model": "intfloat/multilingual-e5-large", + "dim": 1024, + "description": "Multilingual model, e5-large. Recommend using this model for non-English languages", + "size_in_GB": 2.24, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", + "hf": "qdrant/multilingual-e5-large-onnx", + } + } +] + +supported_jina_models = [ + { + "model": "jinaai/jina-embeddings-v2-base-en", + "dim": 768, + "description": "English embedding model supporting 8192 sequence length", + "size_in_GB": 0.55, + "sources": { + "hf": "xenova/jina-embeddings-v2-base-en" + } + }, + { + "model": "jinaai/jina-embeddings-v2-small-en", + "dim": 512, + "description": "English embedding model supporting 8192 sequence length", + "size_in_GB": 0.13, + "sources": {"hf": "xenova/jina-embeddings-v2-small-en"} + } +] diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py new file mode 100644 index 00000000..d2f9f74b --- /dev/null +++ b/fastembed/text/text_embedding.py @@ -0,0 +1,91 @@ +from typing import Optional, Union, Iterable, List, Dict, Any, Type + +import numpy as np + +from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding +from fastembed.text.jina_onnx_embedding import JinaOnnxEmbedding +from fastembed.text.onnx_embedding import OnnxTextEmbedding +from fastembed.text.text_embedding_base import TextEmbeddingBase + + +class TextEmbedding(TextEmbeddingBase): + EMBEDDINGS_REGISTRY: List[Type[TextEmbeddingBase]] = [ + OnnxTextEmbedding, + E5OnnxEmbedding, + JinaOnnxEmbedding, + ] + + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + """ + Lists the supported models. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. + + Example: + ``` + [ + { + "model": "intfloat/multilingual-e5-large", + "dim": 1024, + "description": "Multilingual model, e5-large. Recommend using this model for non-English languages", + "size_in_GB": 2.24, + "sources": { + "gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", + "hf": "qdrant/multilingual-e5-large-onnx", + } + } + ] + ``` + """ + result = [] + for embedding in cls.EMBEDDINGS_REGISTRY: + result.extend(embedding.list_supported_models()) + return result + + def __init__( + self, + model_name: str = "BAAI/bge-small-en-v1.5", + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs + ): + super().__init__(model_name, cache_dir, threads, **kwargs) + + self.model = None + for embedding in self.EMBEDDINGS_REGISTRY: + supported_models = embedding.list_supported_models() + if any(model_name == model["model"] for model in supported_models): + self.model = embedding(model_name, cache_dir, threads, **kwargs) + break + + if self.model is None: + raise ValueError( + f"Model {model_name} is not supported in TextEmbedding." + "Please check the supported models using `TextEmbedding.list_supported_models()`" + ) + + def embed( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: int = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + We use mean pooling with attention so that the model can handle variable-length inputs. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self.model.embed(documents, batch_size, parallel, **kwargs) diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py new file mode 100644 index 00000000..8e2de732 --- /dev/null +++ b/fastembed/text/text_embedding_base.py @@ -0,0 +1,62 @@ +from typing import Optional, Union, Iterable, List, Dict, Any + +import numpy as np + +from fastembed.common.model_management import ModelManagement + + +class TextEmbeddingBase(ModelManagement): + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + raise NotImplementedError() + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs + ): + self.model_name = model_name + self.cache_dir = cache_dir + self.threads = threads + + def embed( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: int = None, + **kwargs, + ) -> Iterable[np.ndarray]: + raise NotImplementedError() + + def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds a list of text passages into a list of embeddings. + + Args: + texts (Iterable[str]): The list of texts to embed. + **kwargs: Additional keyword argument to pass to the embed method. + + Yields: + Iterable[np.ndarray]: The embeddings. + """ + + # This is model-specific, so that different models can have specialized implementations + yield from self.embed(texts, **kwargs) + + def query_embed(self, query: str, **kwargs) -> np.ndarray: + """ + Embeds a query + + Args: + query (str): The query to search for. + + Returns: + np.ndarray: The embeddings. + """ + + # This is model-specific, so that different models can have specialized implementations + query_embedding = list(self.embed([query], **kwargs))[0] + return query_embedding + diff --git a/tests/test_onnx_embeddings.py b/tests/test_onnx_embeddings.py index 743dd3e5..6c1185d3 100644 --- a/tests/test_onnx_embeddings.py +++ b/tests/test_onnx_embeddings.py @@ -27,6 +27,9 @@ def test_embedding(embedding_class): if is_ubuntu_ci == "false" and model_desc["size_in_GB"] > 1: continue + if model_desc["model"] not in CANONICAL_VECTOR_VALUES: + continue + dim = model_desc["dim"] model = embedding_class(model_name=model_desc["model"]) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py new file mode 100644 index 00000000..e44dcf23 --- /dev/null +++ b/tests/test_text_onnx_embeddings.py @@ -0,0 +1,81 @@ +import os + +import numpy as np +import pytest + +from fastembed.text.text_embedding import TextEmbedding + +CANONICAL_VECTOR_VALUES = { + "BAAI/bge-small-en": np.array([-0.0232, -0.0255, 0.0174, -0.0639, -0.0006]), + "BAAI/bge-small-en-v1.5": np.array([0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434]), + "BAAI/bge-small-zh-v1.5": np.array([-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762]), + "BAAI/bge-base-en": np.array([0.0115, 0.0372, 0.0295, 0.0121, 0.0346]), + "BAAI/bge-base-en-v1.5": np.array([0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045]), + "BAAI/bge-large-en-v1.5": np.array([0.03434538, 0.03316108, 0.02191251, -0.03713358, -0.01577825]), + "BAAI/bge-large-en-v1.5-quantized": np.array([0.03434538, 0.03316108, 0.02191251, -0.03713358, -0.01577825]), + "sentence-transformers/all-MiniLM-L6-v2": np.array([0.0259, 0.0058, 0.0114, 0.0380, -0.0233]), + "intfloat/multilingual-e5-large": np.array([0.0098, 0.0045, 0.0066, -0.0354, 0.0070]), + "jinaai/jina-embeddings-v2-small-en": np.array([-0.0455, -0.0428, -0.0122, 0.0613, 0.0015]), + "jinaai/jina-embeddings-v2-base-en": np.array([-0.0332, -0.0509, 0.0287, -0.0043, -0.0077]), +} + + +def test_embedding(): + is_ubuntu_ci = os.getenv("IS_UBUNTU_CI") + + for model_desc in TextEmbedding.list_supported_models(): + if is_ubuntu_ci == "false" and model_desc["size_in_GB"] > 1: + continue + + dim = model_desc["dim"] + model = TextEmbedding(model_name=model_desc["model"]) + + docs = ["hello world", "flag embedding"] + embeddings = list(model.embed(docs)) + embeddings = np.stack(embeddings, axis=0) + assert embeddings.shape == (2, dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] + assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3), model_desc["model"] + + +@pytest.mark.parametrize( + "n_dims,model_name", + [ + (384, "BAAI/bge-small-en-v1.5"), + (768, "jinaai/jina-embeddings-v2-base-en") + ] +) +def test_batch_embedding(n_dims, model_name): + model = TextEmbedding(model_name=model_name) + + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (200, n_dims) + + +@pytest.mark.parametrize( + "n_dims,model_name", + [ + (384, "BAAI/bge-small-en-v1.5"), + (768, "jinaai/jina-embeddings-v2-base-en") + ] +) +def test_parallel_processing(n_dims, model_name): + model = TextEmbedding(model_name=model_name) + + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) + + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) + + assert embeddings.shape == (200, n_dims) + assert np.allclose(embeddings, embeddings_2, atol=1e-3) + assert np.allclose(embeddings, embeddings_3, atol=1e-3)