diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py index c30aa2b0..f73d3edb 100644 --- a/myria3d/callbacks/metric_callbacks.py +++ b/myria3d/callbacks/metric_callbacks.py @@ -65,7 +65,7 @@ def _end_of_epoch(self, phase: str, pl_module): ) class_names = pl_module.hparams.classification_dict.values() for metric_name, metric in self.metrics_by_class[phase].items(): - values = metric.compute() + values = metric.to(pl_module.device).compute() for value, class_name in zip(values, class_names): metric_name_for_log = f"{phase}/{metric_name}/{class_name}" self.log(