diff --git a/callbacks/__init__.py b/callbacks/__init__.py index 086a901..81213aa 100644 --- a/callbacks/__init__.py +++ b/callbacks/__init__.py @@ -1,3 +1,4 @@ from callbacks.csv_logging import CSVLogging from callbacks.epoch_results_logging import EpochResultsLogging -from callbacks.early_stopping import EarlyStopping \ No newline at end of file +from callbacks.early_stopping import EarlyStopping +from callbacks.checkpoint import Checkpoint \ No newline at end of file diff --git a/callbacks/checkpoint.py b/callbacks/checkpoint.py new file mode 100644 index 0000000..0ca0ff7 --- /dev/null +++ b/callbacks/checkpoint.py @@ -0,0 +1,60 @@ +from callbacks.callback import Callback +import os +import torch + +class Checkpoint(Callback): + """ + Callback class for saving model checkpoints during training. + + Args: + checkpoint_dir (str): Directory to save the checkpoints. + model (torch.nn.Module): The model to be saved. + optimizer (torch.optim.Optimizer): The optimizer to be saved. + scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to be saved. Default is None. + save_freq (int, optional): Frequency of saving checkpoints. Default is 1. + verbose (bool, optional): Whether to print the checkpoint save path. Default is False. + """ + + def __init__(self, checkpoint_dir, model, optimizer, scheduler=None, save_freq=1, verbose=False): + self.checkpoint_dir = checkpoint_dir + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.save_freq = save_freq + self.verbose = verbose + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir, exist_ok=True) + + def on_epoch_end(self, epoch, logs=None): + """ + Callback function called at the end of each epoch. + + Args: + epoch (int): The current epoch number. + logs (dict, optional): Dictionary containing training and validation losses. Default is None. + """ + if (epoch + 1) % self.save_freq == 0: + self.save_checkpoint(epoch, logs) + + def save_checkpoint(self, epoch, logs=None): + """ + Save the model checkpoint. + + Args: + epoch (int): The current epoch number. + logs (dict, optional): Dictionary containing training and validation losses. Default is None. + """ + state = { + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': logs.get('training_losses', []), + 'val_losses': logs.get('validation_losses', []), + } + if self.scheduler: + state['scheduler_state_dict'] = self.scheduler.state_dict() + + save_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') + torch.save(state, save_path) + if self.verbose: + print(f"Checkpoint saved at {save_path}") diff --git a/tests/test_checkpoints.py b/tests/test_checkpoints.py index 12044d2..c8d8ea4 100644 --- a/tests/test_checkpoints.py +++ b/tests/test_checkpoints.py @@ -1,19 +1,20 @@ import pytest import os +import yaml +import torch from trainers import get_trainer from utils.metrics import Accuracy, Precision, Recall, F1Score from datasets.transformations import get_transforms from datasets.dataset import get_dataset from models import get_model -import torch -import yaml +from callbacks import Checkpoint CONFIG_TEST = {} with open("./config/config_test.yaml", 'r') as file: CONFIG_TEST = yaml.safe_load(file) -def test_checkpoint_functionality(): +def test_checkpoint(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transforms = get_transforms(CONFIG_TEST) @@ -34,42 +35,48 @@ def test_checkpoint_functionality(): model = get_model(CONFIG_TEST['model']['name'], CONFIG_TEST['model']['num_classes'], CONFIG_TEST['model']['pretrained']).to(device) criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.Adam - optimizer_params = {'lr': CONFIG_TEST['training']['learning_rate']} + optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG_TEST['training']['learning_rate']) metrics = [Accuracy(), Precision(), Recall(), F1Score()] trainer = get_trainer(CONFIG_TEST['trainer'], model=model, device=device) checkpoint_dir = "./outputs/checkpoints/" os.makedirs(checkpoint_dir, exist_ok=True) - + checkpoint_callback = Checkpoint( + checkpoint_dir=checkpoint_dir, + model=model, + optimizer=optimizer, + save_freq=5, + verbose=False + ) + trainer.build( criterion=criterion, - optimizer_class=optimizer, - optimizer_params=optimizer_params, + optimizer_class=torch.optim.Adam, + optimizer_params={'lr': CONFIG_TEST['training']['learning_rate']}, metrics=metrics ) + + # Train the model and automatically save the checkpoint at the specified interval trainer.train( train_loader=train_loader, num_epochs=6, - checkpoint_dir=checkpoint_dir, + valid_loader=None, + callbacks=[checkpoint_callback] ) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_5.pth') - assert os.path.exists(checkpoint_dir), "Checkpoint file was not created." + assert os.path.exists(checkpoint_path), "Checkpoint file was not created." + # Zero out the model parameters to simulate a restart for param in model.parameters(): param.data.zero_() + # Load the checkpoint trainer.load_checkpoint(checkpoint_path) - trainer.train( - train_loader=train_loader, - num_epochs=2, - checkpoint_dir=checkpoint_dir, - ) - + # Continue training or perform evaluation _, metrics_results = trainer.evaluate(test_loader, verbose=False) assert all([v >= 0 for v in metrics_results.values()]), "Metrics after resuming are not valid." -test_checkpoint_functionality() +test_checkpoint() diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index c55a2a9..51b3e07 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -1,6 +1,4 @@ -import os import torch -from utils.plotting import plot_loss from abc import ABC, abstractmethod import time from typing import Tuple @@ -33,31 +31,8 @@ def __init__(self, model, device): self.optimizer = None self.scheduler = None self.metrics = [] - - def save_checkpoint(self, save_path, epoch, train_losses, val_losses, logs) -> None: - """ - Saves the current state of the training process. - Args: - save_path (str): Directory to save checkpoint files. - epoch (int): Current epoch number. - train_losses (list): List of training losses up to the current epoch. - val_losses (list): List of validation losses up to the current epoch. - logs (dict): Dictionary containing other metric values. - """ - state = { - 'epoch': epoch, - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'train_losses': train_losses, - 'val_losses': val_losses, - 'logs': logs - } - if self.scheduler: - state['scheduler_state_dict'] = self.scheduler.state_dict() - - torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch+1}.pth')) - def load_checkpoint(self, load_path): + def load_checkpoint(self, load_path) -> dict: """ Loads a checkpoint and resumes training or evaluation. Args: @@ -110,7 +85,7 @@ def _train_epoch(self, train_loader, epoch, num_epochs) -> float: raise NotImplementedError( "The train_epoch method must be implemented by the subclass.") - def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, checkpoint_dir=None, callbacks=None): + def train(self, train_loader, num_epochs, valid_loader=None, callbacks=None) -> None: """ Train the model for a given number of epochs, calculating metrics at the end of each epoch for both training and validation sets. @@ -119,8 +94,6 @@ def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, che train_loader: The data loader for the training set. num_epochs (int): The number of epochs to train the model. valid_loader: The data loader for the validation set (optional). - plot_path: The path to save the training plot (optional). - checkpoint_dir: The directory to save model checkpoints (optional). callbacks: List of callback objects to use during training (optional). """ logs = {} @@ -139,6 +112,9 @@ def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, che for epoch in range(num_epochs): epoch_start_time = time.time() + for callback in callbacks: + callback.on_epoch_begin(epoch, logs=logs) + logs['epoch'] = epoch epoch_loss_train = self._train_epoch(train_loader, epoch, num_epochs) training_epoch_losses.append(epoch_loss_train) @@ -157,15 +133,11 @@ def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, che logs['val_metrics'] = {} for callback in callbacks: - callback.on_epoch_begin(epoch, logs=logs) callback.on_epoch_end(epoch, logs=logs) epoch_time = time.time() - epoch_start_time times.append(epoch_time) - if checkpoint_dir and (epoch + 1) % 5 == 0: - self.save_checkpoint(checkpoint_dir, epoch, training_epoch_losses, validation_epoch_losses, logs) - if not all(callback.should_continue(logs=logs) for callback in callbacks): print(f"Training stopped early at epoch {epoch + 1}.") break @@ -178,9 +150,6 @@ def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, che elapsed_time = time.time() - start_time print(f"Training completed in: {elapsed_time:.2f} seconds") - if plot_path is not None: - plot_loss(training_epoch_losses, validation_epoch_losses, plot_path) - def predict(self, instance) -> torch.Tensor: """ Predict the output of the model for a given instance.