From afb0c592944fe58fb0c82d33dd80d142b3fe8da0 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 7 Nov 2024 11:29:56 +0100 Subject: [PATCH] also log epoch --- modelforge/train/training.py | 40 +++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/modelforge/train/training.py b/modelforge/train/training.py index bb5c91f5..c96dd28a 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1157,7 +1157,9 @@ def _log_figures_for_each_phase( # Log outlier error counts for non-training phases if phase != "train": self._identify__and_log_top_k_errors(errors, gathered_indices, phase) - self.log_dict(self.outlier_errors_over_epochs, on_epoch=True) + self.log_dict( + self.outlier_errors_over_epochs, on_epoch=True, rank_zero_only=True + ) def _identify__and_log_top_k_errors( self, @@ -1197,7 +1199,9 @@ def _identify__and_log_top_k_errors( if key not in self.outlier_errors_over_epochs: self.outlier_errors_over_epochs[key] = 0 self.outlier_errors_over_epochs[key] += 1 - log.info(f"{phase} : Outlier error {error} at index {idx}.") + log.info( + f"{self.current_epoch}: {phase} : Outlier error {error} at index {idx}." + ) def _clear_error_tracking(self, preds, targets, incides): """ @@ -1293,24 +1297,28 @@ def on_train_epoch_end(self): self._log_time() self._log_histograms() # log the weights of the different loss components - for key, weight in self.loss.weights_scheduling.items(): - self.log( - f"loss/{key}/weight", - weight[self.current_epoch], - ) + if self.trainer.is_global_zero: + for key, weight in self.loss.weights_scheduling.items(): + self.log( + f"loss/{key}/weight", + weight[self.current_epoch], + rank_zero_only=True, + ) def _log_learning_rate(self): """Logs the current learning rate.""" sch = self.lr_schedulers() - try: - self.log( - "lr", - sch.get_last_lr()[0], - on_epoch=True, - prog_bar=True, - ) - except AttributeError: - pass + if self.trainer.is_global_zero: + try: + self.log( + "lr", + sch.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + rank_zero_only=True, + ) + except AttributeError: + pass def _log_metrics(self, metrics: ModuleDict, phase: str): """