Skip to content

Commit

Permalink
Merge pull request #37 from ArneBinder/improve_new_models
Browse files Browse the repository at this point in the history
streamline `SimpleGenerativeModel`
  • Loading branch information
ArneBinder authored Jan 19, 2024
2 parents 8e6d5b2 + 03dedff commit 7320355
Show file tree
Hide file tree
Showing 23 changed files with 297 additions and 454 deletions.
5 changes: 4 additions & 1 deletion src/pie_modules/models/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .model import Model
from .has_taskmodule import HasTaskmodule
from .model_with_boilerplate import ModelWithBoilerplate
from .model_with_metrics_from_taskmodule import ModelWithMetricsFromTaskModule
from .stages import TESTING, TRAINING, VALIDATION
20 changes: 20 additions & 0 deletions src/pie_modules/models/common/has_taskmodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any, Dict, Optional

from pytorch_ie import AutoTaskModule, TaskModule

from pie_modules.models.interface import RequiresTaskmoduleConfig


class HasTaskmodule(RequiresTaskmoduleConfig):
"""A mixin class for models that have a taskmodule.
Args:
taskmodule_config: The config for the taskmodule which can be obtained from the
taskmodule.config property.
"""

def __init__(self, taskmodule_config: Optional[Dict[str, Any]] = None, **kwargs):
super().__init__(**kwargs)
self.taskmodule: Optional[TaskModule] = None
if taskmodule_config is not None:
self.taskmodule = AutoTaskModule.from_config(taskmodule_config)
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import abc
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
import logging
from typing import Generic, Optional, Tuple, TypeVar

from pytorch_ie import PyTorchIEModel
from typing_extensions import TypeAlias

from pie_modules.models.mixins import WithMetricsFromTaskModule
from .model_with_metrics_from_taskmodule import ModelWithMetricsFromTaskModule
from .stages import TESTING, TRAINING, VALIDATION

InputType = TypeVar("InputType")
OutputType = TypeVar("OutputType")
Expand All @@ -15,25 +15,18 @@
]
StepOutputType = TypeVar("StepOutputType")

TRAINING = "train"
VALIDATION = "val"
TEST = "test"
logger = logging.getLogger(__name__)


