-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #37 from ArneBinder/improve_new_models
streamline `SimpleGenerativeModel`
- Loading branch information
Showing
23 changed files
with
297 additions
and
454 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from .model import Model | ||
from .has_taskmodule import HasTaskmodule | ||
from .model_with_boilerplate import ModelWithBoilerplate | ||
from .model_with_metrics_from_taskmodule import ModelWithMetricsFromTaskModule | ||
from .stages import TESTING, TRAINING, VALIDATION |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from typing import Any, Dict, Optional | ||
|
||
from pytorch_ie import AutoTaskModule, TaskModule | ||
|
||
from pie_modules.models.interface import RequiresTaskmoduleConfig | ||
|
||
|
||
class HasTaskmodule(RequiresTaskmoduleConfig): | ||
"""A mixin class for models that have a taskmodule. | ||
Args: | ||
taskmodule_config: The config for the taskmodule which can be obtained from the | ||
taskmodule.config property. | ||
""" | ||
|
||
def __init__(self, taskmodule_config: Optional[Dict[str, Any]] = None, **kwargs): | ||
super().__init__(**kwargs) | ||
self.taskmodule: Optional[TaskModule] = None | ||
if taskmodule_config is not None: | ||
self.taskmodule = AutoTaskModule.from_config(taskmodule_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 150 additions & 0 deletions
150
src/pie_modules/models/common/model_with_metrics_from_taskmodule.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import logging | ||
from typing import Dict, Generic, List, Optional, Set, TypeVar, Union | ||
|
||
from pytorch_ie import PyTorchIEModel | ||
from torchmetrics import Metric, MetricCollection | ||
|
||
from .has_taskmodule import HasTaskmodule | ||
from .stages import TESTING, TRAINING, VALIDATION | ||
|
||
InputType = TypeVar("InputType") | ||
TargetType = TypeVar("TargetType") | ||
OutputType = TypeVar("OutputType") | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ModelWithMetricsFromTaskModule( | ||
HasTaskmodule, PyTorchIEModel, Generic[InputType, TargetType, OutputType] | ||
): | ||
"""A PyTorchIEModel that adds metrics from a taskmodule. | ||
The metrics are added to the model as attributes with the names metric_{stage} via | ||
setup_metrics method, where stage is one of "train", "val", or "test". The metrics are updated | ||
with the update_metric method and logged with the log_metric method. | ||
Args: | ||
metric_stages: The stages for which to set up metrics. Must be one of "train", "val", or | ||
"test". | ||
metric_intervals: A dict mapping metric stages to the number of steps between metric | ||
calculation. If not provided, the metrics are calculated at the end of each epoch. | ||
metric_call_predict: Whether to call predict() and use its result for metric calculation | ||
instead of the (decoded) model output. This is useful, for instance, for generative models | ||
that define special logic to produce predictions, e.g. beam search, which requires multiple | ||
passes through the model. If True, predict() is called for all metric stages. If False (default), | ||
the model outputs are passed to decode() and that is used for all metric stages. If a list of | ||
metric stages is provided, predict() is called for these stages and the (decoded) model | ||
outputs for the remaining stages. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
metric_stages: List[str] = [TRAINING, VALIDATION, TESTING], | ||
metric_intervals: Optional[Dict[str, int]] = None, | ||
metric_call_predict: Union[bool, List[str]] = False, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
|
||
self.setup_metrics(metric_stages=metric_stages) | ||
|
||
self.metric_intervals = metric_intervals or {} | ||
missed_stages = set(self.metric_intervals) - set(metric_stages) | ||
if len(missed_stages) > 0: | ||
logger.warning( | ||
f"There are stages in metric_intervals that are not in metric_stages: " | ||
f"{missed_stages}. Available metric stages: {metric_stages}." | ||
) | ||
|
||
self.use_prediction_for_metrics: Set[str] | ||
if isinstance(metric_call_predict, bool): | ||
self.metric_call_predict = set(metric_stages) if metric_call_predict else set() | ||
else: | ||
self.metric_call_predict = set(metric_call_predict) | ||
missed_stages = self.metric_call_predict - set(metric_stages) | ||
if len(missed_stages) > 0: | ||
logger.warning( | ||
f"There are stages in metric_call_predict that are not in metric_stages: " | ||
f"{missed_stages}. Available metric stages: {metric_stages}." | ||
) | ||
|
||
def setup_metrics(self, metric_stages: List[str]) -> None: | ||
"""Set up metrics for the given stages if a taskmodule is available. | ||
Args: | ||
metric_stages: The stages for which to set up metrics. Must be one of "train", "val", or | ||
"test". | ||
""" | ||
if self.taskmodule is not None: | ||
for stage in metric_stages: | ||
metric = self.taskmodule.configure_model_metric(stage=stage) | ||
if metric is not None: | ||
self._set_metric(stage=stage, metric=metric) | ||
else: | ||
logger.warning( | ||
f"The taskmodule {self.taskmodule.__class__.__name__} does not define a metric for stage " | ||
f"'{stage}'." | ||
) | ||
elif len(metric_stages) > 0: | ||
logger.warning( | ||
"No taskmodule is available, so no metrics are set up. " | ||
"Please provide a taskmodule_config to enable metrics for stages " | ||
f"{metric_stages}." | ||
) | ||
|
||
def _get_metric( | ||
self, stage: str, batch_idx: int = 0 | ||
) -> Optional[Union[Metric, MetricCollection]]: | ||
metric_interval = self.metric_intervals.get(stage, 1) | ||
if (batch_idx + 1) % metric_interval == 0: | ||
return getattr(self, f"metric_{stage}", None) | ||
else: | ||
return None | ||
|
||
def _set_metric(self, stage: str, metric: Optional[Union[Metric, MetricCollection]]) -> None: | ||
setattr(self, f"metric_{stage}", metric) | ||
|
||
def update_metric( | ||
self, | ||
stage: str, | ||
inputs: InputType, | ||
targets: TargetType, | ||
outputs: OutputType, | ||
) -> None: | ||
"""Update the metric for the given stage. If outputs is provided, the predictions are | ||
decoded from the outputs. Otherwise, the predictions are obtained by directly calling the | ||
predict method with the inputs (note that this causes the model to be called a second | ||
time). Finally, the metric is updated with the predictions and targets. | ||
Args: | ||
stage: The stage for which to update the metric. Must be one of "train", "val", or "test". | ||
inputs: The inputs to the model. | ||
targets: The targets for the inputs. | ||
outputs: The outputs of the model. They are decoded into predictions if provided. If | ||
outputs is None, the predictions are obtained by directly calling the predict method | ||
on the inputs. | ||
""" | ||
|
||
metric = self._get_metric(stage=stage) | ||
if metric is not None: | ||
if stage in self.metric_call_predict: | ||
predictions = self.predict(inputs=inputs) | ||
else: | ||
predictions = self.decode(inputs=inputs, outputs=outputs) | ||
metric.update(predictions, targets) | ||
|
||
def log_metric(self, stage: str, reset: bool = True) -> None: | ||
"""Log the metric for the given stage and reset it.""" | ||
|
||
metric = self._get_metric(stage=stage) | ||
if metric is not None: | ||
values = metric.compute() | ||
log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True} | ||
if isinstance(values, dict): | ||
for key, value in values.items(): | ||
self.log(f"metric/{key}/{stage}", value, **log_kwargs) | ||
else: | ||
metric_name = getattr(metric, "name", None) or type(metric).__name__ | ||
self.log(f"metric/{metric_name}/{stage}", values, **log_kwargs) | ||
if reset: | ||
metric.reset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
TRAINING = "train" | ||
VALIDATION = "val" | ||
TESTING = "test" |
This file was deleted.
Oops, something went wrong.
124 changes: 0 additions & 124 deletions
124
src/pie_modules/models/mixins/with_metrics_from_taskmodule.py
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.