Skip to content

Commit

Permalink
Merge pull request #35 from ArneBinder/refactor_token_classification
Browse files Browse the repository at this point in the history
refactor token classification
  • Loading branch information
ArneBinder authored Jan 19, 2024
2 parents 5f96e5a + 19e9a83 commit 8e6d5b2
Show file tree
Hide file tree
Showing 24 changed files with 1,823 additions and 775 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ Available models:

- [SimpleSequenceClassificationModel](src/pie_modules/models/simple_sequence_classification.py)
- [SequenceClassificationModel](src/pie_modules/models/sequence_classification.py)
- [SimpleTokenClassificationModel](src/pie_modules/models/simple_token_classification.py)
- [TokenClassificationModelWithSeq2SeqEncoderAndCrf](src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py)
- [SimpleExtractiveQuestionAnsweringModel](src/pie_modules/models/simple_extractive_question_answering.py)
- [SimpleGenerativeModel](src/pie_modules/models/simple_generative.py)

Available taskmodules:

- [RETextClassificationWithIndicesTaskModule](src/pie_modules/taskmodules/re_text_classification_with_indices.py)
- [TokenClassificationTaskModule](src/pie_modules/taskmodules/token_classification.py)
- [LabeledSpanExtractionByTokenClassificationTaskModule](src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py)
- [ExtractiveQuestionAnsweringTaskModule](src/pie_modules/taskmodules/extractive_question_answering.py)
- [TextToTextTaskModule](src/pie_modules/taskmodules/text_to_text.py)
- [PointerNetworkTaskModuleForEnd2EndRE](src/pie_modules/taskmodules/pointer_network_for_end2end_re.py)
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.9"
pytorch-ie = ">=0.29.5,<0.30.0"
pytorch-ie = ">=0.29.8,<0.30.0"
pytorch-lightning = "^2.1.0"
torchmetrics = "^1"
pytorch-crf = ">=0.7.2"
Expand Down
1 change: 1 addition & 0 deletions src/pie_modules/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .simple_extractive_question_answering import SimpleExtractiveQuestionAnsweringModel
from .simple_generative import SimpleGenerativeModel
from .simple_sequence_classification import SimpleSequenceClassificationModel
from .simple_token_classification import SimpleTokenClassificationModel
from .token_classification_with_seq2seq_encoder_and_crf import (
TokenClassificationModelWithSeq2SeqEncoderAndCrf,
)
1 change: 1 addition & 0 deletions src/pie_modules/models/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import Model
98 changes: 98 additions & 0 deletions src/pie_modules/models/common/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import abc
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar

from pytorch_ie import PyTorchIEModel
from typing_extensions import TypeAlias

from pie_modules.models.mixins import WithMetricsFromTaskModule

InputType = TypeVar("InputType")
OutputType = TypeVar("OutputType")
TargetType = TypeVar("TargetType")
StepInputType: TypeAlias = Tuple[
InputType,
Optional[TargetType],
]
StepOutputType = TypeVar("StepOutputType")

TRAINING = "train"
VALIDATION = "val"
TEST = "test"


