-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
164 lines (127 loc) · 6.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch.optim as optim
from torch.utils.data import DataLoader
from Modeling.DerainDataset import *
from Modeling.utils import *
from torch.optim.lr_scheduler import MultiStepLR
from Modeling.SSIM import SSIM
from Modeling.network import *
from Modeling.fpn import *
from torchvision import datasets, transforms
from option import *
from loss_fun import *
import torch.nn as nn
import lpips
loss_fn_vgg = lpips.LPIPS(net='alex').to(device) # choose between alexnet, VGG, or others
def SegLoss(out_train, device, device_ids):
num_of_SegClass = 21
seg = fpn(num_of_SegClass).to(device)
seg = nn.DataParallel(seg, device_ids=device_ids)
seg_criterion = FocalLoss(gamma=2).to(device)
# build seg. output
seg_output = seg(out_train).to(device)
# build seg. target
target = (get_NoGT_target(seg_output)).to(device)
# Get seg. loss
seg_loss = seg_criterion(seg_output, target).to(device)
# freeze seg. backpropagation
for param in seg.parameters():
param.requires_grad = False
return seg_loss
def LpisLoss(out_train, target_train, device):
new_out_train = (torch.max(out_train)-out_train)/(torch.max(out_train)-torch.min(out_train))
new_target_train = (torch.max(target_train)-target_train)/(torch.max(target_train)-torch.min(target_train))
resize = transforms.Resize([256, 256])
new_target_train = resize(new_target_train)
new_out_train = resize(new_out_train)
lpips_num = 0
for ii in range(len(new_out_train)):
outtrain = new_out_train[ii].reshape((1,3,256,256))
targettrain = new_target_train[ii].reshape((1,3,256,256))
lpips_num += float(loss_fn_vgg(targettrain.to(device), outtrain.to(device)))
lpips_num = torch.tensor(lpips_num).to(device)
return lpips_num
def train():
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # adjust according to # GPUs you are using
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
device_ids = [Id for Id in range(torch.cuda.device_count())]
print(device_ids)
if opt.preprocess:
if opt.data_path.find('RainTrainH') != -1:
print(opt.data_path.find('RainTrainH'))
prepare_data_RainTrainH(data_path=opt.data_path, patch_size=100, stride=80)
elif opt.data_path.find('RainTrainL') != -1:
prepare_data_RainTrainL(data_path=opt.data_path, patch_size=100, stride=80)
elif opt.data_path.find('Rain12600') != -1:
prepare_data_Rain12600(data_path=opt.data_path, patch_size=100, stride=100)
else:
print('unkown datasets: please define prepare data function in DerainDataset.py')
print('Loading Synthetic Rainy dataset ...\n')
dataset_train = Dataset(data_path=opt.data_path)
loader_train = DataLoader(dataset=dataset_train,
num_workers=0,
batch_size=opt.batch_size,
shuffle=True)
print("# of training samples: %d\n" % int(len(dataset_train)))
# Build deraining model
model = SAPNet(recurrent_iter=opt.recurrent_iter,
use_dilation=opt.use_dilation).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
print_network(model)
# Define SSIM and constrative loss
criterion = SSIM().to(device)
loss_C = ContrastLoss().to(device)
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
scheduler = MultiStepLR(optimizer, milestones=opt.milestone, gamma=0.2) # learning rates
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.95)
# load the lastest model
initial_epoch = findLastCheckpoint(save_dir=opt.save_path)
if initial_epoch > 0:
print('resuming by loading epoch %d' % initial_epoch)
model.load_state_dict(torch.load(os.path.join(opt.save_path, 'net_epoch%d.pth' % initial_epoch)))
# Start training
for epoch in range(initial_epoch, opt.epochs):
scheduler.step(epoch)
for param_group in optimizer.param_groups:
print('learning rate %f' % param_group["lr"])
if opt.use_stage1:
# Phase 1 Training (Synthetic images)
for i, (input_train, target_train) in enumerate(loader_train, 0):
model.train()
model.zero_grad()
optimizer.zero_grad()
input_train, target_train = Variable(input_train).to(device), Variable(target_train).to(device)
# Obtain the derained image and calculate ssim loss
out_train, _ = model(input_train)
#print("input_train", input_train.size()) # torch.Size([batch_size, 3, 100, 100])
#print("target_train", target_train.size()) # torch.Size([batch_size, 3, 100, 100])
#print("out_train", out_train.size()) # torch.Size([batch_size, 3, 100, 100])
pixel_metric = criterion(target_train, out_train)
# Negative SSIM loss
loss_ssim = -pixel_metric
# Constrative loss
loss_contrast = 4 * loss_C(out_train, target_train, input_train) if opt.use_contrast else 0 # scale the contrast loss
# LPIS loss
loss_lpis = 10 * LpisLoss(out_train, target_train, device) if opt.use_lpis else 0 # scale the lpips loss
# Segmentation loss
loss_seg = SegLoss(out_train, device, device_ids) if (opt.use_seg_stage1 and epoch > 50) else 0
# Total loss
loss = loss_ssim + 0.1 * loss_contrast + 0.1 * loss_lpis + 0.1 * loss_seg
# backward and update parameters.
loss.backward()
optimizer.step()
model.eval()
out_train, _ = model(input_train)
out_train = torch.clamp(out_train, 0., 1.)
psnr_train = batch_PSNR(out_train, target_train, 1.)
if i % 50 == 0:
print("[epoch %d][%d/%d] loss: %.4f, pixel_metric: %.4f, PSNR: %.4f" %
(epoch + 1, i + 1, len(loader_train), loss.item(), pixel_metric.item(), psnr_train))
# save model
os.makedirs(opt.save_path, exist_ok=True)
torch.save(model.state_dict(), os.path.join(opt.save_path, 'net_latest.pth'))
if epoch % opt.save_freq == 0:
torch.save(model.state_dict(), os.path.join(opt.save_path, 'net_epoch%d.pth' % (epoch+1)))
if __name__ == "__main__":
train()