-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from HernandoR/dev/optim_model
Dev/optim model
- Loading branch information
Showing
47 changed files
with
9,772 additions
and
1,407 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base_data_loader import * | ||
from .base_model import * | ||
from .base_trainer import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.dataloader import default_collate | ||
from torch.utils.data.sampler import SubsetRandomSampler | ||
|
||
from logger import Loggers | ||
|
||
Logger = Loggers.get_logger(__name__) | ||
|
||
|
||
class BaseDataLoader(DataLoader): | ||
""" | ||
Base class for all data loaders | ||
""" | ||
|
||
def __init__(self, dataset, batch_size: int, | ||
shuffle: bool, validation_split: int, | ||
num_workers: int, collate_fn=default_collate, **kwargs): | ||
self.validation_split = validation_split | ||
self.shuffle = shuffle | ||
|
||
self.batch_idx = 0 | ||
self.n_samples = len(dataset) | ||
|
||
self.sampler, self.valid_sampler = self._split_sampler( | ||
self.validation_split) | ||
|
||
self.init_kwargs = { | ||
'dataset': dataset, | ||
'batch_size': batch_size, | ||
'shuffle': self.shuffle, | ||
'collate_fn': collate_fn, | ||
'num_workers': num_workers | ||
} | ||
self.init_kwargs.update(kwargs) | ||
super().__init__(sampler=self.sampler, **self.init_kwargs) | ||
|
||
def _split_sampler(self, split): | ||
if split == 0.0: | ||
return None, None | ||
|
||
idx_full = np.arange(self.n_samples) | ||
|
||
np.random.seed(0) | ||
np.random.shuffle(idx_full) | ||
|
||
if isinstance(split, int): | ||
assert split > 0 | ||
assert split < self.n_samples, ... | ||
"validation set size is configured to be larger than entire dataset." | ||
if split < 1.0: | ||
split = int(split * self.n_samples) | ||
Logger.info( | ||
f"got an fraction number for validation split, convert to {split} samples") | ||
len_valid = split | ||
else: | ||
len_valid = int(self.n_samples * split) | ||
|
||
valid_idx = idx_full[0:len_valid] | ||
train_idx = np.delete(idx_full, np.arange(0, len_valid)) | ||
|
||
train_sampler = SubsetRandomSampler(train_idx) | ||
valid_sampler = SubsetRandomSampler(valid_idx) | ||
|
||
# turn off shuffle option which is mutually exclusive with sampler | ||
self.shuffle = False | ||
self.n_samples = len(train_idx) | ||
|
||
return train_sampler, valid_sampler | ||
|
||
def split_validation(self): | ||
if self.valid_sampler is None: | ||
return None | ||
else: | ||
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch.nn as nn | ||
import numpy as np | ||
from abc import abstractmethod | ||
|
||
|
||
class BaseModel(nn.Module): | ||
""" | ||
Base class for all models | ||
""" | ||
|
||
@abstractmethod | ||
def forward(self, *inputs): | ||
""" | ||
Forward pass logic | ||
:return: Model output | ||
""" | ||
raise NotImplementedError | ||
|
||
def __str__(self): | ||
""" | ||
Model prints with number of trainable parameters | ||
""" | ||
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) | ||
params = sum([np.prod(p.size()) for p in model_parameters]) | ||
return super().__str__() + '\nTrainable parameters: {}'.format(params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
from abc import abstractmethod | ||
from pathlib import Path | ||
|
||
import torch | ||
import wandb | ||
from numpy import inf | ||
|
||
from logger import Loggers | ||
|
||
|
||
class BaseTrainer: | ||
""" | ||
Base class for all trainers | ||
""" | ||
|
||
def __init__(self, model, criterion, metric_ftns, optimizer, config): | ||
self.config = config | ||
self.logger = Loggers.get_logger('trainer') | ||
self.model = model | ||
self.model_id = config.model_id | ||
self.criterion = criterion | ||
self.metric_ftns = metric_ftns | ||
self.optimizer = optimizer | ||
|
||
cfg_trainer = config['trainer'] | ||
self.epochs = cfg_trainer['epochs'] | ||
self.save_period = cfg_trainer['save_period'] | ||
self.monitor = cfg_trainer.get('monitor', 'off') | ||
|
||
# configuration to monitor model performance and save best | ||
if self.monitor == 'off': | ||
self.mnt_mode = 'off' | ||
self.mnt_best = 0 | ||
else: | ||
self.mnt_mode, self.mnt_metric = self.monitor.split() | ||
self.mnt_mode = self.mnt_mode.lower() | ||
assert self.mnt_mode in ['min', 'max'] | ||
|
||
self.mnt_best = inf if self.mnt_mode == 'min' else -inf | ||
self.early_stop = cfg_trainer.get('early_stop', inf) | ||
if self.early_stop <= 0: | ||
self.early_stop = inf | ||
|
||
self.start_epoch = 1 | ||
|
||
self.checkpoint_dir = Path(config['PATHS']['CP_DIR']) | ||
|
||
if config.resume_path is not None: | ||
self._resume_checkpoint(config.resume_path) | ||
|
||
@abstractmethod | ||
def _train_epoch(self, epoch): | ||
""" | ||
Training logic for an epoch | ||
:param epoch: Current epoch number | ||
""" | ||
raise NotImplementedError | ||
|
||
def train(self): | ||
""" | ||
Full training logic | ||
""" | ||
|
||
not_improved_count = 0 | ||
for epoch in range(self.start_epoch, self.epochs + 1): | ||
# train epoch | ||
# return metrics that may or may not be logged | ||
# TODO find how to config the mnt_metric | ||
result = self._train_epoch(epoch) | ||
|
||
# save logged information into log dict | ||
log = {'epoch': epoch} | ||
log.update(result) | ||
|
||
# print logged information to the screen | ||
for key, value in log.items(): | ||
# self.logger.info(' {:15s}: {}'.format(str(key), value)) | ||
self.logger.info(f" {key:15s}: {value}") | ||
|
||
# evaluate model performance according to configured metric, save the best checkpoint as model_best | ||
best = False | ||
if self.mnt_mode != 'off': | ||
try: | ||
# check whether model performance improved or not, according to specified metric(mnt_metric) | ||
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ | ||
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) | ||
except KeyError: | ||
self.logger.warning("Warning: Metric '{}' is not found. " | ||
"Model performance monitoring is disabled.".format(self.mnt_metric)) | ||
self.mnt_mode = 'off' | ||
improved = False | ||
|
||
if improved: | ||
self.mnt_best = log[self.mnt_metric] | ||
not_improved_count = 0 | ||
best = True | ||
else: | ||
not_improved_count += 1 | ||
self.logger.info("Early stop count: {}".format(not_improved_count)) | ||
|
||
if not_improved_count > self.early_stop: | ||
self.logger.info("Validation performance didn\'t improve for {} epochs. " | ||
"Training stops.".format(self.early_stop)) | ||
wandb.run.summary["early_stop"] = True | ||
|
||
break | ||
|
||
if epoch % self.save_period == 0: | ||
self._save_checkpoint(epoch, save_best=best) | ||
|
||
wandb.save(str(self.checkpoint_dir / f'{self.model_id}_best.pth')) | ||
wandb.finish() | ||
|
||
def _save_checkpoint(self, epoch, save_best=False): | ||
""" | ||
Saving checkpoints | ||
:param epoch: current epoch number | ||
:param log: logging information of the epoch | ||
:param save_best: if True, rename the saved checkpoint to 'model_best.pth' | ||
""" | ||
arch = type(self.model).__name__ | ||
state = { | ||
'arch': arch, | ||
'epoch': epoch, | ||
'model': self.model.state_dict(), | ||
'optimizer': self.optimizer.state_dict(), | ||
'monitor_best': self.mnt_best, | ||
# 'config': self.config | ||
'config': { | ||
k: v for k, v in self.config.items() if k in ['model', 'optimizer', 'trainer'] | ||
} | ||
} | ||
filename = str(self.checkpoint_dir / f'{self.model_id}_checkpoints.pth') | ||
torch.save(state, filename) | ||
self.logger.info("Saving checkpoint: {} ...".format(filename)) | ||
if save_best: | ||
best_path = str(self.checkpoint_dir / f'{self.model_id}_best.pth') | ||
torch.save(state, best_path) | ||
self.logger.info("Saving current best: model_best.pth ...") | ||
|
||
def _resume_checkpoint(self, resume_path): | ||
""" | ||
Resume from saved checkpoints | ||
:param resume_path: Checkpoint path to be resumed | ||
""" | ||
resume_path = str(resume_path) | ||
self.logger.info("Loading checkpoint: {} ...".format(resume_path)) | ||
checkpoint = torch.load(resume_path) | ||
self.start_epoch = checkpoint['epoch'] + 1 | ||
self.mnt_best = checkpoint['monitor_best'] | ||
|
||
# load architecture params from checkpoint. | ||
if checkpoint['config']['model'] != self.config['arch']: | ||
self.logger.warning("Warning: Architecture configuration given in config file is different from that of " | ||
"checkpoint. This may yield an exception while state_dict is being loaded.") | ||
self.model.load_state_dict(checkpoint['model']) | ||
|
||
# load optimizer state from checkpoint only when optimizer type is not changed. | ||
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: | ||
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " | ||
"Optimizer parameters not being resumed.") | ||
else: | ||
self.optimizer.load_state_dict(checkpoint['optimizer']) | ||
|
||
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) |
Oops, something went wrong.