Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of internal structure #105

Merged
merged 6 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pip install fastembed
## 📖 Usage

```python
from fastembed.embedding import FlagEmbedding as Embedding
from fastembed import TextEmbedding
from typing import List
import numpy as np

Expand All @@ -36,7 +36,7 @@ documents: List[str] = [
"passage: This is an example passage.",
"fastembed is supported by and maintained by Qdrant." # You can leave out the prefix but it's recommended
]
embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512)
embedding_model = TextEmbedding(model_name="BAAI/bge-base-en")
embeddings: List[np.ndarray] = list(embedding_model.embed(documents)) # Note the list() call - this is a generator
```

Expand Down
Empty file added fastembed/__init__.py
Empty file.
Empty file added fastembed/common/__init__.py
Empty file.
196 changes: 196 additions & 0 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import os
import shutil
import tarfile
from pathlib import Path
from typing import List, Optional, Dict, Any

import requests
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from tqdm import tqdm
from loguru import logger


def locate_model_file(model_dir: Path, file_names: List[str]) -> Path:
"""
Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used
by Optimum and Qdrant
"""
if not model_dir.is_dir():
raise ValueError(f"Provided model path '{model_dir}' is not a directory.")

for file_name in file_names:
file_paths = [path for path in model_dir.rglob(file_name) if path.is_file()]

if file_paths:
return file_paths[0]

raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}")


class ModelManagement:
@classmethod
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
"""
Downloads a file from Google Cloud Storage.
Args:
url (str): The URL to download the file from.
output_path (str): The path to save the downloaded file to.
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
Returns:
str: The path to the downloaded file.
"""

if os.path.exists(output_path):
return output_path
response = requests.get(url, stream=True)

# Handle HTTP errors
if response.status_code == 403:
raise PermissionError(
"Authentication Error: You do not have permission to access this resource. "
"Please check your credentials."
)

# Get the total size of the file
total_size_in_bytes = int(response.headers.get("content-length", 0))

# Warn if the total size is zero
if total_size_in_bytes == 0:
print(f"Warning: Content-length header is missing or zero in the response from {url}.")

show_progress = total_size_in_bytes and show_progress

with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar:
with open(output_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # Filter out keep-alive new chunks
progress_bar.update(len(chunk))
file.write(chunk)
return output_path

@classmethod
def download_files_from_huggingface(cls, hf_source_repo: str, cache_dir: Optional[str] = None) -> str:
"""
Downloads a model from HuggingFace Hub.
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
Returns:
Path: The path to the model directory.
"""

return snapshot_download(
repo_id=hf_source_repo,
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
cache_dir=cache_dir,
)

@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
"""
Decompresses a .tar.gz file to a cache directory.
Args:
targz_path (str): Path to the .tar.gz file.
cache_dir (str): Path to the cache directory.
Returns:
cache_dir (str): Path to the cache directory.
"""
# Check if targz_path exists and is a file
if not os.path.isfile(targz_path):
raise ValueError(f"{targz_path} does not exist or is not a file.")

# Check if targz_path is a .tar.gz file
if not targz_path.endswith(".tar.gz"):
raise ValueError(f"{targz_path} is not a .tar.gz file.")

try:
# Open the tar.gz file
with tarfile.open(targz_path, "r:gz") as tar:
# Extract all files into the cache directory
tar.extractall(path=cache_dir)
except tarfile.TarError as e:
# If any error occurs while opening or extracting the tar.gz file,
# delete the cache directory (if it was created in this function)
# and raise the error again
if "tmp" in cache_dir:
shutil.rmtree(cache_dir)
raise ValueError(f"An error occurred while decompressing {targz_path}: {e}")

return cache_dir

@classmethod
def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> Path:
fast_model_name = f"fast-{model_name.split('/')[-1]}"

cache_tmp_dir = Path(cache_dir) / "tmp"
model_tmp_dir = cache_tmp_dir / fast_model_name
model_dir = Path(cache_dir) / fast_model_name

if model_dir.exists():
return model_dir

if model_tmp_dir.exists():
shutil.rmtree(model_tmp_dir)

cache_tmp_dir.mkdir(parents=True, exist_ok=True)

model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"

cls.download_file_from_gcs(
source_url,
output_path=str(model_tar_gz),
)

cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"

model_tar_gz.unlink()
# Rename from tmp to final name is atomic
model_tmp_dir.rename(model_dir)

return model_dir

@classmethod
def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path:
"""
Downloads a model from HuggingFace Hub or Google Cloud Storage.
Args:
model (Dict[str, Any]): The model description.
Example:
```
{
"model": "BAAI/bge-base-en-v1.5",
"dim": 768,
"description": "Base English model, v1.5",
"size_in_GB": 0.44,
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
"hf": "qdrant/bge-base-en-v1.5-onnx-q",
}
}
```
cache_dir (str): The path to the cache directory.
Returns:
Path: The path to the downloaded model directory.
"""

hf_source = model.get("sources", {}).get("hf")
url_source = model.get("sources", {}).get("url")

if hf_source:
try:
return Path(cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir)))
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(f"Could not download model from HuggingFace: {e}" "Falling back to other sources.")

if url_source:
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))

raise ValueError(f"Could not download model {model['model']} from any source.")
52 changes: 52 additions & 0 deletions fastembed/common/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import json
from pathlib import Path

import numpy as np
from tokenizers import Tokenizer, AddedToken


def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
config_path = model_dir / "config.json"
if not config_path.exists():
raise ValueError(f"Could not find config.json in {model_dir}")

tokenizer_path = model_dir / "tokenizer.json"
if not tokenizer_path.exists():
raise ValueError(f"Could not find tokenizer.json in {model_dir}")

tokenizer_config_path = model_dir / "tokenizer_config.json"
if not tokenizer_config_path.exists():
raise ValueError(f"Could not find tokenizer_config.json in {model_dir}")

tokens_map_path = model_dir / "special_tokens_map.json"
if not tokens_map_path.exists():
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")

with open(str(config_path)) as config_file:
config = json.load(config_file)

with open(str(tokenizer_config_path)) as tokenizer_config_file:
tokenizer_config = json.load(tokenizer_config_file)

with open(str(tokens_map_path)) as tokens_map_file:
tokens_map = json.load(tokens_map_file)

tokenizer = Tokenizer.from_file(str(tokenizer_path))
tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length))
tokenizer.enable_padding(pad_id=config["pad_token_id"], pad_token=tokenizer_config["pad_token"])

for token in tokens_map.values():
if isinstance(token, str):
tokenizer.add_special_tokens([token])
elif isinstance(token, dict):
tokenizer.add_special_tokens([AddedToken(**token)])

return tokenizer


def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray:
# Calculate the Lp norm along the specified dimension
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
norm = np.maximum(norm, eps) # Avoid division by zero
normalized_array = input_array / norm
return normalized_array
33 changes: 33 additions & 0 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import tempfile
from itertools import islice
from pathlib import Path
from typing import Union, Iterable, Generator, Optional


def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
"""
>>> list(iter_batch([1,2,3,4,5], 3))
[[1, 2, 3], [4, 5]]
"""
source_iter = iter(iterable)
while source_iter:
b = list(islice(source_iter, size))
if len(b) == 0:
break
yield b


def define_cache_dir(cache_dir: Optional[str] = None) -> Path:
"""
Define the cache directory for fastembed
"""
if cache_dir is None:
default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache")
cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir))
else:
cache_dir = Path(cache_dir)

cache_dir.mkdir(parents=True, exist_ok=True)

return cache_dir
Loading
Loading