Skip to content

Commit

Permalink
feat: Added multiprocessing threshold parameter (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled authored Dec 27, 2024
1 parent ecf022f commit 8074182
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

logger = getLogger(__name__)

_MULTIPROCESSING_THRESHOLD = 10_000 # Minimum number of sentences to use multiprocessing


class StaticModel:
def __init__(
Expand Down Expand Up @@ -175,6 +173,7 @@ def encode_as_sequence(
batch_size: int = 1024,
show_progress_bar: bool = False,
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
) -> list[np.ndarray] | np.ndarray:
"""
Encode a list of sentences as a list of numpy arrays of tokens.
Expand All @@ -191,7 +190,8 @@ def encode_as_sequence(
:param batch_size: The batch size to use.
:param show_progress_bar: Whether to show the progress bar.
:param use_multiprocessing: Whether to use multiprocessing.
By default, this is enabled for inputs > 10k sentences and disabled otherwise.
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
:return: The encoded sentences with an embedding per token.
"""
was_single = False
Expand All @@ -204,7 +204,7 @@ def encode_as_sequence(
total_batches = math.ceil(len(sentences) / batch_size)

# Use joblib for multiprocessing if requested, and if we have enough sentences
if use_multiprocessing and len(sentences) > _MULTIPROCESSING_THRESHOLD:
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
# Disable parallelism for tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down Expand Up @@ -246,6 +246,7 @@ def encode(
max_length: int | None = 512,
batch_size: int = 1024,
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
**kwargs: Any,
) -> np.ndarray:
"""
Expand All @@ -260,7 +261,8 @@ def encode(
If this is None, no truncation is done.
:param batch_size: The batch size to use.
:param use_multiprocessing: Whether to use multiprocessing.
By default, this is enabled for inputs > 10k sentences and disabled otherwise.
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
:param **kwargs: Any additional arguments. These are ignored.
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
"""
Expand All @@ -276,7 +278,7 @@ def encode(
ids = self.tokenize(sentences=sentences, max_length=max_length)

# Use joblib for multiprocessing if requested, and if we have enough sentences
if use_multiprocessing and len(sentences) > _MULTIPROCESSING_THRESHOLD:
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
# Disable parallelism for tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down

0 comments on commit 8074182

Please sign in to comment.