Skip to content

Commit

Permalink
feat: HuggingFace download support for FlagEmbedding (#94)
Browse files Browse the repository at this point in the history
* feat: HF support for FlagEmbedding

* chore: update docstring embedding.py

* refactor: GCS URLs models.json

* chore: toLower() models.json

* chore: update tqdm declarative

* chore: exclude keys list_supported_models

* chore: review changes
  • Loading branch information
Anush008 authored Jan 23, 2024
1 parent f87330f commit ede507e
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 119 deletions.
251 changes: 134 additions & 117 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import json
import os
import shutil
Expand All @@ -7,13 +8,16 @@
from itertools import islice
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union

import numpy as np
import onnxruntime as ort
import requests
from tokenizers import AddedToken, Tokenizer
from tqdm import tqdm
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from loguru import logger

from fastembed.parallel_processor import ParallelWorkerPool, Worker

Expand Down Expand Up @@ -179,71 +183,44 @@ class Embedding(ABC):
_type_: _description_
"""

# Internal helper decorator to maintain backward compatibility
# by supporting a fallback to download from Google Cloud Storage (GCS)
# if the model couldn't be downloaded from HuggingFace.
def gcs_fallback(hf_download_method: Callable) -> Callable:
@functools.wraps(hf_download_method)
def wrapper(self, *args, **kwargs):
try:
return hf_download_method(self, *args, **kwargs)
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.exception(
f"Could not download model from HuggingFace: {e}"
"Falling back to download from Google Cloud Storage"
)
return self.retrieve_model_gcs(*args, **kwargs)

return wrapper

@abstractmethod
def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = None) -> List[np.ndarray]:
raise NotImplementedError

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Lists the supported models.
def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Any]]:
"""Lists the supported models.
Args:
exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return [
{
"model": "BAAI/bge-small-en",
"dim": 384,
"description": "Fast English model",
"size_in_GB": 0.2,
},
{
"model": "BAAI/bge-small-en-v1.5",
"dim": 384,
"description": "Fast and Default English model",
"size_in_GB": 0.13,
},
{
"model": "BAAI/bge-small-zh-v1.5",
"dim": 512,
"description": "Fast and recommended Chinese model",
"size_in_GB": 0.1,
},
{
"model": "BAAI/bge-base-en",
"dim": 768,
"description": "Base English model",
"size_in_GB": 0.5,
},
{
"model": "BAAI/bge-base-en-v1.5",
"dim": 768,
"description": "Base English model, v1.5",
"size_in_GB": 0.44,
},
{
"model": "sentence-transformers/all-MiniLM-L6-v2",
"dim": 384,
"description": "Sentence Transformer model, MiniLM-L6-v2",
"size_in_GB": 0.09,
},
{
"model": "intfloat/multilingual-e5-large",
"dim": 1024,
"description": "Multilingual model, e5-large. Recommend using this model for non-English languages",
"size_in_GB": 2.24,
},
{
"model": "jinaai/jina-embeddings-v2-base-en",
"dim": 768,
"description": " English embedding model supporting 8192 sequence length",
"size_in_GB": 0.55,
},
{
"model": "jinaai/jina-embeddings-v2-small-en",
"dim": 512,
"description": " English embedding model supporting 8192 sequence length",
"size_in_GB": 0.13,
},
]
models_file_path = Path(__file__).with_name("models.json")
with open(models_file_path, "r") as file:
models = json.load(file)

models = [{k: v for k, v in model.items() if k not in exclude} for model in models]

return models

@classmethod
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
Expand Down Expand Up @@ -276,48 +253,49 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
if total_size_in_bytes == 0:
print(f"Warning: Content-length header is missing or zero in the response from {url}.")

# Initialize the progress bar
progress_bar = (
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
if total_size_in_bytes and show_progress
else None
)
show_progress = total_size_in_bytes and show_progress

# Attempt to download the file
try:
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar:
with open(output_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024): # Adjust chunk size to your preference
for chunk in response.iter_content(chunk_size=1024):
if chunk: # Filter out keep-alive new chunks
if progress_bar is not None:
progress_bar.update(len(chunk))
progress_bar.update(len(chunk))
file.write(chunk)
except Exception as e:
print(f"An error occurred while trying to download the file: {str(e)}")
return
finally:
if progress_bar is not None:
progress_bar.close()
return output_path

@classmethod
def download_files_from_huggingface(cls, repod_id: str, cache_dir: Optional[str] = None) -> str:
def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[str] = None) -> str:
"""
Downloads a model from HuggingFace Hub.
Args:
repod_id (str): The HF hub id (name) of the model to retrieve.
model_name (str): Name of the model to download.
cache_dir (Optional[str]): The path to the cache directory.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
Returns:
Path: The path to the model directory.
"""
from huggingface_hub import snapshot_download

return snapshot_download(
repo_id=repod_id,
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
cache_dir=cache_dir,
)
models = cls.list_supported_models(exclude=["compressed_url_sources"])

hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]]

# Check if the HF sources list is empty
# Raise an exception causing a fallback to GCS
if not hf_sources:
raise ValueError(f"No HuggingFace source for {model_name}")