class Model(
PyTorchIEModel,
WithMetricsFromTaskModule[InputType, TargetType, OutputType],
Generic[InputType, OutputType, TargetType, StepOutputType],
abc.ABC,
):
def __init__(
self,
taskmodule_config: Optional[Dict[str, Any]] = None,
metric_stages: List[str] = [TRAINING, VALIDATION, TEST],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.setup_metrics(metric_stages=metric_stages, taskmodule_config=taskmodule_config)

def get_loss_from_outputs(self, outputs: OutputType) -> StepOutputType:
if hasattr(outputs, "loss"):
return outputs.loss
else:
raise ValueError(
f"The model {self.__class__.__name__} does not define a 'loss' attribute in its output, "
"so the loss cannot be automatically extracted from the outputs. Please override the"
"get_loss_from_outputs() method for this model."
)

def log_loss(self, stage: str, loss: StepOutputType) -> None:
# show loss on each step only during training
self.log(
f"loss/{stage}",
loss,
on_step=(stage == TRAINING),
on_epoch=True,
prog_bar=True,
sync_dist=True,
)

def _step(
self,
stage: str,
batch: StepInputType,
) -> StepOutputType:
inputs, targets = batch
assert targets is not None, "targets has to be available for training"

outputs = self(inputs=inputs, targets=targets)

self.update_metric(inputs=inputs, outputs=outputs, targets=targets, stage=stage)

loss = self.get_loss_from_outputs(outputs=outputs)
self.log_loss(stage=stage, loss=loss)

return loss

def training_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType:
return self._step(stage=TRAINING, batch=batch)

def validation_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType:
return self._step(stage=VALIDATION, batch=batch)

def test_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType:
return self._step(stage=TEST, batch=batch)

def predict_step(
self, batch: StepInputType, batch_idx: int, dataloader_idx: int
) -> TargetType:
inputs, targets = batch
return self.predict(inputs=inputs)

def on_train_epoch_end(self) -> None:
self.log_metric(stage=TRAINING)

def on_validation_epoch_end(self) -> None:
self.log_metric(stage=VALIDATION)

def on_test_epoch_end(self) -> None:
self.log_metric(stage=TEST)
1 change: 1 addition & 0 deletions src/pie_modules/models/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .with_metrics_from_taskmodule import WithMetricsFromTaskModule
124 changes: 124 additions & 0 deletions src/pie_modules/models/mixins/with_metrics_from_taskmodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import abc
import logging
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

from pytorch_ie import AutoTaskModule
from pytorch_lightning import LightningModule
from torchmetrics import Metric, MetricCollection

from pie_modules.models.interface import RequiresTaskmoduleConfig

InputType = TypeVar("InputType")
TargetType = TypeVar("TargetType")
OutputType = TypeVar("OutputType")

TRAINING = "train"
VALIDATION = "val"
TEST = "test"

logger = logging.getLogger(__name__)


class WithMetricsFromTaskModule(
LightningModule, RequiresTaskmoduleConfig, Generic[InputType, TargetType, OutputType], abc.ABC
):
"""A mixin for LightningModules that adds metrics from a taskmodule.
The metrics are added to the LightningModule as attributes with the names metric_{stage} via
setup_metrics method, where stage is one of "train", "val", or "test". The metrics are updated
with the update_metric method and logged with the on_{stage}_epoch_end methods.
"""

def setup_metrics(
self, metric_stages: List[str], taskmodule_config: Optional[Dict[str, Any]] = None
) -> None:
"""Setup metrics for the given stages. If taskmodule_config is provided, the metrics are
configured from the taskmodule. Otherwise, no metrics are available.
Args:
metric_stages: The stages for which to setup metrics. Must be one of "train", "val", or
"test".
taskmodule_config: The config for the taskmodule which can be obtained from the
taskmodule.config property.
"""

for stage in [TRAINING, VALIDATION, TEST]:
self._set_metric(stage=stage, metric=None)
if taskmodule_config is not None:
taskmodule = AutoTaskModule.from_config(taskmodule_config)
for stage in metric_stages:
if stage not in [TRAINING, VALIDATION, TEST]:
raise ValueError(
f'metric_stages must only contain the values "{TRAINING}", "{VALIDATION}", and "{TEST}".'
)
metric = taskmodule.configure_model_metric(stage=stage)
if metric is not None:
self._set_metric(stage=stage, metric=metric)
else:
logger.warning(
f"The taskmodule {taskmodule.__class__.__name__} does not define a metric for stage "
f"'{stage}'."
)
else:
logger.warning("No taskmodule_config was provided. Metrics will not be available.")

def _get_metric(self, stage: str) -> Optional[Union[Metric, MetricCollection]]:
return getattr(self, f"metric_{stage}")

def _set_metric(self, stage: str, metric: Optional[Union[Metric, MetricCollection]]) -> None:
setattr(self, f"metric_{stage}", metric)

@abc.abstractmethod
def predict(self, inputs: InputType, **kwargs) -> TargetType:
"""Predict the target for the given inputs."""
pass

@abc.abstractmethod
def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
"""Decode the outputs of the model into the target format."""
pass

def update_metric(
self,
stage: str,
inputs: InputType,
targets: TargetType,
outputs: Optional[OutputType] = None,
) -> None:
"""Update the metric for the given stage. If outputs is provided, the predictions are
decoded from the outputs. Otherwise, the predictions are obtained by directly calling the
predict method with the inputs (note that this causes the model to be called a second
time). Finally, the metric is updated with the predictions and targets.
Args:
stage: The stage for which to update the metric. Must be one of "train", "val", or "test".
inputs: The inputs to the model.
targets: The targets for the inputs.
outputs: The outputs of the model. They are decoded into predictions if provided. If
outputs is None, the predictions are obtained by directly calling the predict method
on the inputs.
"""

metric = self._get_metric(stage=stage)
if metric is not None:
if outputs is not None:
predictions = self.decode(inputs=inputs, outputs=outputs)
else:
predictions = self.predict(inputs=inputs)
metric.update(predictions, targets)

def log_metric(self, stage: str, reset: bool = True) -> None:
"""Log the metric for the given stage and reset it."""

metric = self._get_metric(stage=stage)
if metric is not None:
values = metric.compute()
log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True}
if isinstance(values, dict):
for key, value in values.items():
self.log(f"metric/{key}/{stage}", value, **log_kwargs)
else:
metric_name = getattr(metric, "name", None) or type(metric).__name__
self.log(f"metric/{metric_name}/{stage}", values, **log_kwargs)
if reset:
metric.reset()
95 changes: 95 additions & 0 deletions src/pie_modules/models/simple_token_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging
from typing import Optional, Tuple