class Model(
PyTorchIEModel,
WithMetricsFromTaskModule[InputType, TargetType, OutputType],
class ModelWithBoilerplate(
ModelWithMetricsFromTaskModule[InputType, TargetType, OutputType],
Generic[InputType, OutputType, TargetType, StepOutputType],
abc.ABC,
):
def __init__(
self,
taskmodule_config: Optional[Dict[str, Any]] = None,
metric_stages: List[str] = [TRAINING, VALIDATION, TEST],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.setup_metrics(metric_stages=metric_stages, taskmodule_config=taskmodule_config)
"""A PyTorchIEModel that adds boilerplate code for training, validation, and testing.
Especially, it handles updating the metrics and logging of losses and metric results. Also see
ModelWithMetricsFromTaskModule for more details on how metrics are handled.
"""

def get_loss_from_outputs(self, outputs: OutputType) -> StepOutputType:
if hasattr(outputs, "loss"):
Expand Down Expand Up @@ -65,11 +58,9 @@ def _step(
assert targets is not None, "targets has to be available for training"

outputs = self(inputs=inputs, targets=targets)

self.update_metric(inputs=inputs, outputs=outputs, targets=targets, stage=stage)

loss = self.get_loss_from_outputs(outputs=outputs)
self.log_loss(stage=stage, loss=loss)
self.update_metric(inputs=inputs, outputs=outputs, targets=targets, stage=stage)

return loss

Expand All @@ -80,10 +71,10 @@ def validation_step(self, batch: StepInputType, batch_idx: int) -> StepOutputTyp
return self._step(stage=VALIDATION, batch=batch)

def test_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType:
return self._step(stage=TEST, batch=batch)
return self._step(stage=TESTING, batch=batch)

def predict_step(
self, batch: StepInputType, batch_idx: int, dataloader_idx: int
self, batch: StepInputType, batch_idx: int, dataloader_idx: int = 0
) -> TargetType:
inputs, targets = batch
return self.predict(inputs=inputs)
Expand All @@ -95,4 +86,4 @@ def on_validation_epoch_end(self) -> None:
self.log_metric(stage=VALIDATION)

def on_test_epoch_end(self) -> None:
self.log_metric(stage=TEST)
self.log_metric(stage=TESTING)
150 changes: 150 additions & 0 deletions src/pie_modules/models/common/model_with_metrics_from_taskmodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import logging
from typing import Dict, Generic, List, Optional, Set, TypeVar, Union

from pytorch_ie import PyTorchIEModel
from torchmetrics import Metric, MetricCollection

from .has_taskmodule import HasTaskmodule
from .stages import TESTING, TRAINING, VALIDATION

InputType = TypeVar("InputType")
TargetType = TypeVar("TargetType")
OutputType = TypeVar("OutputType")

logger = logging.getLogger(__name__)


class ModelWithMetricsFromTaskModule(
HasTaskmodule, PyTorchIEModel, Generic[InputType, TargetType, OutputType]
):
"""A PyTorchIEModel that adds metrics from a taskmodule.
The metrics are added to the model as attributes with the names metric_{stage} via
setup_metrics method, where stage is one of "train", "val", or "test". The metrics are updated
with the update_metric method and logged with the log_metric method.
Args:
metric_stages: The stages for which to set up metrics. Must be one of "train", "val", or
"test".
metric_intervals: A dict mapping metric stages to the number of steps between metric
calculation. If not provided, the metrics are calculated at the end of each epoch.
metric_call_predict: Whether to call predict() and use its result for metric calculation
instead of the (decoded) model output. This is useful, for instance, for generative models
that define special logic to produce predictions, e.g. beam search, which requires multiple
passes through the model. If True, predict() is called for all metric stages. If False (default),
the model outputs are passed to decode() and that is used for all metric stages. If a list of
metric stages is provided, predict() is called for these stages and the (decoded) model
outputs for the remaining stages.
"""

def __init__(
self,
metric_stages: List[str] = [TRAINING, VALIDATION, TESTING],
metric_intervals: Optional[Dict[str, int]] = None,
metric_call_predict: Union[bool, List[str]] = False,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.setup_metrics(metric_stages=metric_stages)

self.metric_intervals = metric_intervals or {}
missed_stages = set(self.metric_intervals) - set(metric_stages)
if len(missed_stages) > 0:
logger.warning(
f"There are stages in metric_intervals that are not in metric_stages: "
f"{missed_stages}. Available metric stages: {metric_stages}."
)

self.use_prediction_for_metrics: Set[str]
if isinstance(metric_call_predict, bool):
self.metric_call_predict = set(metric_stages) if metric_call_predict else set()
else:
self.metric_call_predict = set(metric_call_predict)
missed_stages = self.metric_call_predict - set(metric_stages)
if len(missed_stages) > 0:
logger.warning(
f"There are stages in metric_call_predict that are not in metric_stages: "
f"{missed_stages}. Available metric stages: {metric_stages}."
)

def setup_metrics(self, metric_stages: List[str]) -> None:
"""Set up metrics for the given stages if a taskmodule is available.
Args:
metric_stages: The stages for which to set up metrics. Must be one of "train", "val", or
"test".
"""
if self.taskmodule is not None:
for stage in metric_stages:
metric = self.taskmodule.configure_model_metric(stage=stage)
if metric is not None:
self._set_metric(stage=stage, metric=metric)
else:
logger.warning(
f"The taskmodule {self.taskmodule.__class__.__name__} does not define a metric for stage "
f"'{stage}'."
)
elif len(metric_stages) > 0:
logger.warning(
"No taskmodule is available, so no metrics are set up. "
"Please provide a taskmodule_config to enable metrics for stages "
f"{metric_stages}."
)

def _get_metric(
self, stage: str, batch_idx: int = 0
) -> Optional[Union[Metric, MetricCollection]]:
metric_interval = self.metric_intervals.get(stage, 1)
if (batch_idx + 1) % metric_interval == 0:
return getattr(self, f"metric_{stage}", None)
else:
return None

def _set_metric(self, stage: str, metric: Optional[Union[Metric, MetricCollection]]) -> None:
setattr(self, f"metric_{stage}", metric)

def update_metric(
self,
stage: str,
inputs: InputType,
targets: TargetType,
outputs: OutputType,
) -> None:
"""Update the metric for the given stage. If outputs is provided, the predictions are
decoded from the outputs. Otherwise, the predictions are obtained by directly calling the
predict method with the inputs (note that this causes the model to be called a second
time). Finally, the metric is updated with the predictions and targets.
Args:
stage: The stage for which to update the metric. Must be one of "train", "val", or "test".
inputs: The inputs to the model.
targets: The targets for the inputs.
outputs: The outputs of the model. They are decoded into predictions if provided. If
outputs is None, the predictions are obtained by directly calling the predict method
on the inputs.
"""

metric = self._get_metric(stage=stage)
if metric is not None:
if stage in self.metric_call_predict:
predictions = self.predict(inputs=inputs)
else:
predictions = self.decode(inputs=inputs, outputs=outputs)
metric.update(predictions, targets)

def log_metric(self, stage: str, reset: bool = True) -> None:
"""Log the metric for the given stage and reset it."""

metric = self._get_metric(stage=stage)
if metric is not None:
values = metric.compute()
log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True}
if isinstance(values, dict):
for key, value in values.items():
self.log(f"metric/{key}/{stage}", value, **log_kwargs)
else:
metric_name = getattr(metric, "name", None) or type(metric).__name__
self.log(f"metric/{metric_name}/{stage}", values, **log_kwargs)
if reset:
metric.reset()
3 changes: 3 additions & 0 deletions src/pie_modules/models/common/stages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
TRAINING = "train"
VALIDATION = "val"
TESTING = "test"
1 change: 0 additions & 1 deletion src/pie_modules/models/mixins/__init__.py

This file was deleted.

124 changes: 0 additions & 124 deletions src/pie_modules/models/mixins/with_metrics_from_taskmodule.py

This file was deleted.

Loading

0 comments on commit 7320355

Please sign in to comment.