-
Notifications
You must be signed in to change notification settings - Fork 196
/
train.py
77 lines (68 loc) · 2.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
import cv2
import os
import numpy as np
from time import time
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
from framework import MyFrame
from loss import dice_bce_loss
from data import ImageFolder
SHAPE = (1024,1024)
ROOT = 'dataset/train/'
imagelist = filter(lambda x: x.find('sat')!=-1, os.listdir(ROOT))
trainlist = map(lambda x: x[:-8], imagelist)
NAME = 'log01_dink34'
BATCHSIZE_PER_CARD = 4
solver = MyFrame(DinkNet34, dice_bce_loss, 2e-4)
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
dataset = ImageFolder(trainlist, ROOT)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batchsize,
shuffle=True,
num_workers=4)
mylog = open('logs/'+NAME+'.log','w')
tic = time()
no_optim = 0
total_epoch = 300
train_epoch_best_loss = 100.
for epoch in range(1, total_epoch + 1):
data_loader_iter = iter(data_loader)
train_epoch_loss = 0
for img, mask in data_loader_iter:
solver.set_input(img, mask)
train_loss = solver.optimize()
train_epoch_loss += train_loss
train_epoch_loss /= len(data_loader_iter)
print >> mylog, '********'
print >> mylog, 'epoch:',epoch,' time:',int(time()-tic)
print >> mylog, 'train_loss:',train_epoch_loss
print >> mylog, 'SHAPE:',SHAPE
print '********'
print 'epoch:',epoch,' time:',int(time()-tic)
print 'train_loss:',train_epoch_loss
print 'SHAPE:',SHAPE
if train_epoch_loss >= train_epoch_best_loss:
no_optim += 1
else:
no_optim = 0
train_epoch_best_loss = train_epoch_loss
solver.save('weights/'+NAME+'.th')
if no_optim > 6:
print >> mylog, 'early stop at %d epoch' % epoch
print 'early stop at %d epoch' % epoch
break
if no_optim > 3:
if solver.old_lr < 5e-7:
break
solver.load('weights/'+NAME+'.th')
solver.update_lr(5.0, factor = True, mylog = mylog)
mylog.flush()
print >> mylog, 'Finish!'
print 'Finish!'
mylog.close()