Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
Co-authored-by: George Panchuk <[email protected]>
  • Loading branch information
generall and joein committed Feb 1, 2024
1 parent 2e3e550 commit fb93e40
Show file tree
Hide file tree
Showing 18 changed files with 976 additions and 687 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pip install fastembed
## 📖 Usage

```python
from fastembed.embedding import FlagEmbedding as Embedding
from fastembed import TextEmbedding
from typing import List
import numpy as np

Expand All @@ -36,7 +36,7 @@ documents: List[str] = [
"passage: This is an example passage.",
"fastembed is supported by and maintained by Qdrant." # You can leave out the prefix but it's recommended
]
embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512)
embedding_model = TextEmbedding(model_name="BAAI/bge-base-en")
embeddings: List[np.ndarray] = list(embedding_model.embed(documents)) # Note the list() call - this is a generator
```

Expand Down
1 change: 1 addition & 0 deletions fastembed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from fastembed.text.text_embedding import TextEmbedding
Empty file added fastembed/common/__init__.py
Empty file.
211 changes: 211 additions & 0 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import os
import shutil
import tarfile
from pathlib import Path
from typing import List, Optional, Dict, Any

import requests
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from tqdm import tqdm
from loguru import logger


def locate_model_file(model_dir: Path, file_names: List[str]):
"""
Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used
by Optimum and Qdrant
"""
if not model_dir.is_dir():
raise ValueError(f"Provided model path '{model_dir}' is not a directory.")

for path in model_dir.rglob("*.onnx"):
for file_name in file_names:
if path.is_file() and path.name == file_name:
return path

raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}")


class ModelManagement:

@classmethod
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
"""
Downloads a file from Google Cloud Storage.
Args:
url (str): The URL to download the file from.
output_path (str): The path to save the downloaded file to.
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
Returns:
str: The path to the downloaded file.
"""

if os.path.exists(output_path):
return output_path
response = requests.get(url, stream=True)

# Handle HTTP errors
if response.status_code == 403:
raise PermissionError(
"Authentication Error: You do not have permission to access this resource. "
"Please check your credentials."
)

# Get the total size of the file
total_size_in_bytes = int(response.headers.get("content-length", 0))

# Warn if the total size is zero
if total_size_in_bytes == 0:
print(f"Warning: Content-length header is missing or zero in the response from {url}.")

show_progress = total_size_in_bytes and show_progress

