From 6af9b1572756e37a41f544ab71b7da7f3303c462 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 5 Nov 2024 16:47:38 +0100 Subject: [PATCH 1/6] check rank --- modelforge/train/training.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/modelforge/train/training.py b/modelforge/train/training.py index b10ba83f..bc152ff6 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1251,19 +1251,6 @@ def on_validation_epoch_end(self): self.val_indices, ) - def on_train_start(self): - """Log the GPU name to Weights & Biases at the start of training.""" - if isinstance(self.logger, pL.loggers.WandbLogger) and self.global_rank == 0: - if torch.cuda.is_available(): - gpu_name = torch.cuda.get_device_name(0) - else: - gpu_name = "CPU" - # Log GPU name to W&B - self.logger.experiment.config.update({"GPU": gpu_name}) - self.logger.experiment.log({"GPU Name": gpu_name}) - else: - log.warning("Weights & Biases logger not found; GPU name not logged.") - def on_train_epoch_start(self): """Start the epoch timer.""" self.epoch_start_time = time.time() @@ -1281,7 +1268,9 @@ def _log_time(self): def on_train_epoch_end(self): """Logs metrics, learning rate, histograms, and figures at the end of the training epoch.""" - if self.global_rank == 0: + print(self.global_rank) + if self.trainer.is_global_zero: + self._log_metrics(self.loss_metrics, "loss") self._log_learning_rate() self._log_time() From fbd5907850deb686d7e8bc636a578fa2cbf02e9c Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 6 Nov 2024 18:26:09 +0100 Subject: [PATCH 2/6] for debugging --- modelforge/train/training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelforge/train/training.py b/modelforge/train/training.py index bc152ff6..10eb6599 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1269,6 +1269,7 @@ def _log_time(self): def on_train_epoch_end(self): """Logs metrics, learning rate, histograms, and figures at the end of the training epoch.""" print(self.global_rank) + print(self.trainer.is_global_zero) if self.trainer.is_global_zero: self._log_metrics(self.loss_metrics, "loss") From 312f898d837e2d1938b91dd01de3d4b60d09b6cf Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 6 Nov 2024 20:25:25 +0100 Subject: [PATCH 3/6] update rank zero logging logic --- modelforge/train/training.py | 38 +++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 10eb6599..a0dda77e 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -664,9 +664,8 @@ def training_step( if self.training_parameter.log_norm: if key == "total_loss": continue # Skip total loss for gradient norm logging - grad_norm = compute_grad_norm(metric.mean(), self) - self.log(f"grad_norm/{key}", grad_norm) + self.log(f"grad_norm/{key}", grad_norm, sync_dist=True) # Save energy predictions and targets self._update_predictions( @@ -1158,7 +1157,10 @@ def _log_figures_for_each_phase( # Log outlier error counts for non-training phases if phase != "train": self._identify__and_log_top_k_errors(errors, gathered_indices, phase) - self.log_dict(self.outlier_errors_over_epochs, on_epoch=True) + self.log_dict( + self.outlier_errors_over_epochs, + on_epoch=True, + ) def _identify__and_log_top_k_errors( self, @@ -1268,18 +1270,7 @@ def _log_time(self): def on_train_epoch_end(self): """Logs metrics, learning rate, histograms, and figures at the end of the training epoch.""" - print(self.global_rank) - print(self.trainer.is_global_zero) - if self.trainer.is_global_zero: - - self._log_metrics(self.loss_metrics, "loss") - self._log_learning_rate() - self._log_time() - self._log_histograms() - # log the weights of the different loss components - for key, weight in self.loss.weights_scheduling.items(): - self.log(f"loss/{key}/weight", weight[self.current_epoch]) - + self._log_metrics(self.loss_metrics, "loss") # this performs gather operations and logs only at rank == 0 self._log_figures_for_each_phase( self.train_preds, @@ -1301,12 +1292,27 @@ def on_train_epoch_end(self): self.train_indices, ) + self._log_learning_rate() + self._log_time() + self._log_histograms() + # log the weights of the different loss components + for key, weight in self.loss.weights_scheduling.items(): + self.log( + f"loss/{key}/weight", + weight[self.current_epoch], + rank_zero_only=True, + ) + def _log_learning_rate(self): """Logs the current learning rate.""" sch = self.lr_schedulers() try: self.log( - "lr", sch.get_last_lr()[0], on_epoch=True, prog_bar=True, sync_dist=True + "lr", + sch.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + rank_zero_only=True, ) except AttributeError: pass From 8b23f1876631d5f62491baeddb6c7dc7e9a441e6 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 7 Nov 2024 10:42:14 +0100 Subject: [PATCH 4/6] checking if this solves the issue --- modelforge/train/training.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modelforge/train/training.py b/modelforge/train/training.py index a0dda77e..8e49b157 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1158,8 +1158,7 @@ def _log_figures_for_each_phase( if phase != "train": self._identify__and_log_top_k_errors(errors, gathered_indices, phase) self.log_dict( - self.outlier_errors_over_epochs, - on_epoch=True, + self.outlier_errors_over_epochs, on_epoch=True, sync_dist=True ) def _identify__and_log_top_k_errors( @@ -1300,7 +1299,7 @@ def on_train_epoch_end(self): self.log( f"loss/{key}/weight", weight[self.current_epoch], - rank_zero_only=True, + sync_dist=True, ) def _log_learning_rate(self): @@ -1312,7 +1311,7 @@ def _log_learning_rate(self): sch.get_last_lr()[0], on_epoch=True, prog_bar=True, - rank_zero_only=True, + sync_dist=True, ) except AttributeError: pass From 667909fac266c52ebf7218e2865a1ae1001790bc Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 7 Nov 2024 11:09:05 +0100 Subject: [PATCH 5/6] this fixes the issue --- modelforge/train/training.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 8e49b157..bb5c91f5 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1157,9 +1157,7 @@ def _log_figures_for_each_phase( # Log outlier error counts for non-training phases if phase != "train": self._identify__and_log_top_k_errors(errors, gathered_indices, phase) - self.log_dict( - self.outlier_errors_over_epochs, on_epoch=True, sync_dist=True - ) + self.log_dict(self.outlier_errors_over_epochs, on_epoch=True) def _identify__and_log_top_k_errors( self, @@ -1299,7 +1297,6 @@ def on_train_epoch_end(self): self.log( f"loss/{key}/weight", weight[self.current_epoch], - sync_dist=True, ) def _log_learning_rate(self): @@ -1311,7 +1308,6 @@ def _log_learning_rate(self): sch.get_last_lr()[0], on_epoch=True, prog_bar=True, - sync_dist=True, ) except AttributeError: pass From afb0c592944fe58fb0c82d33dd80d142b3fe8da0 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 7 Nov 2024 11:29:56 +0100 Subject: [PATCH 6/6] also log epoch --- modelforge/train/training.py | 40 +++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/modelforge/train/training.py b/modelforge/train/training.py index bb5c91f5..c96dd28a 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1157,7 +1157,9 @@ def _log_figures_for_each_phase( # Log outlier error counts for non-training phases if phase != "train": self._identify__and_log_top_k_errors(errors, gathered_indices, phase) - self.log_dict(self.outlier_errors_over_epochs, on_epoch=True) + self.log_dict( + self.outlier_errors_over_epochs, on_epoch=True, rank_zero_only=True + ) def _identify__and_log_top_k_errors( self, @@ -1197,7 +1199,9 @@ def _identify__and_log_top_k_errors( if key not in self.outlier_errors_over_epochs: self.outlier_errors_over_epochs[key] = 0 self.outlier_errors_over_epochs[key] += 1 - log.info(f"{phase} : Outlier error {error} at index {idx}.") + log.info( + f"{self.current_epoch}: {phase} : Outlier error {error} at index {idx}." + ) def _clear_error_tracking(self, preds, targets, incides): """ @@ -1293,24 +1297,28 @@ def on_train_epoch_end(self): self._log_time() self._log_histograms() # log the weights of the different loss components - for key, weight in self.loss.weights_scheduling.items(): - self.log( - f"loss/{key}/weight", - weight[self.current_epoch], - ) + if self.trainer.is_global_zero: + for key, weight in self.loss.weights_scheduling.items(): + self.log( + f"loss/{key}/weight", + weight[self.current_epoch], + rank_zero_only=True, + ) def _log_learning_rate(self): """Logs the current learning rate.""" sch = self.lr_schedulers() - try: - self.log( - "lr", - sch.get_last_lr()[0], - on_epoch=True, - prog_bar=True, - ) - except AttributeError: - pass + if self.trainer.is_global_zero: + try: + self.log( + "lr", + sch.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + rank_zero_only=True, + ) + except AttributeError: + pass def _log_metrics(self, metrics: ModuleDict, phase: str): """