diff --git a/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py b/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py index f3fc0ce6b..06ffe2ce9 100644 --- a/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py +++ b/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py @@ -4,6 +4,8 @@ from pytorch_ie import PyTorchIEModel from torchmetrics import Metric, MetricCollection +from pie_modules.utils import flatten_dict + from .has_taskmodule import HasTaskmodule from .stages import TESTING, TRAINING, VALIDATION @@ -141,7 +143,8 @@ def log_metric(self, stage: str, reset: bool = True) -> None: values = metric.compute() log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True} if isinstance(values, dict): - for key, value in values.items(): + values_flat = flatten_dict(values, sep="/") + for key, value in values_flat.items(): self.log(f"metric/{key}/{stage}", value, **log_kwargs) else: metric_name = getattr(metric, "name", None) or type(metric).__name__