Skip to content

Commit

Permalink
🐛 Fix FID metric bug (#6)
Browse files Browse the repository at this point in the history
* ⬇️ Downgrade `scipy==1.11.1`

* ♻️ Format metric lpips code

* ♻️ Format metric fid code
  • Loading branch information
zero-nnkn authored Sep 16, 2023
1 parent a9f41aa commit 6f419cb
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 286 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cupy-cuda11x>=11.0.0
# Monitor/Metric
tensorboardX>=2.6
tqdm>=4.65.0
scipy==1.11.1
thop==0.1.1.post2209072238

# Other
Expand Down
179 changes: 116 additions & 63 deletions utils/metrics/lpips/lpips.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,44 @@

import os

import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from .utils import normalize_tensor, im2tensor, load_image
from .pretrained_networks import vgg16, alexnet, squeezenet
from .pretrained_networks import alexnet, squeezenet, vgg16


def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)


def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2,3],keepdim=keepdim)
return in_tens.mean([2, 3], keepdim=keepdim)

def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
in_H, in_W = in_tens.shape[2], in_tens.shape[3]

def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
# in_H, in_W = in_tens.shape[2], in_tens.shape[3]
return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)


# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False):
""" Initializes a perceptual loss torch.nn.Module
def __init__(
self,
pretrained=True,
net='alex',
version='0.1',
lpips=True,
spatial=False,
pnet_rand=False,
pnet_tune=False,
use_dropout=True,
model_path=None,
eval_mode=True,
verbose=False,
):
"""Initializes a perceptual loss torch.nn.Module
Parameters (default listed first)
---------------------------------
Expand All @@ -39,9 +56,9 @@ def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spati
['alex','vgg','squeeze'] are the base/trunk networks available
version : str
['v0.1'] is the default and latest
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
['v0.0'] contained normalization bug; corresponds to https://arxiv.org/abs/1801.03924v1
model_path : 'str'
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
[None] is default and loads the pretrained from paper https://arxiv.org/abs/1801.03924v1
The following parameters should only be changed if training the network
Expand All @@ -56,115 +73,147 @@ def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spati
[False] for no dropout when training linear layers
"""

super(LPIPS, self).__init__()
if(verbose):
print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
super().__init__()
if verbose:
print(
'Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'
% ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')
)

self.pnet_type = net
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.lpips = lpips # false means baseline of just averaging all layers
self.lpips = lpips # false means baseline of just averaging all layers
self.version = version
self.scaling_layer = ScalingLayer()

if(self.pnet_type in ['vgg','vgg16']):
if self.pnet_type in ['vgg', 'vgg16']:
net_type = vgg16
self.chns = [64,128,256,512,512]
elif(self.pnet_type=='alex'):
self.chns = [64, 128, 256, 512, 512]
elif self.pnet_type == 'alex':
net_type = alexnet
self.chns = [64,192,384,256,256]
elif(self.pnet_type=='squeeze'):
self.chns = [64, 192, 384, 256, 256]
elif self.pnet_type == 'squeeze':
net_type = squeezenet
self.chns = [64,128,256,384,384,512,512]
self.chns = [64, 128, 256, 384, 384, 512, 512]
self.L = len(self.chns)

self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)

if(lpips):
if lpips:
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
if self.pnet_type == 'squeeze': # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins+=[self.lin5,self.lin6]
self.lins += [self.lin5, self.lin6]
self.lins = nn.ModuleList(self.lins)

if(pretrained):
if(model_path is None):
if pretrained:
if model_path is None:
import inspect
import os
model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))

if(verbose):
print('Loading model from: %s'%model_path)
self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
model_path = os.path.abspath(
os.path.join(
inspect.getfile(self.__init__),
'..',
f'weights/v{version}/{net}.pth',
)
)

if verbose:
print('Loading model from: %s' % model_path)
self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)

if(eval_mode):
if eval_mode:
self.eval()

def forward(self, in0, in1, retPerLayer=False, normalize=False):
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1

# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
in0_input, in1_input = (
(self.scaling_layer(in0), self.scaling_layer(in1))
if self.version == '0.1'
else (in0, in1)
)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}

for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk]-feats1[kk])**2
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

if(self.lpips):
if(self.spatial):
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
if self.lpips:
if self.spatial:
res = [
upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)
]
else:
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
res = [
spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)
]
else:
if(self.spatial):
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
if self.spatial:
res = [
upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:])
for kk in range(self.L)
]
else:
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
res = [
spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True)
for kk in range(self.L)
]

val = 0
for l in range(self.L):
val += res[l]
if(retPerLayer):
for lidx in range(self.L):
val += res[lidx]

if retPerLayer:
return (val, res)
else:
return val


class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
super().__init__()
self.register_buffer('shift', torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])

def forward(self, inp):
return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
'''A single linear layer which does a 1x1 conv'''

layers = [nn.Dropout(),] if(use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super().__init__()

layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)


def calculate_lpips_given_paths(paths, device, backbone='vgg', version='0.1'):
loss_fn = LPIPS(net=backbone, version=version).to(device)
Expand All @@ -180,16 +229,20 @@ def calculate_lpips_given_paths(paths, device, backbone='vgg', version='0.1'):
# Load images
# img0 = im2tensor(load_image(os.path.join(paths[0], f))).to(device) # RGB image from [-1,1]
# img1 = im2tensor(load_image(os.path.join(paths[1], f))).to(device)
img0 = (np.array(Image.open(os.path.join(paths[0], f))).transpose(2, 0, 1).astype(np.float32) / 255) * 2 - 1
img1 = (np.array(Image.open(os.path.join(paths[1], f))).transpose(2, 0, 1).astype(np.float32) / 255) * 2 - 1
img0 = (
np.array(Image.open(os.path.join(paths[0], f))).transpose(2, 0, 1).astype(np.float32)
/ 255
) * 2 - 1
img1 = (
np.array(Image.open(os.path.join(paths[1], f))).transpose(2, 0, 1).astype(np.float32)
/ 255
) * 2 - 1
img0 = torch.from_numpy(img0).unsqueeze(0).to(device)
img1 = torch.from_numpy(img1).unsqueeze(0).to(device)

# Compute distance
lpips_avg += loss_fn(img0, img1).item()

lpips_avg /= len(files)

return lpips_avg


Loading

0 comments on commit 6f419cb

Please sign in to comment.