Skip to content

Commit

Permalink
Resolved merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Dec 28, 2024
2 parents 5212c0a + 8074182 commit c49b7a0
Show file tree
Hide file tree
Showing 13 changed files with 1,215 additions and 638 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ install:

install-no-pre-commit:
uv pip install ".[dev,distill]"
uv pip install "torch<2.5.0"

install-base:
uv sync --extra dev
Expand Down
264 changes: 133 additions & 131 deletions README.md

Large diffs are not rendered by default.

52 changes: 41 additions & 11 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import logging
import re
from typing import Literal, Union

import numpy as np
from huggingface_hub import model_info
from sklearn.decomposition import PCA
from tokenizers import Tokenizer
from tokenizers.models import BPE, Unigram
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast

Expand Down Expand Up @@ -39,6 +41,7 @@ def distill_from_model(
pca_dims: PCADimType = 256,
apply_zipf: bool = True,
use_subword: bool = True,
token_remove_pattern: str | None = r"\[unused\d+\]",
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -58,8 +61,12 @@ def distill_from_model(
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
:raises: ValueError if the vocabulary contains duplicate tokens.
:raises: ValueError if the regex can't be compiled.
:raises: ValueError if the vocabulary is empty after token removal.
:return: A StaticModel
"""
Expand All @@ -81,17 +88,7 @@ def distill_from_model(
if use_subword:
# Create the subword embeddings.
tokens, embeddings = create_output_embeddings_from_model_name(model=model, tokenizer=tokenizer, device=device)

# Remove any unused tokens from the tokenizer and embeddings.
wrong_tokens = [x for x in tokens if x.startswith("[unused")]
vocab = tokenizer.get_vocab()
# Get the ids of the unused token.
wrong_token_ids = [vocab[token] for token in wrong_tokens]
# Remove the unused tokens from the tokenizer.
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, wrong_tokens)
# Remove the embeddings of the unused tokens.
embeddings = np.delete(embeddings, wrong_token_ids, axis=0)
logger.info(f"Removed {len(wrong_tokens)} unused tokens from the tokenizer and embeddings.")
new_tokenizer, embeddings = _remove_tokens_and_embeddings(tokenizer, token_remove_pattern, tokens, embeddings)
else:
# We need to keep the unk token in the tokenizer.
unk_token = tokenizer.backend_tokenizer.model.unk_token
Expand Down Expand Up @@ -136,6 +133,8 @@ def distill_from_model(
model_name = getattr(model, "name_or_path", "")

config = {
"model_type": "model2vec",
"architectures": ["StaticModel"],
"tokenizer_name": model_name,
"apply_pca": pca_dims,
"apply_zipf": apply_zipf,
Expand All @@ -155,6 +154,37 @@ def distill_from_model(
)


def _remove_tokens_and_embeddings(
tokenizer: PreTrainedTokenizerFast, token_remove_pattern: str | None, tokens: list[str], embeddings: np.ndarray
) -> tuple[Tokenizer, np.ndarray]:
if not token_remove_pattern:
return tokenizer.backend_tokenizer, embeddings

try:
token_regex = re.compile(token_remove_pattern)
except re.error as e:
raise ValueError(f"Invalid regex pattern: {token_remove_pattern}") from e
# Remove any unused tokens from the tokenizer and embeddings.
wrong_tokens = [x for x in tokens if token_regex.match(x)]
vocab = tokenizer.get_vocab()
# Get the ids of the unused token.
wrong_token_ids = [vocab[token] for token in wrong_tokens]

if len(wrong_token_ids) == len(vocab):
raise ValueError(
"All tokens in the vocabulary are unused tokens. This will result in an empty tokenizer. "
"Please provide a valid token removal pattern. The pattern is now: {token_remove_pattern}"
)

# Remove the unused tokens from the tokenizer.
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, wrong_tokens)
# Remove the embeddings of the unused tokens.
embeddings = np.delete(embeddings, wrong_token_ids, axis=0)
logger.info(f"Removed {len(wrong_tokens)} unused tokens from the tokenizer and embeddings.")

return new_tokenizer, embeddings


def distill(
model_name: str,
vocabulary: list[str] | None = None,
Expand Down
6 changes: 4 additions & 2 deletions model2vec/distill/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
tokenizer_data["model"]["vocab"] = reindexed

elif model_type == "Unigram":
raise ValueError("Removing tokens from a unigram tokenizer is not supported.")
logger.warning("Removing tokens from a unigram tokenizer is not supported.")
return tokenizer

elif model_type == "BPE":
raise ValueError("Removing tokens from a BPE tokenizer is not supported.")
logger.warning("Removing tokens from a BPE tokenizer is not supported.")
return tokenizer

else:
raise ValueError(f"Unknown model type {model_type}")
Expand Down
79 changes: 60 additions & 19 deletions model2vec/model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from __future__ import annotations

import math
import os
from logging import getLogger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Iterator, Union

import numpy as np
from joblib import delayed
from tokenizers import Encoding, Tokenizer
from tqdm import tqdm

from model2vec.utils import load_local_model
from model2vec.utils import ProgressParallel, load_local_model

PathLike = Union[Path, str]


logger = getLogger(__name__)


Expand Down Expand Up @@ -171,6 +172,8 @@ def encode_as_sequence(
max_length: int | None = None,
batch_size: int = 1024,
show_progress_bar: bool = False,
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
) -> list[np.ndarray] | np.ndarray:
"""
Encode a list of sentences as a list of numpy arrays of tokens.
Expand All @@ -186,24 +189,42 @@ def encode_as_sequence(
If this is None, no truncation is done.
:param batch_size: The batch size to use.
:param show_progress_bar: Whether to show the progress bar.
:param use_multiprocessing: Whether to use multiprocessing.
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
:return: The encoded sentences with an embedding per token.
"""
was_single = False
if isinstance(sentences, str):
sentences = [sentences]
was_single = True

out_array: list[np.ndarray] = []
for batch in tqdm(
self._batch(sentences, batch_size),
total=math.ceil(len(sentences) / batch_size),
disable=not show_progress_bar,
):
out_array.extend(self._encode_batch_as_sequence(batch, max_length))
# Prepare all batches
sentence_batches = list(self._batch(sentences, batch_size))
total_batches = math.ceil(len(sentences) / batch_size)

# Use joblib for multiprocessing if requested, and if we have enough sentences
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
# Disable parallelism for tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"

results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
delayed(self._encode_batch_as_sequence)(batch, max_length) for batch in sentence_batches
)
out_array: list[np.ndarray] = []
for r in results:
out_array.extend(r)
else:
out_array = []
for batch in tqdm(
sentence_batches,
total=total_batches,
disable=not show_progress_bar,
):
out_array.extend(self._encode_batch_as_sequence(batch, max_length))

if was_single:
return out_array[0]

return out_array

def _encode_batch_as_sequence(self, sentences: list[str], max_length: int | None) -> list[np.ndarray]:
Expand All @@ -224,6 +245,8 @@ def encode(
show_progress_bar: bool = False,
max_length: int | None = 512,
batch_size: int = 1024,
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
**kwargs: Any,
) -> np.ndarray:
"""
Expand All @@ -237,6 +260,9 @@ def encode(
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
If this is None, no truncation is done.
:param batch_size: The batch size to use.
:param use_multiprocessing: Whether to use multiprocessing.
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
:param **kwargs: Any additional arguments. These are ignored.
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
"""
Expand All @@ -245,19 +271,34 @@ def encode(
sentences = [sentences]
was_single = True

out_arrays: list[np.ndarray] = []
for batch in tqdm(
self._batch(sentences, batch_size),
total=math.ceil(len(sentences) / batch_size),
disable=not show_progress_bar,
):
out_arrays.append(self._encode_batch(batch, max_length))
# Prepare all batches
sentence_batches = list(self._batch(sentences, batch_size))
total_batches = math.ceil(len(sentences) / batch_size)

out_array = np.concatenate(out_arrays, axis=0)
ids = self.tokenize(sentences=sentences, max_length=max_length)

# Use joblib for multiprocessing if requested, and if we have enough sentences
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
# Disable parallelism for tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"

results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
delayed(self._encode_batch)(batch, max_length) for batch in sentence_batches
)
out_array = np.concatenate(results, axis=0)
else:
# Don't use multiprocessing
out_arrays: list[np.ndarray] = []
for batch in tqdm(
sentence_batches,
total=total_batches,
disable=not show_progress_bar,
):
out_arrays.append(self._encode_batch(batch, max_length))
out_array = np.concatenate(out_arrays, axis=0)

if was_single:
return out_array[0]

return out_array

def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndarray:
Expand Down
44 changes: 42 additions & 2 deletions model2vec/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,56 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import json
import logging
import re
from importlib import import_module
from importlib.metadata import metadata
from pathlib import Path
from typing import Iterator, Protocol, cast
from typing import Any, Iterator, Protocol, cast

import numpy as np
import safetensors
from joblib import Parallel
from tokenizers import Tokenizer
from tqdm import tqdm

logger = logging.getLogger(__name__)


class ProgressParallel(Parallel):
"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""

def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
"""
Initialize the ProgressParallel object.
:param use_tqdm: Whether to show the progress bar.
:param total: Total number of tasks (batches) you expect to process. If None,
it updates the total dynamically to the number of dispatched tasks.
:param *args: Additional arguments to pass to `Parallel.__init__`.
:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
"""
self._use_tqdm = use_tqdm
self._total = total
super().__init__(*args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Create a tqdm context."""
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
self._pbar = self._pbar
return super().__call__(*args, **kwargs)

def print_progress(self) -> None:
"""Hook called by joblib as tasks complete. We update the tqdm bar here."""
if self._total is None:
# If no fixed total was given, we dynamically set the total
self._pbar.total = self.n_dispatched_tasks
# Move the bar to the number of completed tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()


class SafeOpenProtocol(Protocol):
"""Protocol to fix safetensors safe open."""

Expand All @@ -22,6 +60,7 @@ def get_tensor(self, key: str) -> np.ndarray:


_MODULE_MAP = (("scikit-learn", "sklearn"),)
_DIVIDERS = re.compile(r"[=<>!]+")


def get_package_extras(package: str, extra: str) -> Iterator[str]:
Expand All @@ -38,7 +77,8 @@ def get_package_extras(package: str, extra: str) -> Iterator[str]:
# Extract and clean the extra requirement
found_extra = rest[0].split("==")[-1].strip(" \"'")
if found_extra == extra:
yield name.strip()
prefix, *_ = _DIVIDERS.split(name)
yield prefix.strip()


def importable(module: str, extra: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion model2vec/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version_triple__ = (0, 3, 2)
__version_triple__ = (0, 3, 3)
__version__ = ".".join(map(str, __version_triple__))
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "model2vec"
description = "Distill a Small Static Model from any Sentence Transformer"
description = "The Fastest State-of-the-Art Static Embeddings in the World"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
requires-python = ">=3.9"
Expand All @@ -23,12 +23,14 @@ classifiers = [
]

dependencies = [
"jinja2",
"joblib",
"numpy",
"rich",
"tqdm",
"tokenizers>=0.20",
"safetensors",
"setuptools",
"tokenizers>=0.20",
"tqdm",
]

[build-system]
Expand All @@ -49,11 +51,10 @@ dev = [
"mypy",
"pre-commit",
"pytest",
"pytest-coverage",
"pytest-cov",
"ruff",
]
distill = ["torch", "transformers", "scikit-learn"]

onnx = ["onnx", "torch"]

[project.urls]
Expand Down
Loading

0 comments on commit c49b7a0

Please sign in to comment.