From 07de91aa24ac336590170bea48bcc4847364fa4e Mon Sep 17 00:00:00 2001 From: zdaiot Date: Wed, 9 Oct 2019 23:28:50 +0800 Subject: [PATCH] add early_stopping --- train_classify.py | 16 ++++++++----- train_segment.py | 10 ++++++++- utils/easy_stopping.py | 51 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 utils/easy_stopping.py diff --git a/train_classify.py b/train_classify.py index 63cc3a8..2196e94 100755 --- a/train_classify.py +++ b/train_classify.py @@ -1,21 +1,21 @@ from torch import optim import torch import tqdm -from config import get_classify_config -from solver import Solver from torch.utils.tensorboard import SummaryWriter import datetime import os import codecs, json import time +import pickle +from config import get_classify_config +from solver import Solver from models.model import ClassifyResNet from utils.loss import ClassifyLoss from datasets.steel_dataset import classify_provider from utils.cal_classify_accuracy import Meter from utils.set_seed import seed_torch -import pickle -import random +from utils.easy_stopping import EarlyStopping class TrainVal(): @@ -67,6 +67,7 @@ def train(self, train_loader, valid_loader): optimizer = optim.Adam(self.model.module.parameters(), self.lr, weight_decay=self.weight_decay) lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epoch+10) global_step = 0 + es = EarlyStopping(mode='min', patience=10) for epoch in range(self.epoch): epoch += 1 @@ -94,7 +95,12 @@ def train(self, train_loader, valid_loader): global_step += len(train_loader) # Print the log info - print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, epoch_loss/len(tbar))) + average_loss = epoch_loss / len(tbar) + print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, average_loss)) + + # 提前终止 + if es.step(average_loss): + break # 验证模型 class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_valid = \ diff --git a/train_segment.py b/train_segment.py index 73b44a9..46ef5a3 100644 --- a/train_segment.py +++ b/train_segment.py @@ -15,6 +15,7 @@ from config import get_seg_config from solver import Solver from utils.loss import MultiClassesSoftBCEDiceLoss +from utils.easy_stopping import EarlyStopping class TrainVal(): @@ -75,6 +76,8 @@ def train(self, train_loader, valid_loader): lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epoch+10) global_step = 0 + es = EarlyStopping(mode='min', patience=10) + for epoch in range(self.epoch): epoch += 1 epoch_loss = 0 @@ -105,7 +108,12 @@ def train(self, train_loader, valid_loader): global_step += len(train_loader) # Print the log info - print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, epoch_loss/len(tbar))) + average_loss = epoch_loss/len(tbar) + print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, average_loss)) + + # 提前终止 + if es.step(average_loss): + break # 验证模型 loss_valid, dice_valid, iou_valid = self.validation(valid_loader) diff --git a/utils/easy_stopping.py b/utils/easy_stopping.py new file mode 100644 index 0000000..2ebd5fd --- /dev/null +++ b/utils/easy_stopping.py @@ -0,0 +1,51 @@ +import numpy as np + + +class EarlyStopping(object): + def __init__(self, mode='min', min_delta=0, patience=10, percentage=False): + self.mode = mode + self.min_delta = min_delta + self.patience = patience + self.best = None + self.num_bad_epochs = 0 + self.is_better = None + self._init_is_better(mode, min_delta, percentage) + + if patience == 0: + self.is_better = lambda a, b: True + self.step = lambda a: False + + def step(self, metrics): + if self.best is None: + self.best = metrics + return False + + if np.isnan(metrics): + return True + + if self.is_better(metrics, self.best): + self.num_bad_epochs = 0 + self.best = metrics + else: + self.num_bad_epochs += 1 + + if self.num_bad_epochs >= self.patience: + return True + + return False + + def _init_is_better(self, mode, min_delta, percentage): + if mode not in {'min', 'max'}: + raise ValueError('mode ' + mode + ' is unknown!') + if not percentage: + if mode == 'min': + self.is_better = lambda a, best: a < best - min_delta + if mode == 'max': + self.is_better = lambda a, best: a > best + min_delta + else: + if mode == 'min': + self.is_better = lambda a, best: a < best - ( + best * min_delta / 100) + if mode == 'max': + self.is_better = lambda a, best: a > best + ( + best * min_delta / 100)