Skip to content

Commit

Permalink
HF sources for all models
Browse files Browse the repository at this point in the history
  • Loading branch information
I8dNLo committed Dec 24, 2024
1 parent 4200834 commit d44ae27
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union

import numpy as np

Expand All @@ -16,6 +16,7 @@
"license": "mit",
"size_in_GB": 0.42,
"sources": {
"hf": "Qdrant/fast-bge-base-en",
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz",
},
"model_file": "model_optimized.onnx",
Expand Down Expand Up @@ -50,6 +51,7 @@
"license": "mit",
"size_in_GB": 0.13,
"sources": {
"hf": "Qdrant/bge-small-en",
"url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
},
"model_file": "model_optimized.onnx",
Expand All @@ -72,6 +74,7 @@
"license": "mit",
"size_in_GB": 0.09,
"sources": {
"hf": "Qdrant/bge-small-zh-v1.5",
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
},
"model_file": "model_optimized.onnx",
Expand Down Expand Up @@ -165,15 +168,16 @@
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"model": "akshayballal/colpali-v1.2-merged",
"dim": 128,
"description": "",
"license": "mit",
"size_in_GB": 6.08,
"sources": {
"hf": "jinaai/jina-clip-v1",
"hf": "akshayballal/colpali-v1.2-merged-onnx",
},
"model_file": "onnx/text_model.onnx",
"additional_files": ["model.onnx_data"],
"model_file": "model.onnx",
},
]

Expand All @@ -182,12 +186,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]):
"""Implementation of the Flag Embedding model."""

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""
Lists the supported models.
Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_onnx_models

Expand All @@ -198,7 +202,7 @@ def __init__(
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
device_ids: Optional[List[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
**kwargs,
Expand All @@ -214,7 +218,7 @@ def __init__(
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
Expand Down Expand Up @@ -287,22 +291,16 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
return OnnxTextEmbeddingWorker

def _preprocess_onnx_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
self, onnx_input: Dict[str, np.ndarray], **kwargs
) -> Dict[str, np.ndarray]:
"""
Preprocess the onnx input.
"""
return onnx_input

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
embeddings = output.model_output
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
processed_embeddings = embeddings[:, 0]
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)
return normalize(embeddings[:, 0]).astype(np.float32)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down

0 comments on commit d44ae27

Please sign in to comment.