Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
generall committed Feb 2, 2024
1 parent bc2dbaf commit d0005d4
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 106 deletions.
1 change: 0 additions & 1 deletion fastembed/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from fastembed.text.text_embedding import TextEmbedding
24 changes: 4 additions & 20 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def locate_model_file(model_dir: Path, file_names: List[str]):


class ModelManagement:

@classmethod
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
"""
Expand Down Expand Up @@ -124,12 +123,7 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
return cache_dir

@classmethod
def retrieve_model_gcs(
cls,
model_name: str,
source_url: str,
cache_dir: str
) -> Path:
def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> Path:
fast_model_name = f"fast-{model_name.split('/')[-1]}"

cache_tmp_dir = Path(cache_dir) / "tmp"
Expand Down Expand Up @@ -191,21 +185,11 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path:

if hf_source:
try:
return Path(cls.download_files_from_huggingface(
hf_source,
cache_dir=str(cache_dir)
))
return Path(cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir)))
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(
f"Could not download model from HuggingFace: {e}"
"Falling back to other sources."
)
logger.error(f"Could not download model from HuggingFace: {e}" "Falling back to other sources.")

if url_source:
return cls.retrieve_model_gcs(
model["model"],
url_source,
str(cache_dir)
)
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))

raise ValueError(f"Could not download model {model['model']} from any source.")
2 changes: 1 addition & 1 deletion fastembed/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
return tokenizer


def normalize(input_array, p=2, dim=1, eps= 1e-12) -> np.ndarray:
def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray:
# Calculate the Lp norm along the specified dimension
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
norm = np.maximum(norm, eps) # Avoid division by zero
Expand Down
15 changes: 6 additions & 9 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@

from fastembed.text.text_embedding import TextEmbedding

logger.warning(
"DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated."
" Use TextEmbedding instead."
)
logger.warning("DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated." " Use TextEmbedding instead.")

DefaultEmbedding = TextEmbedding
FlagEmbedding = TextEmbedding


class JinaEmbedding(TextEmbedding):
def __init__(
self,
model_name: str = "jinaai/jina-embeddings-v2-base-en",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs
self,
model_name: str = "jinaai/jina-embeddings-v2-base-en",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
):
super().__init__(model_name, cache_dir, threads, **kwargs)
9 changes: 4 additions & 5 deletions fastembed/text/e5_onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz",
"hf": "qdrant/multilingual-e5-large-onnx",
}
},
}
]


class E5OnnxEmbedding(OnnxTextEmbedding):

@classmethod
def _get_worker_class(cls) -> Type["EmbeddingWorker"]:
return E5OnnxEmbeddingWorker
Expand All @@ -43,8 +42,8 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str,

class E5OnnxEmbeddingWorker(OnnxTextEmbeddingWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
self,
model_name: str,
cache_dir: str,
) -> E5OnnxEmbedding:
return E5OnnxEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1)
15 changes: 6 additions & 9 deletions fastembed/text/jina_onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,19 @@
"dim": 768,
"description": "English embedding model supporting 8192 sequence length",
"size_in_GB": 0.55,
"sources": {
"hf": "xenova/jina-embeddings-v2-base-en"
}
"sources": {"hf": "xenova/jina-embeddings-v2-base-en"},
},
{
"model": "jinaai/jina-embeddings-v2-small-en",
"dim": 512,
"description": "English embedding model supporting 8192 sequence length",
"size_in_GB": 0.13,
"sources": {"hf": "xenova/jina-embeddings-v2-small-en"}
}
"sources": {"hf": "xenova/jina-embeddings-v2-small-en"},
},
]


class JinaOnnxEmbedding(OnnxTextEmbedding):

@classmethod
def _get_worker_class(cls) -> Type[EmbeddingWorker]:
return JinaEmbeddingWorker
Expand Down Expand Up @@ -58,8 +55,8 @@ def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> np.

class JinaEmbeddingWorker(OnnxTextEmbeddingWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
self,
model_name: str,
cache_dir: str,
) -> OnnxTextEmbedding:
return JinaOnnxEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1)
55 changes: 27 additions & 28 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
"hf": "qdrant/bge-base-en-v1.5-onnx-q",
}
},
},
{
"model": "BAAI/bge-large-en-v1.5-quantized",
Expand All @@ -38,7 +38,7 @@
"size_in_GB": 1.34,
"sources": {
"hf": "qdrant/bge-large-en-v1.5-onnx-q",
}
},
},
{
"model": "BAAI/bge-large-en-v1.5",
Expand All @@ -47,7 +47,7 @@
"size_in_GB": 1.34,
"sources": {
"hf": "qdrant/bge-large-en-v1.5-onnx",
}
},
},
{
"model": "BAAI/bge-small-en",
Expand All @@ -56,7 +56,7 @@
"size_in_GB": 0.2,
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
}
},
},
# {
# "model": "BAAI/bge-small-en",
Expand All @@ -77,7 +77,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en-v1.5.tar.gz",
"hf": "qdrant/bge-small-en-v1.5-onnx-q",
}
},
},
{
"model": "BAAI/bge-small-zh-v1.5",
Expand All @@ -86,7 +86,7 @@
"size_in_GB": 0.1,
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
}
},
},
{ # todo: it is not a flag embedding
"model": "sentence-transformers/all-MiniLM-L6-v2",
Expand All @@ -96,7 +96,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
"hf": "qdrant/all-MiniLM-L6-v2-onnx",
}
},
},
# {
# "model": "sentence-transformers/all-MiniLM-L6-v2",
Expand Down Expand Up @@ -147,11 +147,11 @@ def _get_model_description(cls, model_name: str) -> Dict[str, Any]:
raise ValueError(f"Model {model_name} is not supported in FlagEmbedding.")

def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str = None,
threads: int = None,
**kwargs,
self,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str = None,
threads: int = None,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -190,11 +190,11 @@ def __init__(
self.model = ort.InferenceSession(str(model_path), providers=onnx_providers, sess_options=so)

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: int = None,
**kwargs,
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: int = None,
**kwargs,
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
Expand Down Expand Up @@ -260,7 +260,7 @@ def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]:
onnx_input = {
"input_ids": np.array(input_ids, dtype=np.int64),
"attention_mask": np.array(attention_mask, dtype=np.int64),
"token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64)
"token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64),
}

onnx_input = self._preprocess_onnx_input(onnx_input)
Expand All @@ -271,18 +271,17 @@ def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]:


class EmbeddingWorker(Worker):

def init_embedding(
self,
model_name: str,
cache_dir: str,
self,
model_name: str,
cache_dir: str,
) -> OnnxTextEmbedding:
raise NotImplementedError()

def __init__(
self,
model_name: str,
cache_dir: str,
self,
model_name: str,
cache_dir: str,
):
self.model = self.init_embedding(model_name, cache_dir)

Expand All @@ -301,8 +300,8 @@ def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]]

class OnnxTextEmbeddingWorker(EmbeddingWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
self,
model_name: str,
cache_dir: str,
) -> OnnxTextEmbedding:
return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1)
20 changes: 10 additions & 10 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def list_supported_models(cls) -> List[Dict[str, Any]]:
return result

def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs
self,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
):
super().__init__(model_name, cache_dir, threads, **kwargs)

Expand All @@ -67,11 +67,11 @@ def __init__(
)

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: int = None,
**kwargs,
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: int = None,
**kwargs,
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
Expand Down
19 changes: 6 additions & 13 deletions fastembed/text/text_embedding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,17 @@ class TextEmbeddingBase(ModelManagement):
def list_supported_models(cls) -> List[Dict[str, Any]]:
raise NotImplementedError()

def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs
):
def __init__(self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, **kwargs):
self.model_name = model_name
self.cache_dir = cache_dir
self.threads = threads

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: int = None,
**kwargs,
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: int = None,
**kwargs,
) -> Iterable[np.ndarray]:
raise NotImplementedError()

Expand Down Expand Up @@ -59,4 +53,3 @@ def query_embed(self, query: str, **kwargs) -> np.ndarray:
# This is model-specific, so that different models can have specialized implementations
query_embedding = list(self.embed([query], **kwargs))[0]
return query_embedding

12 changes: 2 additions & 10 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ def test_embedding():


@pytest.mark.parametrize(
"n_dims,model_name",
[
(384, "BAAI/bge-small-en-v1.5"),
(768, "jinaai/jina-embeddings-v2-base-en")
]
"n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")]
)
def test_batch_embedding(n_dims, model_name):
model = TextEmbedding(model_name=model_name)
Expand All @@ -57,11 +53,7 @@ def test_batch_embedding(n_dims, model_name):


@pytest.mark.parametrize(
"n_dims,model_name",
[
(384, "BAAI/bge-small-en-v1.5"),
(768, "jinaai/jina-embeddings-v2-base-en")
]
"n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")]
)
def test_parallel_processing(n_dims, model_name):
model = TextEmbedding(model_name=model_name)
Expand Down

0 comments on commit d0005d4

Please sign in to comment.