with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar:
with open(output_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # Filter out keep-alive new chunks
progress_bar.update(len(chunk))
file.write(chunk)
return output_path

@classmethod
def download_files_from_huggingface(cls, hf_source_repo: str, cache_dir: Optional[str] = None) -> str:
"""
Downloads a model from HuggingFace Hub.
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
Returns:
Path: The path to the model directory.
"""

return snapshot_download(
repo_id=hf_source_repo,
ignore_patterns=["model.safetensors", "pytorch_model.bin"],
cache_dir=cache_dir,
)

@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
"""
Decompresses a .tar.gz file to a cache directory.
Args:
targz_path (str): Path to the .tar.gz file.
cache_dir (str): Path to the cache directory.
Returns:
cache_dir (str): Path to the cache directory.
"""
# Check if targz_path exists and is a file
if not os.path.isfile(targz_path):
raise ValueError(f"{targz_path} does not exist or is not a file.")

# Check if targz_path is a .tar.gz file
if not targz_path.endswith(".tar.gz"):
raise ValueError(f"{targz_path} is not a .tar.gz file.")

try:
# Open the tar.gz file
with tarfile.open(targz_path, "r:gz") as tar:
# Extract all files into the cache directory
tar.extractall(path=cache_dir)
except tarfile.TarError as e:
# If any error occurs while opening or extracting the tar.gz file,
# delete the cache directory (if it was created in this function)
# and raise the error again
if "tmp" in cache_dir:
shutil.rmtree(cache_dir)
raise ValueError(f"An error occurred while decompressing {targz_path}: {e}")

return cache_dir

@classmethod
def retrieve_model_gcs(
cls,
model_name: str,
source_url: str,
cache_dir: str
) -> Path:
fast_model_name = f"fast-{model_name.split('/')[-1]}"

cache_tmp_dir = Path(cache_dir) / "tmp"
model_tmp_dir = cache_tmp_dir / fast_model_name
model_dir = Path(cache_dir) / fast_model_name

if model_dir.exists():
return model_dir

if model_tmp_dir.exists():
shutil.rmtree(model_tmp_dir)

cache_tmp_dir.mkdir(parents=True, exist_ok=True)

model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"

cls.download_file_from_gcs(
source_url,
output_path=str(model_tar_gz),
)

cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"

model_tar_gz.unlink()
# Rename from tmp to final name is atomic
model_tmp_dir.rename(model_dir)

return model_dir

@classmethod
def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path:
"""
Downloads a model from HuggingFace Hub or Google Cloud Storage.
Args:
model (Dict[str, Any]): The model description.
Example:
```
{
"model": "BAAI/bge-base-en-v1.5",
"dim": 768,
"description": "Base English model, v1.5",
"size_in_GB": 0.44,
"sources": {
"gcp": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
"hf": "qdrant/bge-base-en-v1.5-onnx-q",
}
}
```
cache_dir (str): The path to the cache directory.
Returns:
Path: The path to the downloaded model directory.
"""

hf_source = model.get("sources", {}).get("hf")
gcp_source = model.get("sources", {}).get("gcp")

if hf_source:
try:
return Path(cls.download_files_from_huggingface(
hf_source,
cache_dir=str(cache_dir)
))
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(
f"Could not download model from HuggingFace: {e}"
"Falling back to other sources."
)

if gcp_source:
return cls.retrieve_model_gcs(
model["model"],
gcp_source,
str(cache_dir)
)

raise ValueError(f"Could not download model {model['model']} from any source.")
52 changes: 52 additions & 0 deletions fastembed/common/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import json
from pathlib import Path

import numpy as np
from tokenizers import Tokenizer, AddedToken


def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
config_path = model_dir / "config.json"
if not config_path.exists():
raise ValueError(f"Could not find config.json in {model_dir}")

tokenizer_path = model_dir / "tokenizer.json"
if not tokenizer_path.exists():
raise ValueError(f"Could not find tokenizer.json in {model_dir}")

tokenizer_config_path = model_dir / "tokenizer_config.json"
if not tokenizer_config_path.exists():
raise ValueError(f"Could not find tokenizer_config.json in {model_dir}")

tokens_map_path = model_dir / "special_tokens_map.json"
if not tokens_map_path.exists():
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")

with open(str(config_path)) as config_file:
config = json.load(config_file)

with open(str(tokenizer_config_path)) as tokenizer_config_file:
tokenizer_config = json.load(tokenizer_config_file)

with open(str(tokens_map_path)) as tokens_map_file:
tokens_map = json.load(tokens_map_file)

tokenizer = Tokenizer.from_file(str(tokenizer_path))
tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length))
tokenizer.enable_padding(pad_id=config["pad_token_id"], pad_token=tokenizer_config["pad_token"])

for token in tokens_map.values():
if isinstance(token, str):
tokenizer.add_special_tokens([token])
elif isinstance(token, dict):
tokenizer.add_special_tokens([AddedToken(**token)])

return tokenizer


def normalize(input_array, p=2, dim=1, eps= 1e-12) -> np.ndarray:
# Calculate the Lp norm along the specified dimension
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
norm = np.maximum(norm, eps) # Avoid division by zero
normalized_array = input_array / norm
return normalized_array
33 changes: 33 additions & 0 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import tempfile
from itertools import islice
from pathlib import Path
from typing import Union, Iterable, Generator, Optional


def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
"""
>>> list(iter_batch([1,2,3,4,5], 3))
[[1, 2, 3], [4, 5]]
"""
source_iter = iter(iterable)
while source_iter:
b = list(islice(source_iter, size))
if len(b) == 0:
break
yield b


def define_cache_dir(cache_dir: Optional[str] = None) -> Path:
"""
Define the cache directory for fastembed
"""
if cache_dir is None:
default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache")
cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir))
else:
cache_dir = Path(cache_dir)

cache_dir.mkdir(parents=True, exist_ok=True)

return cache_dir
Loading

0 comments on commit fb93e40

Please sign in to comment.