From 094080488ec741ecb706c6ce074c2454c4c1ebfe Mon Sep 17 00:00:00 2001 From: wiederm Date: Sun, 27 Oct 2024 14:32:30 +0100 Subject: [PATCH] also add histogram logging to training step --- .../tests/data/training_defaults/default.toml | 2 +- modelforge/tests/test_training.py | 2 +- modelforge/train/training.py | 101 ++++++++++-------- 3 files changed, 57 insertions(+), 48 deletions(-) diff --git a/modelforge/tests/data/training_defaults/default.toml b/modelforge/tests/data/training_defaults/default.toml index d4864180..4f13abc6 100644 --- a/modelforge/tests/data/training_defaults/default.toml +++ b/modelforge/tests/data/training_defaults/default.toml @@ -8,7 +8,7 @@ monitor = "val/per_system_energy/rmse" # Common monitor key plot_frequency = 1 # ------------------------------------------------------------ # [training.experiment_logger] -logger_name = "wandb" # this will set which logger to use +logger_name = ".tensorboard" # this will set which logger to use [training.experiment_logger.tensorboard_configuration] save_dir = "logs" # ------------------------------------------------------------ # diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py index a36bd0c6..c5f196d2 100644 --- a/modelforge/tests/test_training.py +++ b/modelforge/tests/test_training.py @@ -179,7 +179,7 @@ def test_train_with_lightning(loss, potential_name, dataset_name, prep_temp_dir) # train potential get_trainer(config).train_potential().save_checkpoint("test.chp") # save checkpoint # continue training from checkpoint - get_trainer(config).train_potential() + #get_trainer(config).train_potential() def test_train_from_single_toml_file(prep_temp_dir): diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 8184ed18..270170e5 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -2,7 +2,7 @@ This module contains classes and functions for training neural network potentials using PyTorch Lightning. """ -from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple +from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple, Literal import lightning.pytorch as pL import torch @@ -422,6 +422,8 @@ def __init__( training_parameter.loss_parameter.loss_components, is_loss=True ) + self.train_preds:Dict[int, torch.Tensor] = {} + self.train_targets:Dict[int, torch.Tensor] = {} self.val_preds:Dict[int, torch.Tensor] = {} self.val_targets:Dict[int, torch.Tensor] = {} self.test_preds: Dict[int, torch.Tensor] = {} @@ -535,13 +537,15 @@ def training_step( # Compute the mean loss for optimization total_loss = loss_dict["total_loss"].mean() + self.train_preds.update({batch_idx: predict_target['per_system_energy_predict'].detach()}) + self.train_targets.update({batch_idx: predict_target['per_system_energy_true'].detach()}) return total_loss def on_after_backward(self): # After backward pass for name, param in self.potential.named_parameters(): - if param.grad is not None and False: + if param.grad is not None or False: log.debug( f"Parameter: {name}, Gradient Norm: {param.grad.norm().item()}" ) @@ -628,36 +632,7 @@ def _get_tensors(self, preds: Dict[int, torch.Tensor], targets: Dict[int, torch. def on_validation_epoch_end(self): """Logs metrics at the end of the validation epoch.""" self._log_metrics(self.val_metrics, "val") - - # Gather across processes - gathered_preds , gathered_targets, max_length, pad_size = self._get_tensors(self.val_preds, self.val_targets) - # Clear the dictionaries - self.val_preds = {} - self.val_targets = {} - - # Proceed only on main process - if self.global_rank == 0: - # Remove padding - total_length = max_length * self.trainer.world_size - gathered_preds = gathered_preds.reshape(total_length)[:total_length - pad_size * self.trainer.world_size] - gathered_targets = gathered_targets.reshape(total_length)[:total_length - pad_size * self.trainer.world_size] - errors = (gathered_targets - gathered_preds) - if errors.size == 0: - log.warning("Errors array is empty.") - - # Create regression plot - regression_fig = self._create_regression_plot( - gathered_targets, - gathered_preds, - title=f'Validation Regression Plot - Epoch {self.current_epoch}' - ) - # Generate error histogram plot - histogram_fig = self._create_error_histogram( - errors, - title=f'Validation Error Histogram - Epoch {self.current_epoch}' - ) - - self._log_plots('val', regression_fig, histogram_fig) + self._log_figures_for_each_phase(self.val_preds, self.val_targets, 'val') def _log_plots(self, phase: str, regression_fig, histogram_fig): """ @@ -730,7 +705,9 @@ def _create_regression_plot(self, targets:torch.Tensor, predictions:torch.Tensor """ import matplotlib.pyplot as plt fig, ax = plt.subplots() - ax.scatter(targets.cpu().numpy(), predictions.cpu().numpy(), alpha=0.5) + targets = targets.cpu().numpy() + predictions = predictions.cpu().numpy() + ax.scatter(targets, predictions, alpha=0.5) ax.plot([targets.min(), targets.max()], [targets.min(), targets.max()], 'r--') ax.set_xlabel('True Values') ax.set_ylabel('Predicted Values') @@ -739,22 +716,48 @@ def _create_regression_plot(self, targets:torch.Tensor, predictions:torch.Tensor def _create_error_histogram(self, errors:torch.Tensor, title='Error Histogram'): import matplotlib.pyplot as plt - fig, ax = plt.subplots() - ax.hist(errors.cpu().numpy().flatten(), bins=50, alpha=0.75) + import numpy as np + errors_np = errors.cpu().numpy().flatten() + + # Compute mean and standard deviation + mean_error = np.mean(errors_np) + std_error = np.std(errors_np) + + fig, ax = plt.subplots(figsize=(8, 6)) + bins = 50 + + # Plot histogram and get bin data + counts, bin_edges, patches = ax.hist(errors_np, bins=bins, alpha=0.75, edgecolor='black') + + # Set y-axis to log scale + ax.set_yscale('log') + + # Highlight outlier bins beyond 3 standard deviations + for count, edge_left, edge_right, patch in zip(counts, bin_edges[:-1], bin_edges[1:], patches): + if (edge_left < mean_error - 3 * std_error) or (edge_right > mean_error + 3 * std_error): + patch.set_facecolor('red') + else: + patch.set_facecolor('blue') + + # Add vertical lines for mean and standard deviations + ax.axvline(mean_error, color='k', linestyle='dashed', linewidth=1, label='Mean') + ax.axvline(mean_error + 3 * std_error, color='r', linestyle='dashed', linewidth=1, label='±3 Std Dev') + ax.axvline(mean_error - 3 * std_error, color='r', linestyle='dashed', linewidth=1) + ax.set_xlabel('Error') - ax.set_ylabel('Frequency') + ax.set_ylabel('Frequency (Log Scale)') ax.set_title(title) + ax.legend() + return fig - def on_test_epoch_end(self): - """Logs metrics at the end of the test epoch.""" - self._log_metrics(self.test_metrics, "test") + def _log_figures_for_each_phase(self, preds: torch.Tensor, target:torch.Tensor, phase:Literal['train', 'val', 'test']): # Gather across processes - gathered_preds, gathered_targets, max_length, pad_size = self._get_tensors(self.test_preds, self.test_targets) + gathered_preds, gathered_targets, max_length, pad_size = self._get_tensors(preds, target) # Clear the dictionaries - self.test_preds = {} - self.test_targets = {} - + preds = {} + target = {} + # Proceed only on main process if self.global_rank == 0: # Remove padding @@ -769,21 +772,27 @@ def on_test_epoch_end(self): regression_fig = self._create_regression_plot( gathered_targets, gathered_preds, - title=f'Test Regression Plot - Epoch {self.current_epoch}' + title=f'{phase.capitalize()} Regression Plot - Epoch {self.current_epoch}' ) # Generate error histogram plot histogram_fig = self._create_error_histogram( errors, - title=f'Test Error Histogram - Epoch {self.current_epoch}' + title=f'{phase.capitalize()} Error Histogram - Epoch {self.current_epoch}' ) - self._log_plots("test", regression_fig, histogram_fig) + self._log_plots(phase, regression_fig, histogram_fig) + + def on_test_epoch_end(self): + """Logs metrics at the end of the test epoch.""" + self._log_metrics(self.test_metrics, "test") + self._log_figures_for_each_phase(self.test_preds, self.test_targets, 'test') def on_train_epoch_end(self): """Logs metrics at the end of the training epoch.""" self._log_metrics(self.loss_metrics, "loss") self._log_learning_rate() self._log_histograms() + self._log_figures_for_each_phase(self.train_preds, self.train_targets, 'train') def _log_learning_rate(self): """Logs the current learning rate."""