forked from harshitbansal05/Image-Colorization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
109 lines (89 loc) · 4.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
from torch import cat
from model import UNet, DNet
import data_loader
from data_loader import *
##############################################################
# Initialise the generator and discriminator with the UNet and
# DNet architectures respectively.
generator = UNet(True)
discriminator = DNet()
##################################################################
# Utilize GPU for performing all the calculations performed in the
# forward and backward passes. Thus allocate all the generator and
# discriminator variables on the default GPU device.
generator.cuda()
discriminator.cuda()
###################################################################
# Create ADAM optimizer for the generator as well the discriminator.
# Create loss criterion for calculating the L1 and adversarial loss.
d_optimizer = optim.Adam(discriminator.parameters(), betas=(0.5, 0.999), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), betas=(0.5, 0.999), lr=0.0002)
d_criterion = nn.BCELoss()
g_criterion_1 = nn.BCELoss()
g_criterion_2 = nn.L1Loss()
train_()
def train_():
"""
Train the dataset for several epochs.
"""
g_lambda = 100
smooth = 0.1
# loop over the dataset multiple times.
for epoch in range(200):
# the generator and discriminator losses are summed for the entire epoch.
d_running_loss = 0.0
g_running_loss = 0.0
for i, data in enumerate(cielab_loader):
lab_images = data
# split the lab color space images into luminescence and chrominance channels.
l_images = lab_images[:, 0, :, :]
c_images = lab_images[:, 1:, :, :]
# shift the source and target images into the range [-0.5, 0.5].
mean = torch.Tensor([0.5])
l_images = l_images - mean.expand_as(l_images)
l_images = 2 * l_images
c_images = c_images - mean.expand_as(c_images)
c_images = 2 * c_images
# allocate the images on the default gpu device.
batch_size = l_images.shape[0]
l_images = Variable(l_images.cuda())
c_images = Variable(c_images.cuda())
# fake images are generated by passing them through the generator.
fake_images = generator(l_images)
# Train the discriminator. The loss would be the sum of the losses over
# the source and fake images, with greyscale images as the condition.
d_optimizer.zero_grad()
d_loss = 0
logits = discriminator(cat([l_images, c_images], 1))
d_real_loss = d_criterion(logits, ((1 - smooth) * torch.ones(batch_size)).cuda())
logits = discriminator(cat([l_images, fake_images], 1))
d_fake_loss = d_criterion(logits, (torch.zeros(batch_size)).cuda())
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
d_optimizer.step()
# Train the generator. The loss would be the sum of the adversarial loss
# due to the GAN and L1 distance loss between the fake and target images.
g_optimizer.zero_grad()
g_loss = 0
fake_logits = discriminator(cat([l_images, fake_images], 1))
g_fake_loss = g_criterion_1(fake_logits, (torch.ones(batch_size)).cuda())
g_image_distance_loss = g_lambda * g_criterion_2(fake_images, c_images)
g_loss = g_fake_loss + g_image_distance_loss
g_loss.backward()
g_optimizer.step()
# print statistics on pre-defined intervals.
d_running_loss += d_loss
g_running_loss += g_loss
if i % 10 == 0:
print('[%d, %5d] d_loss: %.5f g_loss: %.5f' %
(epoch + 1, i + 1, d_running_loss / 10, g_running_loss / 10))
d_running_loss = 0.0
g_running_loss = 0.0
# save the generator and discriminator state after each epoch.
torch.save(generator.state_dict(), 'home/cifar10_train_generator')
torch.save(discriminator.state_dict(), 'home/cifar10_train_discriminator')
print('Finished Training')