Skip to content

Commit

Permalink
Added ppl
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Dec 21, 2019
1 parent 5f60e58 commit 6d3e784
Show file tree
Hide file tree
Showing 16 changed files with 1,047 additions and 21 deletions.
24 changes: 24 additions & 0 deletions LICENSE-LPIPS
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ train.py supports Weights & Biases logging. If you want to use it, add --wandb a

![Sample with truncation](sample.png)

At 40,000 iterations. (trained on 1.28M images)
At 110,000 iterations. (trained on 3.52M images)
160 changes: 160 additions & 0 deletions lpips/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from skimage.measure import compare_ssim
import torch
from torch.autograd import Variable

from lpips import dist_model

class PerceptualLoss(torch.nn.Module):
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
super(PerceptualLoss, self).__init__()
print('Setting up Perceptual loss...')
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model = dist_model.DistModel()
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
print('...[%s] initialized'%self.model.name())
print('...Done')

def forward(self, pred, target, normalize=False):
"""
Pred and target are Variables.
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
If normalize is False, assumes the images are already between [-1,+1]
Inputs pred and target are Nx3xHxW
Output pytorch Variable N long
"""

if normalize:
target = 2 * target - 1
pred = 2 * pred - 1

return self.model.forward(target, pred)

def normalize_tensor(in_feat,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)

def l2(p0, p1, range=255.):
return .5*np.mean((p0 / range - p1 / range)**2)

def psnr(p0, p1, peak=255.):
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))

def dssim(p0, p1, range=255.):
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.

def rgb2lab(in_img,mean_cent=False):
from skimage import color
img_lab = color.rgb2lab(in_img)
if(mean_cent):
img_lab[:,:,0] = img_lab[:,:,0]-50
return img_lab

def tensor2np(tensor_obj):
# change dimension of a tensor object into a numpy array
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))

def np2tensor(np_obj):
# change dimenion of np array into tensor array
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
# image tensor to lab tensor
from skimage import color

img = tensor2im(image_tensor)
img_lab = color.rgb2lab(img)
if(mc_only):
img_lab[:,:,0] = img_lab[:,:,0]-50
if(to_norm and not mc_only):
img_lab[:,:,0] = img_lab[:,:,0]-50
img_lab = img_lab/100.

return np2tensor(img_lab)

def tensorlab2tensor(lab_tensor,return_inbnd=False):
from skimage import color
import warnings
warnings.filterwarnings("ignore")

lab = tensor2np(lab_tensor)*100.
lab[:,:,0] = lab[:,:,0]+50

rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
if(return_inbnd):
# convert back to lab, see if we match
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
mask = 1.*np.isclose(lab_back,lab,atol=2.)
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
return (im2tensor(rgb_back),mask)
else:
return im2tensor(rgb_back)

def rgb2lab(input):
from skimage import color
return color.rgb2lab(input / 255.)

def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)

def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

def tensor2vec(vector_tensor):
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]

def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))

# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]

# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap

def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)

def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
58 changes: 58 additions & 0 deletions lpips/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import torch
from torch.autograd import Variable
from pdb import set_trace as st
from IPython import embed

class BaseModel():
def __init__(self):
pass;

def name(self):
return 'BaseModel'

def initialize(self, use_gpu=True, gpu_ids=[0]):
self.use_gpu = use_gpu
self.gpu_ids = gpu_ids

def forward(self):
pass

def get_image_paths(self):
pass

def optimize_parameters(self):
pass

def get_current_visuals(self):
return self.input

def get_current_errors(self):
return {}

def save(self, label):
pass

# helper saving function that can be used by subclasses
def save_network(self, network, path, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(path, save_filename)
torch.save(network.state_dict(), save_path)

# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
print('Loading network from %s'%save_path)
network.load_state_dict(torch.load(save_path))

def update_learning_rate():
pass

def get_image_paths(self):
return self.image_paths

def save_done(self, flag=False):
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')

Loading

0 comments on commit 6d3e784

Please sign in to comment.