Skip to content

Commit

Permalink
Added initial setup for multiprocessed encode
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Dec 26, 2024
1 parent 453527f commit 8248d18
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 27 deletions.
36 changes: 27 additions & 9 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
logger = getLogger(__name__)


from joblib import delayed
from tqdm.auto import tqdm

from model2vec.utils import ProgressParallel


class StaticModel:
def __init__(
self,
Expand Down Expand Up @@ -224,6 +230,7 @@ def encode(
show_progress_bar: bool = False,
max_length: int | None = 512,
batch_size: int = 1024,
use_multiprocessing: bool = True,
**kwargs: Any,
) -> np.ndarray:
"""
Expand All @@ -237,6 +244,7 @@ def encode(
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
If this is None, no truncation is done.
:param batch_size: The batch size to use.
:param use_multiprocessing: Whether to use multiprocessing.
:param **kwargs: Any additional arguments. These are ignored.
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
"""
Expand All @@ -245,19 +253,29 @@ def encode(
sentences = [sentences]
was_single = True

out_arrays: list[np.ndarray] = []
for batch in tqdm(
self._batch(sentences, batch_size),
total=math.ceil(len(sentences) / batch_size),
disable=not show_progress_bar,
):
out_arrays.append(self._encode_batch(batch, max_length))
# Prepare all batches
sentence_batches = list(self._batch(sentences, batch_size))
total_batches = math.ceil(len(sentences) / batch_size)

out_array = np.concatenate(out_arrays, axis=0)
if use_multiprocessing:
# Use joblib for multiprocessing if requested
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
delayed(self._encode_batch)(batch, max_length) for batch in sentence_batches
)
out_array = np.concatenate(results, axis=0)
else:
# Don't use multiprocessing
out_arrays: list[np.ndarray] = []
for batch in tqdm(
sentence_batches,
total=total_batches,
disable=not show_progress_bar,
):
out_arrays.append(self._encode_batch(batch, max_length))
out_array = np.concatenate(out_arrays, axis=0)

if was_single:
return out_array[0]

return out_array

def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndarray:
Expand Down
37 changes: 36 additions & 1 deletion model2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,50 @@
from importlib import import_module
from importlib.metadata import metadata
from pathlib import Path
from typing import Iterator, Protocol, cast
from typing import Any, Iterator, Protocol, cast

import numpy as np
import safetensors
from joblib import Parallel
from tokenizers import Tokenizer
from tqdm import tqdm

logger = logging.getLogger(__name__)


class ProgressParallel(Parallel):
"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""

def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
"""
Initialize the ProgressParallel object.
:param use_tqdm: Whether to show the progress bar.
:param total: Total number of tasks (batches) you expect to process. If None,
it updates the total dynamically to the number of dispatched tasks.
:param *args: Additional arguments to pass to `Parallel.__init__`.
:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
"""
self._use_tqdm = use_tqdm
self._total = total
super().__init__(*args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Create a tqdm context."""
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
self._pbar = self._pbar
return super().__call__(*args, **kwargs)

def print_progress(self) -> None:
"""Hook called by joblib as tasks complete. We update the tqdm bar here."""
if self._total is None:
# If no fixed total was given, we dynamically set the total
self._pbar.total = self.n_dispatched_tasks
# Move the bar to the number of completed tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()


class SafeOpenProtocol(Protocol):
"""Protocol to fix safetensors safe open."""

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ classifiers = [

dependencies = [
"jinja2",
"joblib",
"numpy",
"rich",
"safetensors",
Expand Down
36 changes: 19 additions & 17 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8248d18

Please sign in to comment.