Skip to content

Commit

Permalink
Move metric to gpu before compute to avoid error in ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesGaydon committed Apr 25, 2024
1 parent 6a874b2 commit f7b7f73
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion myria3d/callbacks/metric_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _end_of_epoch(self, phase: str, pl_module):
)
class_names = pl_module.hparams.classification_dict.values()
for metric_name, metric in self.metrics_by_class[phase].items():
values = metric.compute()
values = metric.to(pl_module.device).compute()
for value, class_name in zip(values, class_names):
metric_name_for_log = f"{phase}/{metric_name}/{class_name}"
self.log(
Expand Down

0 comments on commit f7b7f73

Please sign in to comment.