-
Notifications
You must be signed in to change notification settings - Fork 1
/
simple_token_classification.py
97 lines (78 loc) · 3.98 KB
/
simple_token_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
from typing import MutableMapping, Optional, Tuple, Union
import torch
from pytorch_ie.core import PyTorchIEModel
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
from pytorch_lightning.utilities.types import OptimizerLRScheduler
from torch import FloatTensor, LongTensor
from transformers import AutoConfig, AutoModelForTokenClassification, BatchEncoding
from transformers.modeling_outputs import TokenClassifierOutput
from typing_extensions import TypeAlias
from pie_modules.models.common import ModelWithBoilerplate
# model inputs / outputs / targets
InputType: TypeAlias = BatchEncoding
OutputType: TypeAlias = TokenClassifierOutput
TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]]
# step inputs (batch) / outputs (loss)
StepInputType: TypeAlias = Tuple[InputType, TargetType]
StepOutputType: TypeAlias = FloatTensor
logger = logging.getLogger(__name__)
@PyTorchIEModel.register()
class SimpleTokenClassificationModel(
ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
RequiresModelNameOrPath,
RequiresNumClasses,
):
"""A simple token classification model that wraps a (pretrained) model loaded with
AutoModelForTokenClassification from the transformers library.
The model is trained with a cross-entropy loss function and uses the Adam optimizer.
Note that for training, the labels for the special tokens (as well as for padding tokens)
are expected to have the value label_pad_id (-100 by default, which is the default ignore_index
value for the CrossEntropyLoss). The predictions for these tokens are also replaced with
label_pad_id to match the training labels for correct metric calculation. Therefore, the model
requires the special_tokens_mask and attention_mask (for padding) to be passed as inputs.
Args:
model_name_or_path: The name or path of the pretrained transformer model to use.
num_classes: The number of classes to predict.
learning_rate: The learning rate to use for training.
label_pad_id: The label id to use for padding labels (at the padding token positions
as well as for the special tokens).
"""
def __init__(
self,
model_name_or_path: str,
num_classes: int,
learning_rate: float = 1e-5,
label_pad_id: int = -100,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.save_hyperparameters()
self.learning_rate = learning_rate
self.label_pad_id = label_pad_id
self.num_classes = num_classes
config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_classes)
if self.is_from_pretrained:
self.model = AutoModelForTokenClassification.from_config(config=config)
else:
self.model = AutoModelForTokenClassification.from_pretrained(
model_name_or_path, config=config
)
def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
inputs_without_special_tokens_mask = {
k: v for k, v in inputs.items() if k != "special_tokens_mask"
}
return self.model(**inputs_without_special_tokens_mask, **(targets or {}))
def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
# get the max index for each token from the logits
tags_tensor = torch.argmax(outputs.logits, dim=-1).to(torch.long)
# mask out the padding and special tokens
tags_tensor = tags_tensor.masked_fill(inputs["attention_mask"] == 0, self.label_pad_id)
# mask out the special tokens
tags_tensor = tags_tensor.masked_fill(
inputs["special_tokens_mask"] == 1, self.label_pad_id
)
probabilities = torch.softmax(outputs.logits, dim=-1)
return {"labels": tags_tensor, "probabilities": probabilities}
def configure_optimizers(self) -> OptimizerLRScheduler:
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)