Skip to content

Commit

Permalink
feat: support transformer quantization and dtype change
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 9, 2024
1 parent 51c6316 commit 0de3307
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
4 changes: 3 additions & 1 deletion edspdf/pipes/classifiers/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def forward(self, batch: Dict) -> Dict:
output = {"loss": 0, "mask": embeddings.mask}

# Label prediction / learning
logits = self.classifier(embeddings).refold("line")
logits = self.classifier(embeddings.to(self.classifier.weight.dtype)).refold(
"line"
)
if "labels" in batch:
targets = batch["labels"].refold(logits.data_dims)
output["label_loss"] = (
Expand Down
34 changes: 32 additions & 2 deletions edspdf/pipes/embeddings/huggingface_embedding.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import math
import sys
from typing import Optional, Set

import torch
from confit import validate_arguments
from foldedtensor import as_folded_tensor
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
from transformers import BitsAndBytesConfig as BitsAndBytesConfig_
from typing_extensions import Literal

from edspdf import TrainablePipe, registry
from edspdf.pipeline import Pipeline
from edspdf.pipes.embeddings import EmbeddingOutput
from edspdf.structures import PDFDoc

BitsAndBytesConfig = validate_arguments(BitsAndBytesConfig_)


def compute_contextualization_scores(windows):
ramp = torch.arange(0, windows.shape[1], 1)
Expand Down Expand Up @@ -108,6 +114,11 @@ class HuggingfaceEmbedding(TrainablePipe[EmbeddingOutput]):
The maximum number of tokens that can be processed by the model on a single
device. This does not affect the results but can be used to reduce the memory
usage of the model, at the cost of a longer processing time.
quantization_config: Optional[BitsAndBytesConfig]
The quantization configuration to use when loading the model
kwargs:
Additional keyword arguments to pass to the Huggingface
`AutoModel.from_pretrained` method
"""

def __init__(
Expand All @@ -119,7 +130,9 @@ def __init__(
window: int = 510,
stride: int = 255,
line_pooling: Literal["mean", "max", "sum"] = "mean",
max_tokens_per_device: int = 128 * 128,
max_tokens_per_device: int = sys.maxsize,
quantization_config: Optional[BitsAndBytesConfig] = None,
**kwargs,
):
super().__init__(pipeline, name)
self.use_image = use_image
Expand All @@ -129,7 +142,11 @@ def __init__(
else None
)
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.hf_model = AutoModel.from_pretrained(model)
self.hf_model = AutoModel.from_pretrained(
model,
quantization_config=quantization_config,
**kwargs,
)
self.output_size = self.hf_model.config.hidden_size
self.window = window
self.stride = stride
Expand Down Expand Up @@ -335,3 +352,16 @@ def forward(self, batch):
mode=self.line_pooling,
)
return {"embeddings": line_embedding}

def to_disk(self, path, *, exclude: Optional[Set[str]]):
repr_id = object.__repr__(self)
if repr_id in exclude:
return
for obj in (self.tokenizer, self.image_processor, self.hf_model):
if obj is not None:
obj.save_pretrained(path)
for param in self.hf_model.parameters():
exclude.add(object.__repr__(param))
cfg = super().to_disk(path, exclude=exclude) or {}
cfg["model"] = f"./{path.as_posix()}"
return cfg

0 comments on commit 0de3307

Please sign in to comment.