Skip to content

Commit

Permalink
Merge pull request #1 from HernandoR/dev/optim_model
Browse files Browse the repository at this point in the history
Dev/optim model
  • Loading branch information
HernandoR authored Jun 2, 2023
2 parents 4f9cfb8 + eb457cd commit fa7c19c
Show file tree
Hide file tree
Showing 47 changed files with 9,772 additions and 1,407 deletions.
22 changes: 14 additions & 8 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,24 @@ ENV/
# mypy
.mypy_cache/


# input data, saved log, checkpoints
data/
wandb/
input/
output/
data/*
wandb/*
input/*
output/*
saved/
datasets/
datasets/*



model/*/*


# editor, os cache directory
.vscode/
.idea/
__MACOSX/
.vscode/*
.idea/*
__MACOSX/*

# personal config files
*.cfg
Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The submitable (or newest) version is in the root folder / "[baseline.ipynb](dev

following is the instruction of the template.


Data URL https://gist.github.com/nat/e7266a5c765686b7976df10d3a85041b
------


Expand Down Expand Up @@ -79,7 +81,7 @@ Try `python train.py -c config.json` to run code.

Config files are in `.json` format:

```javascript
```json5
{
"name": "Mnist_LeNet", // training session name
"n_gpu": 1, // number of GPUs to use for training.
Expand All @@ -96,7 +98,7 @@ Config files are in `.json` format:
"data_dir": "data/", // dataset path
"batch_size": 64, // batch size
"shuffle": true, // shuffle training data before splitting
"validation_split": 0.1 // size of validation dataset. float(portion) or int(number of samples)
"validation_split": 0.1, // size of validation dataset. float(portion) or int(number of samples)
"num_workers": 2, // number of cpu processes to be used for data loading
}
},
Expand Down Expand Up @@ -125,8 +127,8 @@ Config files are in `.json` format:
"save_freq": 1, // save checkpoints every save_freq epochs
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full

"monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable.
"early_stop": 10 // number of epochs to wait before early stop. set 0 to disable.
"monitor": "min val_loss", // mode and metric for model performance monitoring. set 'off' to disable.
"early_stop": 10, // number of epochs to wait before early stop. set 0 to disable.

"tensorboard": true, // enable tensorboard visualization
}
Expand Down
3 changes: 3 additions & 0 deletions base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_data_loader import *
from .base_model import *
from .base_trainer import *
77 changes: 77 additions & 0 deletions base/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler

from logger import Loggers

Logger = Loggers.get_logger(__name__)


class BaseDataLoader(DataLoader):
"""
Base class for all data loaders
"""

def __init__(self, dataset, batch_size: int,
shuffle: bool, validation_split: int,
num_workers: int, collate_fn=default_collate, **kwargs):
self.validation_split = validation_split
self.shuffle = shuffle

self.batch_idx = 0
self.n_samples = len(dataset)

self.sampler, self.valid_sampler = self._split_sampler(
self.validation_split)

self.init_kwargs = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
self.init_kwargs.update(kwargs)
super().__init__(sampler=self.sampler, **self.init_kwargs)

def _split_sampler(self, split):
if split == 0.0:
return None, None

idx_full = np.arange(self.n_samples)

np.random.seed(0)
np.random.shuffle(idx_full)

if isinstance(split, int):
assert split > 0
assert split < self.n_samples, ...
"validation set size is configured to be larger than entire dataset."
if split < 1.0:
split = int(split * self.n_samples)
Logger.info(
f"got an fraction number for validation split, convert to {split} samples")
len_valid = split
else:
len_valid = int(self.n_samples * split)

valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# turn off shuffle option which is mutually exclusive with sampler
self.shuffle = False
self.n_samples = len(train_idx)

return train_sampler, valid_sampler

def split_validation(self):
if self.valid_sampler is None:
return None
else:
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)


26 changes: 26 additions & 0 deletions base/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch.nn as nn
import numpy as np
from abc import abstractmethod


class BaseModel(nn.Module):
"""
Base class for all models
"""

@abstractmethod
def forward(self, *inputs):
"""
Forward pass logic
:return: Model output
"""
raise NotImplementedError

def __str__(self):
"""
Model prints with number of trainable parameters
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super().__str__() + '\nTrainable parameters: {}'.format(params)
168 changes: 168 additions & 0 deletions base/base_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from abc import abstractmethod
from pathlib import Path

import torch
import wandb
from numpy import inf

from logger import Loggers


class BaseTrainer:
"""
Base class for all trainers
"""

def __init__(self, model, criterion, metric_ftns, optimizer, config):
self.config = config
self.logger = Loggers.get_logger('trainer')
self.model = model
self.model_id = config.model_id
self.criterion = criterion
self.metric_ftns = metric_ftns
self.optimizer = optimizer

cfg_trainer = config['trainer']
self.epochs = cfg_trainer['epochs']
self.save_period = cfg_trainer['save_period']
self.monitor = cfg_trainer.get('monitor', 'off')

# configuration to monitor model performance and save best
if self.monitor == 'off':
self.mnt_mode = 'off'
self.mnt_best = 0
else:
self.mnt_mode, self.mnt_metric = self.monitor.split()
self.mnt_mode = self.mnt_mode.lower()
assert self.mnt_mode in ['min', 'max']

self.mnt_best = inf if self.mnt_mode == 'min' else -inf
self.early_stop = cfg_trainer.get('early_stop', inf)
if self.early_stop <= 0:
self.early_stop = inf

self.start_epoch = 1

self.checkpoint_dir = Path(config['PATHS']['CP_DIR'])

if config.resume_path is not None:
self._resume_checkpoint(config.resume_path)

@abstractmethod
def _train_epoch(self, epoch):
"""
Training logic for an epoch
:param epoch: Current epoch number
"""
raise NotImplementedError

def train(self):
"""
Full training logic
"""

not_improved_count = 0
for epoch in range(self.start_epoch, self.epochs + 1):
# train epoch
# return metrics that may or may not be logged
# TODO find how to config the mnt_metric
result = self._train_epoch(epoch)

# save logged information into log dict
log = {'epoch': epoch}
log.update(result)

# print logged information to the screen
for key, value in log.items():
# self.logger.info(' {:15s}: {}'.format(str(key), value))
self.logger.info(f" {key:15s}: {value}")

# evaluate model performance according to configured metric, save the best checkpoint as model_best
best = False
if self.mnt_mode != 'off':
try:
# check whether model performance improved or not, according to specified metric(mnt_metric)
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
except KeyError:
self.logger.warning("Warning: Metric '{}' is not found. "
"Model performance monitoring is disabled.".format(self.mnt_metric))
self.mnt_mode = 'off'
improved = False

if improved:
self.mnt_best = log[self.mnt_metric]
not_improved_count = 0
best = True
else:
not_improved_count += 1
self.logger.info("Early stop count: {}".format(not_improved_count))

if not_improved_count > self.early_stop:
self.logger.info("Validation performance didn\'t improve for {} epochs. "
"Training stops.".format(self.early_stop))
wandb.run.summary["early_stop"] = True

break

if epoch % self.save_period == 0:
self._save_checkpoint(epoch, save_best=best)

wandb.save(str(self.checkpoint_dir / f'{self.model_id}_best.pth'))
wandb.finish()

def _save_checkpoint(self, epoch, save_best=False):
"""
Saving checkpoints
:param epoch: current epoch number
:param log: logging information of the epoch
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
"""
arch = type(self.model).__name__
state = {
'arch': arch,
'epoch': epoch,
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.mnt_best,
# 'config': self.config
'config': {
k: v for k, v in self.config.items() if k in ['model', 'optimizer', 'trainer']
}
}
filename = str(self.checkpoint_dir / f'{self.model_id}_checkpoints.pth')
torch.save(state, filename)
self.logger.info("Saving checkpoint: {} ...".format(filename))
if save_best:
best_path = str(self.checkpoint_dir / f'{self.model_id}_best.pth')
torch.save(state, best_path)
self.logger.info("Saving current best: model_best.pth ...")

def _resume_checkpoint(self, resume_path):
"""
Resume from saved checkpoints
:param resume_path: Checkpoint path to be resumed
"""
resume_path = str(resume_path)
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint['epoch'] + 1
self.mnt_best = checkpoint['monitor_best']

# load architecture params from checkpoint.
if checkpoint['config']['model'] != self.config['arch']:
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
"checkpoint. This may yield an exception while state_dict is being loaded.")
self.model.load_state_dict(checkpoint['model'])

# load optimizer state from checkpoint only when optimizer type is not changed.
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
"Optimizer parameters not being resumed.")
else:
self.optimizer.load_state_dict(checkpoint['optimizer'])

self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
Loading

0 comments on commit fa7c19c

Please sign in to comment.