-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtransformer_text_classification.py
138 lines (106 loc) · 4.56 KB
/
transformer_text_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import logging
from typing import Any, Dict, MutableMapping, Optional, Tuple
import torchmetrics
from torch import Tensor, nn
from torch.optim import AdamW
from transformers import AutoConfig, AutoModel, get_linear_schedule_with_warmup
from typing_extensions import TypeAlias
from pytorch_ie.core import PyTorchIEModel
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
ModelInputType: TypeAlias = MutableMapping[str, Any]
ModelOutputType: TypeAlias = Dict[str, Any]
ModelStepInputType = Tuple[
ModelInputType,
Optional[Tensor],
]
TRAINING = "train"
VALIDATION = "val"
TEST = "test"
logger = logging.getLogger(__name__)
@PyTorchIEModel.register()
class TransformerTextClassificationModel(
PyTorchIEModel, RequiresModelNameOrPath, RequiresNumClasses
):
def __init__(
self,
model_name_or_path: str,
num_classes: int,
tokenizer_vocab_size: Optional[int] = None,
ignore_index: Optional[int] = None,
learning_rate: float = 1e-5,
task_learning_rate: float = 1e-4,
warmup_proportion: float = 0.1,
freeze_model: bool = False,
multi_label: bool = False,
t_total: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if t_total is not None:
logger.warning(
"t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead"
)
self.save_hyperparameters(ignore=["t_total"])
self.learning_rate = learning_rate
self.task_learning_rate = task_learning_rate
self.warmup_proportion = warmup_proportion
config = AutoConfig.from_pretrained(model_name_or_path)
if self.is_from_pretrained:
self.model = AutoModel.from_config(config=config)
else:
self.model = AutoModel.from_pretrained(model_name_or_path, config=config)
if freeze_model:
for param in self.model.parameters():
param.requires_grad = False
if tokenizer_vocab_size is not None:
self.model.resize_token_embeddings(tokenizer_vocab_size)
classifier_dropout = (
config.classifier_dropout
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, num_classes)
self.loss_fct = nn.BCEWithLogitsLoss() if multi_label else nn.CrossEntropyLoss()
self.f1 = nn.ModuleDict(
{
f"stage_{stage}": torchmetrics.F1Score(
num_classes=num_classes,
ignore_index=ignore_index,
task="multilabel" if multi_label else "multiclass",
)
for stage in [TRAINING, VALIDATION, TEST]
}
)
def forward(self, inputs: ModelInputType) -> ModelOutputType:
output = self.model(**inputs)
hidden_state = output.last_hidden_state
cls_embeddings = hidden_state[:, 0, :]
logits = self.classifier(cls_embeddings)
return {"logits": logits}
def step(self, stage: str, batch: ModelStepInputType):
inputs, target = batch
assert target is not None, "target has to be available for training"
logits = self(inputs)["logits"]
loss = self.loss_fct(logits, target)
self.log(f"{stage}/loss", loss, on_step=(stage == TRAINING), on_epoch=True, prog_bar=True)
f1 = self.f1[f"stage_{stage}"]
f1(logits, target)
self.log(f"{stage}/f1", f1, on_step=False, on_epoch=True, prog_bar=True)
return loss
def training_step(self, batch: ModelStepInputType, batch_idx: int):
return self.step(stage=TRAINING, batch=batch)
def validation_step(self, batch: ModelStepInputType, batch_idx: int):
return self.step(stage=VALIDATION, batch=batch)
def test_step(self, batch: ModelStepInputType, batch_idx: int):
return self.step(stage=TEST, batch=batch)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.learning_rate)
if self.warmup_proportion > 0.0:
stepping_batches = self.trainer.estimated_stepping_batches
scheduler = get_linear_schedule_with_warmup(
optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
else:
return optimizer