diff --git a/modelforge/train/training.py b/modelforge/train/training.py index b10ba83f..c96dd28a 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -664,9 +664,8 @@ def training_step( if self.training_parameter.log_norm: if key == "total_loss": continue # Skip total loss for gradient norm logging - grad_norm = compute_grad_norm(metric.mean(), self) - self.log(f"grad_norm/{key}", grad_norm) + self.log(f"grad_norm/{key}", grad_norm, sync_dist=True) # Save energy predictions and targets self._update_predictions( @@ -1158,7 +1157,9 @@ def _log_figures_for_each_phase( # Log outlier error counts for non-training phases if phase != "train": self._identify__and_log_top_k_errors(errors, gathered_indices, phase) - self.log_dict(self.outlier_errors_over_epochs, on_epoch=True) + self.log_dict( + self.outlier_errors_over_epochs, on_epoch=True, rank_zero_only=True + ) def _identify__and_log_top_k_errors( self, @@ -1198,7 +1199,9 @@ def _identify__and_log_top_k_errors( if key not in self.outlier_errors_over_epochs: self.outlier_errors_over_epochs[key] = 0 self.outlier_errors_over_epochs[key] += 1 - log.info(f"{phase} : Outlier error {error} at index {idx}.") + log.info( + f"{self.current_epoch}: {phase} : Outlier error {error} at index {idx}." + ) def _clear_error_tracking(self, preds, targets, incides): """ @@ -1251,19 +1254,6 @@ def on_validation_epoch_end(self): self.val_indices, ) - def on_train_start(self): - """Log the GPU name to Weights & Biases at the start of training.""" - if isinstance(self.logger, pL.loggers.WandbLogger) and self.global_rank == 0: - if torch.cuda.is_available(): - gpu_name = torch.cuda.get_device_name(0) - else: - gpu_name = "CPU" - # Log GPU name to W&B - self.logger.experiment.config.update({"GPU": gpu_name}) - self.logger.experiment.log({"GPU Name": gpu_name}) - else: - log.warning("Weights & Biases logger not found; GPU name not logged.") - def on_train_epoch_start(self): """Start the epoch timer.""" self.epoch_start_time = time.time() @@ -1281,15 +1271,7 @@ def _log_time(self): def on_train_epoch_end(self): """Logs metrics, learning rate, histograms, and figures at the end of the training epoch.""" - if self.global_rank == 0: - self._log_metrics(self.loss_metrics, "loss") - self._log_learning_rate() - self._log_time() - self._log_histograms() - # log the weights of the different loss components - for key, weight in self.loss.weights_scheduling.items(): - self.log(f"loss/{key}/weight", weight[self.current_epoch]) - + self._log_metrics(self.loss_metrics, "loss") # this performs gather operations and logs only at rank == 0 self._log_figures_for_each_phase( self.train_preds, @@ -1311,15 +1293,32 @@ def on_train_epoch_end(self): self.train_indices, ) + self._log_learning_rate() + self._log_time() + self._log_histograms() + # log the weights of the different loss components + if self.trainer.is_global_zero: + for key, weight in self.loss.weights_scheduling.items(): + self.log( + f"loss/{key}/weight", + weight[self.current_epoch], + rank_zero_only=True, + ) + def _log_learning_rate(self): """Logs the current learning rate.""" sch = self.lr_schedulers() - try: - self.log( - "lr", sch.get_last_lr()[0], on_epoch=True, prog_bar=True, sync_dist=True - ) - except AttributeError: - pass + if self.trainer.is_global_zero: + try: + self.log( + "lr", + sch.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + rank_zero_only=True, + ) + except AttributeError: + pass def _log_metrics(self, metrics: ModuleDict, phase: str): """