Skip to content

Commit

Permalink
fix logging nested metric dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Jan 23, 2024
1 parent cfa3c15 commit a472ed8
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pytorch_ie import PyTorchIEModel
from torchmetrics import Metric, MetricCollection

from pie_modules.utils import flatten_dict

from .has_taskmodule import HasTaskmodule
from .stages import TESTING, TRAINING, VALIDATION

Expand Down Expand Up @@ -141,7 +143,8 @@ def log_metric(self, stage: str, reset: bool = True) -> None:
values = metric.compute()
log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True}
if isinstance(values, dict):
for key, value in values.items():
values_flat = flatten_dict(values, sep="/")
for key, value in values_flat.items():
self.log(f"metric/{key}/{stage}", value, **log_kwargs)
else:
metric_name = getattr(metric, "name", None) or type(metric).__name__
Expand Down

0 comments on commit a472ed8

Please sign in to comment.