diff --git a/fastembed/embedding.py b/fastembed/embedding.py index f642c65e..ec40d0ec 100644 --- a/fastembed/embedding.py +++ b/fastembed/embedding.py @@ -1,4 +1,5 @@ import json +import math import os import shutil import tarfile @@ -461,8 +462,8 @@ def __init__( Args: model_name (str): The name of the model to use. max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512. - cache_dir (str, optional): The path to the cache directory. - Can be set using the `FASTEMBED_CACHE_PATH` env variable. + cache_dir (str, optional): The path to the cache directory. \ + Can be set using the `FASTEMBED_CACHE_PATH` env variable. \ Defaults to `fastembed_cache` in the system's temp directory. threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. @@ -484,7 +485,7 @@ def __init__( max_threads=threads) def embed( - self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None + self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None, show_progress: bool = True ) -> Iterable[np.ndarray]: """ Encode a list of documents into list of embeddings. @@ -497,6 +498,7 @@ def embed( If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. + show_progress (bool, optional): Whether to show a progress bar. Defaults to True. Returns: List of embeddings, one per document @@ -513,22 +515,26 @@ def embed( if parallel == 0: parallel = os.cpu_count() - - if parallel is None or is_small: - for batch in iter_batch(documents, batch_size): - embeddings, _ = self.model.onnx_embed(batch) - yield from normalize(embeddings[:, 0]).astype(np.float32) - else: - start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" - params = { - "path": self._model_dir, - "model_name": self.model_name, - "max_length": self._max_length, - } - pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method) - for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): - embeddings, _ = batch - yield from normalize(embeddings[:, 0]).astype(np.float32) + + with tqdm(total=len(documents), disable=not show_progress) as progress_bar: + batch_iterable = iter_batch(documents, batch_size) + if parallel is None or is_small: + for batch in batch_iterable: + embeddings, _ = self.model.onnx_embed(batch) + yield from normalize(embeddings[:, 0]).astype(np.float32) + progress_bar.update(len(embeddings)) + else: + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" + params = { + "path": self._model_dir, + "model_name": self.model_name, + "max_length": self._max_length, + } + pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method) + for batch in pool.ordered_map(batch_iterable, **params): + embeddings, _ = batch + yield from normalize(embeddings[:, 0]).astype(np.float32) + progress_bar.update(len(embeddings)) @classmethod def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]: @@ -581,8 +587,8 @@ def __init__( Args: model_name (str): The name of the model to use. max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512. - cache_dir (str, optional): The path to the cache directory. - Can be set using the `FASTEMBED_CACHE_PATH` env variable. + cache_dir (str, optional): The path to the cache directory. \ + Can be set using the `FASTEMBED_CACHE_PATH` env variable. \ Defaults to `fastembed_cache` in the system's temp directory. threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. Raises: @@ -603,7 +609,7 @@ def __init__( max_threads=threads) def embed( - self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None + self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None, show_progress: bool = True ) -> Iterable[np.ndarray]: """ Encode a list of documents into list of embeddings. @@ -615,6 +621,7 @@ def embed( If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. + show_progress (bool, optional): Whether to show a progress bar. Defaults to True. Returns: List of embeddings, one per document """ @@ -631,21 +638,25 @@ def embed( if parallel == 0: parallel = os.cpu_count() - if parallel is None or is_small: - for batch in iter_batch(documents, batch_size): - embeddings, attn_mask = self.model.onnx_embed(batch) - yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) - else: - start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" - params = { - "path": self._model_dir, - "model_name": self.model_name, - "max_length": self._max_length, - } - pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method) - for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): - embeddings, attn_mask = batch - yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) + with tqdm(total=len(documents), disable=not show_progress) as progress_bar: + batch_iterable = iter_batch(documents, batch_size) + if parallel is None or is_small: + for batch in batch_iterable: + embeddings, attn_mask = self.model.onnx_embed(batch) + yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) + progress_bar.update(len(embeddings)) + else: + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" + params = { + "path": self._model_dir, + "model_name": self.model_name, + "max_length": self._max_length, + } + pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method) + for batch in pool.ordered_map(batch_iterable, **params): + embeddings, attn_mask = batch + yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) + progress_bar.update(len(embeddings)) @classmethod def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]: