Skip to content

Commit

Permalink
add early_stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
zdaiot committed Oct 9, 2019
1 parent 4e9f025 commit 07de91a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 6 deletions.
16 changes: 11 additions & 5 deletions train_classify.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from torch import optim
import torch
import tqdm
from config import get_classify_config
from solver import Solver
from torch.utils.tensorboard import SummaryWriter
import datetime
import os
import codecs, json
import time
import pickle

from config import get_classify_config
from solver import Solver
from models.model import ClassifyResNet
from utils.loss import ClassifyLoss
from datasets.steel_dataset import classify_provider
from utils.cal_classify_accuracy import Meter
from utils.set_seed import seed_torch
import pickle
import random
from utils.easy_stopping import EarlyStopping


class TrainVal():
Expand Down Expand Up @@ -67,6 +67,7 @@ def train(self, train_loader, valid_loader):
optimizer = optim.Adam(self.model.module.parameters(), self.lr, weight_decay=self.weight_decay)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epoch+10)
global_step = 0
es = EarlyStopping(mode='min', patience=10)

for epoch in range(self.epoch):
epoch += 1
Expand Down Expand Up @@ -94,7 +95,12 @@ def train(self, train_loader, valid_loader):
global_step += len(train_loader)

# Print the log info
print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, epoch_loss/len(tbar)))
average_loss = epoch_loss / len(tbar)
print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, average_loss))

# 提前终止
if es.step(average_loss):
break

# 验证模型
class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_valid = \
Expand Down
10 changes: 9 additions & 1 deletion train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from config import get_seg_config
from solver import Solver
from utils.loss import MultiClassesSoftBCEDiceLoss
from utils.easy_stopping import EarlyStopping


class TrainVal():
Expand Down Expand Up @@ -75,6 +76,8 @@ def train(self, train_loader, valid_loader):
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epoch+10)
global_step = 0

es = EarlyStopping(mode='min', patience=10)

for epoch in range(self.epoch):
epoch += 1
epoch_loss = 0
Expand Down Expand Up @@ -105,7 +108,12 @@ def train(self, train_loader, valid_loader):
global_step += len(train_loader)

# Print the log info
print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, epoch_loss/len(tbar)))
average_loss = epoch_loss/len(tbar)
print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, average_loss))

# 提前终止
if es.step(average_loss):
break

# 验证模型
loss_valid, dice_valid, iou_valid = self.validation(valid_loader)
Expand Down
51 changes: 51 additions & 0 deletions utils/easy_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np


class EarlyStopping(object):
def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
self.mode = mode
self.min_delta = min_delta
self.patience = patience
self.best = None
self.num_bad_epochs = 0
self.is_better = None
self._init_is_better(mode, min_delta, percentage)

if patience == 0:
self.is_better = lambda a, b: True
self.step = lambda a: False

def step(self, metrics):
if self.best is None:
self.best = metrics
return False

if np.isnan(metrics):
return True

if self.is_better(metrics, self.best):
self.num_bad_epochs = 0
self.best = metrics
else:
self.num_bad_epochs += 1

if self.num_bad_epochs >= self.patience:
return True

return False

def _init_is_better(self, mode, min_delta, percentage):
if mode not in {'min', 'max'}:
raise ValueError('mode ' + mode + ' is unknown!')
if not percentage:
if mode == 'min':
self.is_better = lambda a, best: a < best - min_delta
if mode == 'max':
self.is_better = lambda a, best: a > best + min_delta
else:
if mode == 'min':
self.is_better = lambda a, best: a < best - (
best * min_delta / 100)
if mode == 'max':
self.is_better = lambda a, best: a > best + (
best * min_delta / 100)

0 comments on commit 07de91a

Please sign in to comment.