-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from ArneBinder/refactor_token_classification
refactor token classification
- Loading branch information
Showing
24 changed files
with
1,823 additions
and
775 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .model import Model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import abc | ||
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar | ||
|
||
from pytorch_ie import PyTorchIEModel | ||
from typing_extensions import TypeAlias | ||
|
||
from pie_modules.models.mixins import WithMetricsFromTaskModule | ||
|
||
InputType = TypeVar("InputType") | ||
OutputType = TypeVar("OutputType") | ||
TargetType = TypeVar("TargetType") | ||
StepInputType: TypeAlias = Tuple[ | ||
InputType, | ||
Optional[TargetType], | ||
] | ||
StepOutputType = TypeVar("StepOutputType") | ||
|
||
TRAINING = "train" | ||
VALIDATION = "val" | ||
TEST = "test" | ||
|
||
|
||
class Model( | ||
PyTorchIEModel, | ||
WithMetricsFromTaskModule[InputType, TargetType, OutputType], | ||
Generic[InputType, OutputType, TargetType, StepOutputType], | ||
abc.ABC, | ||
): | ||
def __init__( | ||
self, | ||
taskmodule_config: Optional[Dict[str, Any]] = None, | ||
metric_stages: List[str] = [TRAINING, VALIDATION, TEST], | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.setup_metrics(metric_stages=metric_stages, taskmodule_config=taskmodule_config) | ||
|
||
def get_loss_from_outputs(self, outputs: OutputType) -> StepOutputType: | ||
if hasattr(outputs, "loss"): | ||
return outputs.loss | ||
else: | ||
raise ValueError( | ||
f"The model {self.__class__.__name__} does not define a 'loss' attribute in its output, " | ||
"so the loss cannot be automatically extracted from the outputs. Please override the" | ||
"get_loss_from_outputs() method for this model." | ||
) | ||
|
||
def log_loss(self, stage: str, loss: StepOutputType) -> None: | ||
# show loss on each step only during training | ||
self.log( | ||
f"loss/{stage}", | ||
loss, | ||
on_step=(stage == TRAINING), | ||
on_epoch=True, | ||
prog_bar=True, | ||
sync_dist=True, | ||
) | ||
|
||
def _step( | ||
self, | ||
stage: str, | ||
batch: StepInputType, | ||
) -> StepOutputType: | ||
inputs, targets = batch | ||
assert targets is not None, "targets has to be available for training" | ||
|
||
outputs = self(inputs=inputs, targets=targets) | ||
|
||
self.update_metric(inputs=inputs, outputs=outputs, targets=targets, stage=stage) | ||
|
||
loss = self.get_loss_from_outputs(outputs=outputs) | ||
self.log_loss(stage=stage, loss=loss) | ||
|
||
return loss | ||
|
||
def training_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType: | ||
return self._step(stage=TRAINING, batch=batch) | ||
|
||
def validation_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType: | ||
return self._step(stage=VALIDATION, batch=batch) | ||
|
||
def test_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType: | ||
return self._step(stage=TEST, batch=batch) | ||
|
||
def predict_step( | ||
self, batch: StepInputType, batch_idx: int, dataloader_idx: int | ||
) -> TargetType: | ||
inputs, targets = batch | ||
return self.predict(inputs=inputs) | ||
|
||
def on_train_epoch_end(self) -> None: | ||
self.log_metric(stage=TRAINING) | ||
|
||
def on_validation_epoch_end(self) -> None: | ||
self.log_metric(stage=VALIDATION) | ||
|
||
def on_test_epoch_end(self) -> None: | ||
self.log_metric(stage=TEST) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .with_metrics_from_taskmodule import WithMetricsFromTaskModule |
124 changes: 124 additions & 0 deletions
124
src/pie_modules/models/mixins/with_metrics_from_taskmodule.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import abc | ||
import logging | ||
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union | ||
|
||
from pytorch_ie import AutoTaskModule | ||
from pytorch_lightning import LightningModule | ||
from torchmetrics import Metric, MetricCollection | ||
|
||
from pie_modules.models.interface import RequiresTaskmoduleConfig | ||
|
||
InputType = TypeVar("InputType") | ||
TargetType = TypeVar("TargetType") | ||
OutputType = TypeVar("OutputType") | ||
|
||
TRAINING = "train" | ||
VALIDATION = "val" | ||
TEST = "test" | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class WithMetricsFromTaskModule( | ||
LightningModule, RequiresTaskmoduleConfig, Generic[InputType, TargetType, OutputType], abc.ABC | ||
): | ||
"""A mixin for LightningModules that adds metrics from a taskmodule. | ||
The metrics are added to the LightningModule as attributes with the names metric_{stage} via | ||
setup_metrics method, where stage is one of "train", "val", or "test". The metrics are updated | ||
with the update_metric method and logged with the on_{stage}_epoch_end methods. | ||
""" | ||
|
||
def setup_metrics( | ||
self, metric_stages: List[str], taskmodule_config: Optional[Dict[str, Any]] = None | ||
) -> None: | ||
"""Setup metrics for the given stages. If taskmodule_config is provided, the metrics are | ||
configured from the taskmodule. Otherwise, no metrics are available. | ||
Args: | ||
metric_stages: The stages for which to setup metrics. Must be one of "train", "val", or | ||
"test". | ||
taskmodule_config: The config for the taskmodule which can be obtained from the | ||
taskmodule.config property. | ||
""" | ||
|
||
for stage in [TRAINING, VALIDATION, TEST]: | ||
self._set_metric(stage=stage, metric=None) | ||
if taskmodule_config is not None: | ||
taskmodule = AutoTaskModule.from_config(taskmodule_config) | ||
for stage in metric_stages: | ||
if stage not in [TRAINING, VALIDATION, TEST]: | ||
raise ValueError( | ||
f'metric_stages must only contain the values "{TRAINING}", "{VALIDATION}", and "{TEST}".' | ||
) | ||
metric = taskmodule.configure_model_metric(stage=stage) | ||
if metric is not None: | ||
self._set_metric(stage=stage, metric=metric) | ||
else: | ||
logger.warning( | ||
f"The taskmodule {taskmodule.__class__.__name__} does not define a metric for stage " | ||
f"'{stage}'." | ||
) | ||
else: | ||
logger.warning("No taskmodule_config was provided. Metrics will not be available.") | ||
|
||
def _get_metric(self, stage: str) -> Optional[Union[Metric, MetricCollection]]: | ||
return getattr(self, f"metric_{stage}") | ||
|
||
def _set_metric(self, stage: str, metric: Optional[Union[Metric, MetricCollection]]) -> None: | ||
setattr(self, f"metric_{stage}", metric) | ||
|
||
@abc.abstractmethod | ||
def predict(self, inputs: InputType, **kwargs) -> TargetType: | ||
"""Predict the target for the given inputs.""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: | ||
"""Decode the outputs of the model into the target format.""" | ||
pass | ||
|
||
def update_metric( | ||
self, | ||
stage: str, | ||
inputs: InputType, | ||
targets: TargetType, | ||
outputs: Optional[OutputType] = None, | ||
) -> None: | ||
"""Update the metric for the given stage. If outputs is provided, the predictions are | ||
decoded from the outputs. Otherwise, the predictions are obtained by directly calling the | ||
predict method with the inputs (note that this causes the model to be called a second | ||
time). Finally, the metric is updated with the predictions and targets. | ||
Args: | ||
stage: The stage for which to update the metric. Must be one of "train", "val", or "test". | ||
inputs: The inputs to the model. | ||
targets: The targets for the inputs. | ||
outputs: The outputs of the model. They are decoded into predictions if provided. If | ||
outputs is None, the predictions are obtained by directly calling the predict method | ||
on the inputs. | ||
""" | ||
|
||
metric = self._get_metric(stage=stage) | ||
if metric is not None: | ||
if outputs is not None: | ||
predictions = self.decode(inputs=inputs, outputs=outputs) | ||
else: | ||
predictions = self.predict(inputs=inputs) | ||
metric.update(predictions, targets) | ||
|
||
def log_metric(self, stage: str, reset: bool = True) -> None: | ||
"""Log the metric for the given stage and reset it.""" | ||
|
||
metric = self._get_metric(stage=stage) | ||
if metric is not None: | ||
values = metric.compute() | ||
log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True} | ||
if isinstance(values, dict): | ||
for key, value in values.items(): | ||
self.log(f"metric/{key}/{stage}", value, **log_kwargs) | ||
else: | ||
metric_name = getattr(metric, "name", None) or type(metric).__name__ | ||
self.log(f"metric/{metric_name}/{stage}", values, **log_kwargs) | ||
if reset: | ||
metric.reset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import logging | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from pytorch_ie.core import PyTorchIEModel | ||
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses | ||
from pytorch_lightning.utilities.types import OptimizerLRScheduler | ||
from torch import FloatTensor, LongTensor | ||
from transformers import AutoConfig, AutoModelForTokenClassification, BatchEncoding | ||
from transformers.modeling_outputs import TokenClassifierOutput | ||
from typing_extensions import TypeAlias | ||
|
||
from pie_modules.models.common import Model | ||
|
||
# model inputs / outputs / targets | ||
InputType: TypeAlias = BatchEncoding | ||
OutputType: TypeAlias = TokenClassifierOutput | ||
TargetType: TypeAlias = LongTensor | ||
# step inputs / outputs | ||
StepInputType: TypeAlias = Tuple[InputType, Optional[TargetType]] | ||
StepOutputType: TypeAlias = FloatTensor | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@PyTorchIEModel.register() | ||
class SimpleTokenClassificationModel( | ||
Model[InputType, OutputType, TargetType, StepOutputType], | ||
RequiresModelNameOrPath, | ||
RequiresNumClasses, | ||
): | ||
"""A simple token classification model that wraps a (pretrained) model loaded with | ||
AutoModelForTokenClassification from the transformers library. | ||
The model is trained with a cross-entropy loss function and uses the Adam optimizer. | ||
Note that for training, the labels for the special tokens (as well as for padding tokens) | ||
are expected to have the value label_pad_id (-100 by default, which is the default ignore_index | ||
value for the CrossEntropyLoss). The predictions for these tokens are also replaced with | ||
label_pad_id to match the training labels for correct metric calculation. Therefore, the model | ||
requires the special_tokens_mask and attention_mask (for padding) to be passed as inputs. | ||
Args: | ||
model_name_or_path: The name or path of the pretrained transformer model to use. | ||
num_classes: The number of classes to predict. | ||
learning_rate: The learning rate to use for training. | ||
label_pad_id: The label id to use for padding labels (at the padding token positions | ||
as well as for the special tokens). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name_or_path: str, | ||
num_classes: int, | ||
learning_rate: float = 1e-5, | ||
label_pad_id: int = -100, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.save_hyperparameters() | ||
|
||
self.learning_rate = learning_rate | ||
self.label_pad_id = label_pad_id | ||
self.num_classes = num_classes | ||
|
||
config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_classes) | ||
if self.is_from_pretrained: | ||
self.model = AutoModelForTokenClassification.from_config(config=config) | ||
else: | ||
self.model = AutoModelForTokenClassification.from_pretrained( | ||
model_name_or_path, config=config | ||
) | ||
|
||
def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType: | ||
inputs_without_special_tokens_mask = { | ||
k: v for k, v in inputs.items() if k != "special_tokens_mask" | ||
} | ||
return self.model(labels=targets, **inputs_without_special_tokens_mask) | ||
|
||
def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: | ||
# get the max index for each token from the logits | ||
tags_tensor = torch.argmax(outputs.logits, dim=-1).to(torch.long) | ||
|
||
# mask out the padding and special tokens | ||
tags_tensor = tags_tensor.masked_fill(inputs["attention_mask"] == 0, self.label_pad_id) | ||
|
||
# mask out the special tokens | ||
tags_tensor = tags_tensor.masked_fill( | ||
inputs["special_tokens_mask"] == 1, self.label_pad_id | ||
) | ||
return tags_tensor | ||
|
||
def configure_optimizers(self) -> OptimizerLRScheduler: | ||
return torch.optim.Adam(self.parameters(), lr=self.learning_rate) |
Oops, something went wrong.