import torch
from pytorch_ie.core import PyTorchIEModel
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
from pytorch_lightning.utilities.types import OptimizerLRScheduler
from torch import FloatTensor, LongTensor
from transformers import AutoConfig, AutoModelForTokenClassification, BatchEncoding
from transformers.modeling_outputs import TokenClassifierOutput
from typing_extensions import TypeAlias

from pie_modules.models.common import Model

# model inputs / outputs / targets
InputType: TypeAlias = BatchEncoding
OutputType: TypeAlias = TokenClassifierOutput
TargetType: TypeAlias = LongTensor
# step inputs / outputs
StepInputType: TypeAlias = Tuple[InputType, Optional[TargetType]]
StepOutputType: TypeAlias = FloatTensor


logger = logging.getLogger(__name__)


@PyTorchIEModel.register()
class SimpleTokenClassificationModel(
Model[InputType, OutputType, TargetType, StepOutputType],
RequiresModelNameOrPath,
RequiresNumClasses,
):
"""A simple token classification model that wraps a (pretrained) model loaded with
AutoModelForTokenClassification from the transformers library.
The model is trained with a cross-entropy loss function and uses the Adam optimizer.
Note that for training, the labels for the special tokens (as well as for padding tokens)
are expected to have the value label_pad_id (-100 by default, which is the default ignore_index
value for the CrossEntropyLoss). The predictions for these tokens are also replaced with
label_pad_id to match the training labels for correct metric calculation. Therefore, the model
requires the special_tokens_mask and attention_mask (for padding) to be passed as inputs.
Args:
model_name_or_path: The name or path of the pretrained transformer model to use.
num_classes: The number of classes to predict.
learning_rate: The learning rate to use for training.
label_pad_id: The label id to use for padding labels (at the padding token positions
as well as for the special tokens).
"""

def __init__(
self,
model_name_or_path: str,
num_classes: int,
learning_rate: float = 1e-5,
label_pad_id: int = -100,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.save_hyperparameters()

self.learning_rate = learning_rate
self.label_pad_id = label_pad_id
self.num_classes = num_classes

config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_classes)
if self.is_from_pretrained:
self.model = AutoModelForTokenClassification.from_config(config=config)
else:
self.model = AutoModelForTokenClassification.from_pretrained(
model_name_or_path, config=config
)

def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
inputs_without_special_tokens_mask = {
k: v for k, v in inputs.items() if k != "special_tokens_mask"
}
return self.model(labels=targets, **inputs_without_special_tokens_mask)

def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
# get the max index for each token from the logits
tags_tensor = torch.argmax(outputs.logits, dim=-1).to(torch.long)

# mask out the padding and special tokens
tags_tensor = tags_tensor.masked_fill(inputs["attention_mask"] == 0, self.label_pad_id)

# mask out the special tokens
tags_tensor = tags_tensor.masked_fill(
inputs["special_tokens_mask"] == 1, self.label_pad_id
)
return tags_tensor

def configure_optimizers(self) -> OptimizerLRScheduler:
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
Loading

0 comments on commit 8e6d5b2

Please sign in to comment.