-
Notifications
You must be signed in to change notification settings - Fork 22
/
misc.py
31 lines (27 loc) · 1 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch import nn
from torch import optim
from torch.autograd import Variable
import torch
# weight initialisation is provided
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find("Linear") !=-1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def one_hot(labelTensor,batchSize,n_l,use_cuda=False):
# torch.ones(20*10).view(20,10)
oneHot = - torch.ones(batchSize*n_l).view(batchSize,n_l)
for i,j in enumerate(labelTensor):
oneHot[i,j] = 1
if use_cuda:
return Variable(oneHot).cuda()
else:
return Variable(oneHot)
# total variation loss which is effective in removing gost artifacts
def TV_LOSS(imgTensor,img_size=128):
x = (imgTensor[:,:,1:,:]-imgTensor[:,:,:img_size-1,:])**2
y = (imgTensor[:,:,:,1:]-imgTensor[:,:,:,:img_size-1])**2
out = (x.mean(dim=2)+y.mean(dim=3)).mean()
return out