for index, repo_id in enumerate(hf_sources):
try:
return snapshot_download(
repo_id=repo_id,
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
cache_dir=cache_dir,
)
except (RepositoryNotFoundError, EnvironmentError) as e:
logger.exception(f"Failed to download model from HF source: {repo_id}: {e} ")
if repo_id == hf_sources[-1]:
raise e
logger.info(f"Trying another source: {hf_sources[index+1]}")

@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
Expand Down Expand Up @@ -368,28 +346,36 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
Returns:
Path: The path to the model directory.
"""

assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"

fast_model_name = f"fast-{model_name.split('/')[-1]}"

model_dir = Path(cache_dir) / fast_model_name
if model_dir.exists():
return model_dir

model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
try:
self.download_file_from_gcs(
f"https://storage.googleapis.com/qdrant-fastembed/{fast_model_name}.tar.gz",
output_path=str(model_tar_gz),
)
except PermissionError:
simple_model_name = model_name.replace("/", "-")
print(f"Was not able to download {fast_model_name}.tar.gz, trying {simple_model_name}.tar.gz")
self.download_file_from_gcs(
f"https://storage.googleapis.com/qdrant-fastembed/{simple_model_name}.tar.gz",
output_path=str(model_tar_gz),
)

models = self.list_supported_models(exclude=["hf_sources"])

compressed_url_sources = [
item for model in models if model["model"] == model_name for item in model["compressed_url_sources"]
]

# Check if the GCS sources list is empty after falling back from HF
# A model should always have at least one source
if not compressed_url_sources:
raise ValueError(f"No GCS source for {model_name}")

for index, source in enumerate(compressed_url_sources):
try:
self.download_file_from_gcs(
source,
output_path=str(model_tar_gz),
)
except (RuntimeError, PermissionError) as e:
logger.exception(f"Failed to download model from GCS source: {source}: {e} ")
if source == compressed_url_sources[-1]:
raise e
logger.info(f"Trying another source: {compressed_url_sources[index+1]}")

self.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=cache_dir)
assert model_dir.exists(), f"Could not find {model_dir} in {cache_dir}"
Expand All @@ -398,23 +384,31 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:

return model_dir

@gcs_fallback
def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
"""
Retrieves a model from HuggingFace Hub.
Args:
model_name (str): The name of the model to retrieve.
cache_dir (str): The path to the cache directory.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
Returns:
Path: The path to the model directory.
"""

assert (
"/" in model_name
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
return Path(self.download_files_from_huggingface(model_name=model_name, cache_dir=cache_dir))

@classmethod
def assert_model_name(cls, model_name: str):
assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"

return Path(self.download_files_from_huggingface(repod_id=model_name, cache_dir=cache_dir))
models = cls.list_supported_models()
model_names = [model["model"] for model in models]
if model_name not in model_names:
raise ValueError(
f"{model_name} is not a supported model.\n"
f"Try one of {', '.join(model_names)}.\n"
f"Use the 'list_supported_models()' method to get the model information."
)

def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
"""
Expand Down Expand Up @@ -475,6 +469,9 @@ def __init__(
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""

self.assert_model_name(model_name)

self.model_name = model_name

if cache_dir is None:
Expand All @@ -483,7 +480,7 @@ def __init__(
cache_dir.mkdir(parents=True, exist_ok=True)

self._cache_dir = cache_dir
self._model_dir = self.retrieve_model_gcs(model_name, cache_dir)
self._model_dir = self.retrieve_model_hf(model_name, cache_dir)
self._max_length = max_length

self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length, max_threads=threads)
Expand Down Expand Up @@ -536,12 +533,21 @@ def embed(
yield from normalize(embeddings[:, 0]).astype(np.float32)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Lists the supported models.
def list_supported_models(
cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"]
) -> List[Dict[str, Any]]:
"""Lists the supported models.
Args:
exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
# jina models are not supported by this class
return [model for model in super().list_supported_models() if not model["model"].startswith("jinaai")]
return [
model for model in super().list_supported_models(exclude=exclude) if not model["model"].startswith("jinaai")
]


class DefaultEmbedding(FlagEmbedding):
Expand Down Expand Up @@ -591,8 +597,10 @@ def __init__(
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
ValueError: If the model_name is not in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-base-en.
"""
self.assert_model_name(model_name)

self.model_name = model_name

if cache_dir is None:
Expand Down Expand Up @@ -652,12 +660,21 @@ def embed(
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Lists the supported models.
def list_supported_models(
cls, exclude: List[str] = ["compressed_url_sources", "hf_sources"]
) -> List[Dict[str, Any]]:
"""Lists the supported models.
Args:
exclude (List[str], optional): Keys to exclude from the result. Defaults to ["compressed_url_sources", "hf_sources"].
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
# only jina models are supported by this class
return [model for model in Embedding.list_supported_models() if model["model"].startswith("jinaai")]
return [
model for model in Embedding.list_supported_models(exclude=exclude) if model["model"].startswith("jinaai")
]

@staticmethod
def mean_pooling(model_output, attention_mask):
Expand Down
Loading

0 comments on commit ede507e

Please sign in to comment.