Skip to content

Commit

Permalink
also add histogram logging to training step
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Oct 27, 2024
1 parent 6b30fc5 commit 0940804
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 48 deletions.
2 changes: 1 addition & 1 deletion modelforge/tests/data/training_defaults/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ monitor = "val/per_system_energy/rmse" # Common monitor key
plot_frequency = 1
# ------------------------------------------------------------ #
[training.experiment_logger]
logger_name = "wandb" # this will set which logger to use
logger_name = ".tensorboard" # this will set which logger to use
[training.experiment_logger.tensorboard_configuration]
save_dir = "logs"
# ------------------------------------------------------------ #
Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_train_with_lightning(loss, potential_name, dataset_name, prep_temp_dir)
# train potential
get_trainer(config).train_potential().save_checkpoint("test.chp") # save checkpoint
# continue training from checkpoint
get_trainer(config).train_potential()
#get_trainer(config).train_potential()


def test_train_from_single_toml_file(prep_temp_dir):
Expand Down
101 changes: 55 additions & 46 deletions modelforge/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module contains classes and functions for training neural network potentials using PyTorch Lightning.
"""

from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple
from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple, Literal

import lightning.pytorch as pL
import torch
Expand Down Expand Up @@ -422,6 +422,8 @@ def __init__(
training_parameter.loss_parameter.loss_components, is_loss=True
)

self.train_preds:Dict[int, torch.Tensor] = {}
self.train_targets:Dict[int, torch.Tensor] = {}
self.val_preds:Dict[int, torch.Tensor] = {}
self.val_targets:Dict[int, torch.Tensor] = {}
self.test_preds: Dict[int, torch.Tensor] = {}
Expand Down Expand Up @@ -535,13 +537,15 @@ def training_step(

# Compute the mean loss for optimization
total_loss = loss_dict["total_loss"].mean()
self.train_preds.update({batch_idx: predict_target['per_system_energy_predict'].detach()})
self.train_targets.update({batch_idx: predict_target['per_system_energy_true'].detach()})

return total_loss

def on_after_backward(self):
# After backward pass
for name, param in self.potential.named_parameters():
if param.grad is not None and False:
if param.grad is not None or False:
log.debug(
f"Parameter: {name}, Gradient Norm: {param.grad.norm().item()}"
)
Expand Down Expand Up @@ -628,36 +632,7 @@ def _get_tensors(self, preds: Dict[int, torch.Tensor], targets: Dict[int, torch.
def on_validation_epoch_end(self):
"""Logs metrics at the end of the validation epoch."""
self._log_metrics(self.val_metrics, "val")

# Gather across processes
gathered_preds , gathered_targets, max_length, pad_size = self._get_tensors(self.val_preds, self.val_targets)
# Clear the dictionaries
self.val_preds = {}
self.val_targets = {}

# Proceed only on main process
if self.global_rank == 0:
# Remove padding
total_length = max_length * self.trainer.world_size
gathered_preds = gathered_preds.reshape(total_length)[:total_length - pad_size * self.trainer.world_size]
gathered_targets = gathered_targets.reshape(total_length)[:total_length - pad_size * self.trainer.world_size]
errors = (gathered_targets - gathered_preds)
if errors.size == 0:
log.warning("Errors array is empty.")

# Create regression plot
regression_fig = self._create_regression_plot(
gathered_targets,
gathered_preds,
title=f'Validation Regression Plot - Epoch {self.current_epoch}'
)
# Generate error histogram plot
histogram_fig = self._create_error_histogram(
errors,
title=f'Validation Error Histogram - Epoch {self.current_epoch}'
)

self._log_plots('val', regression_fig, histogram_fig)
self._log_figures_for_each_phase(self.val_preds, self.val_targets, 'val')

def _log_plots(self, phase: str, regression_fig, histogram_fig):
"""
Expand Down Expand Up @@ -730,7 +705,9 @@ def _create_regression_plot(self, targets:torch.Tensor, predictions:torch.Tensor
"""
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.scatter(targets.cpu().numpy(), predictions.cpu().numpy(), alpha=0.5)
targets = targets.cpu().numpy()
predictions = predictions.cpu().numpy()
ax.scatter(targets, predictions, alpha=0.5)
ax.plot([targets.min(), targets.max()], [targets.min(), targets.max()], 'r--')
ax.set_xlabel('True Values')
ax.set_ylabel('Predicted Values')
Expand All @@ -739,22 +716,48 @@ def _create_regression_plot(self, targets:torch.Tensor, predictions:torch.Tensor

def _create_error_histogram(self, errors:torch.Tensor, title='Error Histogram'):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.hist(errors.cpu().numpy().flatten(), bins=50, alpha=0.75)
import numpy as np
errors_np = errors.cpu().numpy().flatten()

# Compute mean and standard deviation
mean_error = np.mean(errors_np)
std_error = np.std(errors_np)

fig, ax = plt.subplots(figsize=(8, 6))
bins = 50

# Plot histogram and get bin data
counts, bin_edges, patches = ax.hist(errors_np, bins=bins, alpha=0.75, edgecolor='black')

# Set y-axis to log scale
ax.set_yscale('log')

# Highlight outlier bins beyond 3 standard deviations
for count, edge_left, edge_right, patch in zip(counts, bin_edges[:-1], bin_edges[1:], patches):
if (edge_left < mean_error - 3 * std_error) or (edge_right > mean_error + 3 * std_error):
patch.set_facecolor('red')
else:
patch.set_facecolor('blue')

# Add vertical lines for mean and standard deviations
ax.axvline(mean_error, color='k', linestyle='dashed', linewidth=1, label='Mean')
ax.axvline(mean_error + 3 * std_error, color='r', linestyle='dashed', linewidth=1, label='±3 Std Dev')
ax.axvline(mean_error - 3 * std_error, color='r', linestyle='dashed', linewidth=1)

ax.set_xlabel('Error')
ax.set_ylabel('Frequency')
ax.set_ylabel('Frequency (Log Scale)')
ax.set_title(title)
ax.legend()

return fig

def on_test_epoch_end(self):
"""Logs metrics at the end of the test epoch."""
self._log_metrics(self.test_metrics, "test")
def _log_figures_for_each_phase(self, preds: torch.Tensor, target:torch.Tensor, phase:Literal['train', 'val', 'test']):
# Gather across processes
gathered_preds, gathered_targets, max_length, pad_size = self._get_tensors(self.test_preds, self.test_targets)
gathered_preds, gathered_targets, max_length, pad_size = self._get_tensors(preds, target)
# Clear the dictionaries
self.test_preds = {}
self.test_targets = {}

preds = {}
target = {}
# Proceed only on main process
if self.global_rank == 0:
# Remove padding
Expand All @@ -769,21 +772,27 @@ def on_test_epoch_end(self):
regression_fig = self._create_regression_plot(
gathered_targets,
gathered_preds,
title=f'Test Regression Plot - Epoch {self.current_epoch}'
title=f'{phase.capitalize()} Regression Plot - Epoch {self.current_epoch}'
)
# Generate error histogram plot
histogram_fig = self._create_error_histogram(
errors,
title=f'Test Error Histogram - Epoch {self.current_epoch}'
title=f'{phase.capitalize()} Error Histogram - Epoch {self.current_epoch}'
)

self._log_plots("test", regression_fig, histogram_fig)
self._log_plots(phase, regression_fig, histogram_fig)


def on_test_epoch_end(self):
"""Logs metrics at the end of the test epoch."""
self._log_metrics(self.test_metrics, "test")
self._log_figures_for_each_phase(self.test_preds, self.test_targets, 'test')
def on_train_epoch_end(self):
"""Logs metrics at the end of the training epoch."""
self._log_metrics(self.loss_metrics, "loss")
self._log_learning_rate()
self._log_histograms()
self._log_figures_for_each_phase(self.train_preds, self.train_targets, 'train')

def _log_learning_rate(self):
"""Logs the current learning rate."""
Expand Down

0 comments on commit 0940804

Please sign in to comment.