Skip to content

Commit

Permalink
Checkpoint refactored to callback advances #23
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 22, 2024
1 parent e52deb0 commit b585c94
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 54 deletions.
3 changes: 2 additions & 1 deletion callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from callbacks.csv_logging import CSVLogging
from callbacks.epoch_results_logging import EpochResultsLogging
from callbacks.early_stopping import EarlyStopping
from callbacks.early_stopping import EarlyStopping
from callbacks.checkpoint import Checkpoint
60 changes: 60 additions & 0 deletions callbacks/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from callbacks.callback import Callback
import os
import torch

class Checkpoint(Callback):
"""
Callback class for saving model checkpoints during training.
Args:
checkpoint_dir (str): Directory to save the checkpoints.
model (torch.nn.Module): The model to be saved.
optimizer (torch.optim.Optimizer): The optimizer to be saved.
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to be saved. Default is None.
save_freq (int, optional): Frequency of saving checkpoints. Default is 1.
verbose (bool, optional): Whether to print the checkpoint save path. Default is False.
"""

def __init__(self, checkpoint_dir, model, optimizer, scheduler=None, save_freq=1, verbose=False):
self.checkpoint_dir = checkpoint_dir
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.save_freq = save_freq
self.verbose = verbose
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)

def on_epoch_end(self, epoch, logs=None):
"""
Callback function called at the end of each epoch.
Args:
epoch (int): The current epoch number.
logs (dict, optional): Dictionary containing training and validation losses. Default is None.
"""
if (epoch + 1) % self.save_freq == 0:
self.save_checkpoint(epoch, logs)

def save_checkpoint(self, epoch, logs=None):
"""
Save the model checkpoint.
Args:
epoch (int): The current epoch number.
logs (dict, optional): Dictionary containing training and validation losses. Default is None.
"""
state = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': logs.get('training_losses', []),
'val_losses': logs.get('validation_losses', []),
}
if self.scheduler:
state['scheduler_state_dict'] = self.scheduler.state_dict()

save_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
torch.save(state, save_path)
if self.verbose:
print(f"Checkpoint saved at {save_path}")
41 changes: 24 additions & 17 deletions tests/test_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import pytest
import os
import yaml
import torch
from trainers import get_trainer
from utils.metrics import Accuracy, Precision, Recall, F1Score
from datasets.transformations import get_transforms
from datasets.dataset import get_dataset
from models import get_model
import torch
import yaml
from callbacks import Checkpoint

CONFIG_TEST = {}

with open("./config/config_test.yaml", 'r') as file:
CONFIG_TEST = yaml.safe_load(file)

def test_checkpoint_functionality():
def test_checkpoint():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = get_transforms(CONFIG_TEST)
Expand All @@ -34,42 +35,48 @@ def test_checkpoint_functionality():
model = get_model(CONFIG_TEST['model']['name'], CONFIG_TEST['model']['num_classes'], CONFIG_TEST['model']['pretrained']).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam
optimizer_params = {'lr': CONFIG_TEST['training']['learning_rate']}
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG_TEST['training']['learning_rate'])
metrics = [Accuracy(), Precision(), Recall(), F1Score()]

trainer = get_trainer(CONFIG_TEST['trainer'], model=model, device=device)

checkpoint_dir = "./outputs/checkpoints/"
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_callback = Checkpoint(
checkpoint_dir=checkpoint_dir,
model=model,
optimizer=optimizer,
save_freq=5,
verbose=False
)

trainer.build(
criterion=criterion,
optimizer_class=optimizer,
optimizer_params=optimizer_params,
optimizer_class=torch.optim.Adam,
optimizer_params={'lr': CONFIG_TEST['training']['learning_rate']},
metrics=metrics
)

# Train the model and automatically save the checkpoint at the specified interval
trainer.train(
train_loader=train_loader,
num_epochs=6,
checkpoint_dir=checkpoint_dir,
valid_loader=None,
callbacks=[checkpoint_callback]
)

checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_5.pth')
assert os.path.exists(checkpoint_dir), "Checkpoint file was not created."
assert os.path.exists(checkpoint_path), "Checkpoint file was not created."

# Zero out the model parameters to simulate a restart
for param in model.parameters():
param.data.zero_()

# Load the checkpoint
trainer.load_checkpoint(checkpoint_path)

trainer.train(
train_loader=train_loader,
num_epochs=2,
checkpoint_dir=checkpoint_dir,
)

# Continue training or perform evaluation
_, metrics_results = trainer.evaluate(test_loader, verbose=False)
assert all([v >= 0 for v in metrics_results.values()]), "Metrics after resuming are not valid."

test_checkpoint_functionality()
test_checkpoint()
41 changes: 5 additions & 36 deletions trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import torch
from utils.plotting import plot_loss
from abc import ABC, abstractmethod
import time
from typing import Tuple
Expand Down Expand Up @@ -33,31 +31,8 @@ def __init__(self, model, device):
self.optimizer = None
self.scheduler = None
self.metrics = []

def save_checkpoint(self, save_path, epoch, train_losses, val_losses, logs) -> None:
"""
Saves the current state of the training process.
Args:
save_path (str): Directory to save checkpoint files.
epoch (int): Current epoch number.
train_losses (list): List of training losses up to the current epoch.
val_losses (list): List of validation losses up to the current epoch.
logs (dict): Dictionary containing other metric values.
"""
state = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': train_losses,
'val_losses': val_losses,
'logs': logs
}
if self.scheduler:
state['scheduler_state_dict'] = self.scheduler.state_dict()

torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch+1}.pth'))

def load_checkpoint(self, load_path):
def load_checkpoint(self, load_path) -> dict:
"""
Loads a checkpoint and resumes training or evaluation.
Args:
Expand Down Expand Up @@ -110,7 +85,7 @@ def _train_epoch(self, train_loader, epoch, num_epochs) -> float:
raise NotImplementedError(
"The train_epoch method must be implemented by the subclass.")

def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, checkpoint_dir=None, callbacks=None):
def train(self, train_loader, num_epochs, valid_loader=None, callbacks=None) -> None:
"""
Train the model for a given number of epochs, calculating metrics at the end of each epoch
for both training and validation sets.
Expand All @@ -119,8 +94,6 @@ def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, che
train_loader: The data loader for the training set.
num_epochs (int): The number of epochs to train the model.
valid_loader: The data loader for the validation set (optional).
plot_path: The path to save the training plot (optional).
checkpoint_dir: The directory to save model checkpoints (optional).
callbacks: List of callback objects to use during training (optional).
"""
logs = {}
Expand All @@ -139,6 +112,9 @@ def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, che
for epoch in range(num_epochs):
epoch_start_time = time.time()

for callback in callbacks:
callback.on_epoch_begin(epoch, logs=logs)

logs['epoch'] = epoch
epoch_loss_train = self._train_epoch(train_loader, epoch, num_epochs)
training_epoch_losses.append(epoch_loss_train)
Expand All @@ -157,15 +133,11 @@ def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, che
logs['val_metrics'] = {}

for callback in callbacks:
callback.on_epoch_begin(epoch, logs=logs)
callback.on_epoch_end(epoch, logs=logs)

epoch_time = time.time() - epoch_start_time
times.append(epoch_time)

if checkpoint_dir and (epoch + 1) % 5 == 0:
self.save_checkpoint(checkpoint_dir, epoch, training_epoch_losses, validation_epoch_losses, logs)

if not all(callback.should_continue(logs=logs) for callback in callbacks):
print(f"Training stopped early at epoch {epoch + 1}.")
break
Expand All @@ -178,9 +150,6 @@ def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, che
elapsed_time = time.time() - start_time
print(f"Training completed in: {elapsed_time:.2f} seconds")

if plot_path is not None:
plot_loss(training_epoch_losses, validation_epoch_losses, plot_path)

def predict(self, instance) -> torch.Tensor:
"""
Predict the output of the model for a given instance.
Expand Down

0 comments on commit b585c94

Please sign in to comment.