-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
260 lines (195 loc) · 8.37 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import matplotlib.pyplot as plt
from models.vgg16bn_disp import DepthNet
from numpy import float32
from loss.loss_functions import *
import time
import torch
import pathlib
from preprocessing.data_transformations import get_split
# Device setup/recognition
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Available device is', device)
# Loading hyperparameters
from hyperparameters import *
w1, w2 = W1, W2 # Loss weights
lr = LR
batch_size = BATCH_SIZE
gamma, step = GAMMA, STEP
use_scheduler = USE_SCHEDULER
# Initialize loss list
training_loss = []
validation_loss = []
# Paths
model_path = 'models/' + MODEL_NAME
images_dir = 'images/' + model_path.split('/')[-1]
pathlib.Path(images_dir).mkdir(parents=True, exist_ok=True)
# Summary writer function
def summary_writter():
summary_file = open(images_dir + '/summary.txt', "w")
summary = "Model name : " + str(MODEL_NAME) + \
"\nNumber of epochs : " + str(EPOCHS) + \
"\nlearning rate : " + str(lr) + \
"\nbatch size : " + str(batch_size) + \
"\nWeighted loss : " + str(w1) + ' ' + str(w2) + \
"\nLRScheduler : " + str(USE_SCHEDULER) + \
"\nLoss : " + LOSS + \
"\nLRScheduler (gamma, step) : " + str(gamma) + ' ' + str(step) + \
"\nL1 smooth : " + str(SMOOTH_L1) + \
"\nRescaled image : " + str(IMG_HEIGHT_RESCALE) + ', ' + str(IMG_WIDTH_RESCALE) + \
"\nCropped image : " + str(IMG_HEIGHT) + ', ' + str(IMG_WIDTH) + \
"\nBorder size : " + str(BORDER_SIZE) + \
"\nMixUp : " + str(MIXUP) + \
"\nMixUp ratio : " + str(MIXUP_SIZE) + \
"\nBlend : " + str(BLEND) + \
"\nBlend ratio : " + str(BLEND_SIZE) + \
"\nRotation : " + str(ROTATE) + \
"\nRotation angle : " + str(ROTATION_ANGLE) + \
"\nRotation ratio : " + str(ROTATION_SIZE) + \
"\n"
n = summary_file.write(summary)
summary_file.close()
# Train function
def train(batch_size, epochs):
global training_loss, validation_loss, model_path, lr, w1, w2, gamma, step
# Loss function dictionary
loss_func = {'l1' : l1_loss, 'l2' : l2_loss, 'behru' : behru_loss}
print('Loading data...')
# Loading dataset
train_set, val_set, test_set = get_split(train=True)
print('Model setup...')
# Creating a model
model = DepthNet().to(device)
# Optimizer setup
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step, gamma=gamma, verbose = True)
print('Training...')
best_loss = None
training_loss = []
validation_loss = []
N_train = int(train_set.size() / batch_size)
N_val = int(val_set.size() / batch_size)
for epoch in range(epochs):
torch.cuda.empty_cache()
print('Epoch %d/%d' %(epoch+1,epochs))
# Prepare model for training
model.train()
# Initialize running loss
running_loss_photo = 0
running_loss_smooth = 0
running_loss = 0
# Number of training iterations
N_train = train_set.initBatch(batch_size=batch_size)
for itr in range(N_train):
# Get images and depths
tgt_img, gt_depth = train_set.getBatch()
# Move tensors to device
tgt_img = tgt_img.to(device).float()
gt_depth = gt_depth.to(device).float()
# Clear gradients
optimizer.zero_grad()
# Prediction
disparities = model(tgt_img)
depth = 1 / disparities
# Calculate loss
loss_1 = loss_func[LOSS](gt_depth, depth)
loss_3 = smooth_loss(depth)
loss = weighted_loss(loss_1, loss_3, w1, w2)
# Calculate gradients
loss.backward()
# Update weights
optimizer.step()
print('Iteration {}/{}, loss = {:.4f}'.format(itr+1, N_train, loss.item()))
# Update running loss
running_loss_photo += loss_1.item() / N_train
running_loss_smooth += loss_3.item() / N_train
running_loss += loss.item() / N_train
torch.cuda.empty_cache()
# Print results on training dataset
print('------------------------------------------------')
print('########### Training results {}/{} #############'.format(epoch+1, epochs))
print('Photometric loss {:.4f}, Smooth loss {:.4f}, Overall loss {:.4f}'.format(running_loss_photo, running_loss_smooth, running_loss))
print('------------------------------------------------')
# Save training loss for current epoch
training_loss.append(running_loss)
# Initialize running loss
running_loss_photo = 0
running_loss_smooth = 0
running_loss = 0
# Prepare model for validation
model.eval()
# Number of validation iterations
N_val = val_set.initBatch(batch_size=batch_size)
for itr in range(N_val):
# Get images and depths
tgt_img, gt_depth = val_set.getBatch()
# Move tensors to device
tgt_img = tgt_img.to(device).float()
gt_depth = gt_depth.to(device).float()
#gt_depth = torch.squeeze(gt_depth[:, 0, :, :])
with torch.no_grad():
# Prediction
disparities = model(tgt_img)
depth = 1 / disparities
# Calculate loss
loss_1 = loss_func[LOSS](gt_depth, depth)
loss_3 = smooth_loss(depth)
loss = weighted_loss(loss_1, loss_3, w1, w2)
print('Iteration {}/{}, loss = {:.4f}'.format(itr+1, N_val, loss.item()))
# Update running loss
running_loss_photo += loss_1.item() / N_val
running_loss_smooth += loss_3.item() / N_val
running_loss += loss.item() / N_val
torch.cuda.empty_cache()
# Print results on validation dataset
print('------------------------------------------------')
print('########### Validation results {}/{} ###########'.format(epoch+1, epochs))
print('Photometric loss {:.4f}, Smooth loss {:.4f}, Overall loss {:.4f}'.format(running_loss_photo, running_loss_smooth, running_loss))
print('------------------------------------------------')
# Save validation loss for current epoch
validation_loss.append(running_loss)
# Saving the best model
if (best_loss == None) or (validation_loss[-1] < best_loss):
# Get best loss
best_loss = validation_loss[-1]
# Write in log file
log_file = open(images_dir + "/best_validation loss.txt","w")
log_file.write("loss : " + str(round(best_loss,4)) + "\nepoch " + str(epoch+1))
log_file.close()
# Save best model
torch.save(model.state_dict(), model_path)
# Save loss plot
if epoch%5 == 0:
fig = plt.figure(figsize=(16,10), dpi=120)
plt.plot(training_loss)
plt.plot(validation_loss)
plt.legend(['Training loss','Validation loss'])
plt.title('Loss function')
plt.grid(b=True, which='minor')
fig.savefig(images_dir + '/learning_curve.png', dpi=fig.dpi)
scheduler.step()
return (training_loss, validation_loss)
# Main
if __name__ == '__main__':
summary_writter()
start = time.time()
try:
loss = train(batch_size=batch_size, epochs=EPOCHS)
fig = plt.figure(figsize=(16,10), dpi=120)
plt.plot(loss[0])
plt.plot(loss[1])
plt.legend(['Training loss','Validation loss'])
plt.title('Loss function')
plt.grid(b=True, which='minor')
fig.savefig(images_dir + '/learning_curve.png', dpi=fig.dpi)
# plt.show()
except KeyboardInterrupt:
fig = plt.figure(figsize=(16,10), dpi=120)
plt.plot(training_loss)
plt.plot(validation_loss)
plt.legend(['Training loss','Validation loss'])
plt.title('Loss function')
plt.grid(b=True, which='minor')
fig.savefig(images_dir + '/learning_curve.png', dpi=fig.dpi)
# plt.show()
end = time.time()
print('Training time {}s'.format(end-start))