Skip to content

Commit

Permalink
feat: embedding progress bar (#71)
Browse files Browse the repository at this point in the history
* feat: embedding progress

* refactor: with auto __close__

* refactor: with __exit__ tqdm
  • Loading branch information
Anush008 authored Dec 12, 2023
1 parent e274dd0 commit 2c7fee3
Showing 1 changed file with 48 additions and 37 deletions.
85 changes: 48 additions & 37 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import math
import os
import shutil
import tarfile
Expand Down Expand Up @@ -461,8 +462,8 @@ def __init__(
Args:
model_name (str): The name of the model to use.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
cache_dir (str, optional): The path to the cache directory. \
Can be set using the `FASTEMBED_CACHE_PATH` env variable. \
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
Expand All @@ -484,7 +485,7 @@ def __init__(
max_threads=threads)

def embed(
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None, show_progress: bool = True
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
Expand All @@ -497,6 +498,7 @@ def embed(
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
Returns:
List of embeddings, one per document
Expand All @@ -513,22 +515,26 @@ def embed(

if parallel == 0:
parallel = os.cpu_count()

if parallel is None or is_small:
for batch in iter_batch(documents, batch_size):
embeddings, _ = self.model.onnx_embed(batch)
yield from normalize(embeddings[:, 0]).astype(np.float32)
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
embeddings, _ = batch
yield from normalize(embeddings[:, 0]).astype(np.float32)

with tqdm(total=len(documents), disable=not show_progress) as progress_bar:
batch_iterable = iter_batch(documents, batch_size)
if parallel is None or is_small:
for batch in batch_iterable:
embeddings, _ = self.model.onnx_embed(batch)
yield from normalize(embeddings[:, 0]).astype(np.float32)
progress_bar.update(len(embeddings))
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(batch_iterable, **params):
embeddings, _ = batch
yield from normalize(embeddings[:, 0]).astype(np.float32)
progress_bar.update(len(embeddings))

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
Expand Down Expand Up @@ -581,8 +587,8 @@ def __init__(
Args:
model_name (str): The name of the model to use.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
cache_dir (str, optional): The path to the cache directory. \
Can be set using the `FASTEMBED_CACHE_PATH` env variable. \
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
Raises:
Expand All @@ -603,7 +609,7 @@ def __init__(
max_threads=threads)

def embed(
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None, show_progress: bool = True
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
Expand All @@ -615,6 +621,7 @@ def embed(
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
Returns:
List of embeddings, one per document
"""
Expand All @@ -631,21 +638,25 @@ def embed(
if parallel == 0:
parallel = os.cpu_count()

if parallel is None or is_small:
for batch in iter_batch(documents, batch_size):
embeddings, attn_mask = self.model.onnx_embed(batch)
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
embeddings, attn_mask = batch
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
with tqdm(total=len(documents), disable=not show_progress) as progress_bar:
batch_iterable = iter_batch(documents, batch_size)
if parallel is None or is_small:
for batch in batch_iterable:
embeddings, attn_mask = self.model.onnx_embed(batch)
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
progress_bar.update(len(embeddings))
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(batch_iterable, **params):
embeddings, attn_mask = batch
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
progress_bar.update(len(embeddings))

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
Expand Down

0 comments on commit 2c7fee3

Please sign in to comment.