diff --git a/requirements.txt b/requirements.txt index d3fb7ae..754accf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ cupy-cuda11x>=11.0.0 # Monitor/Metric tensorboardX>=2.6 tqdm>=4.65.0 +scipy==1.11.1 thop==0.1.1.post2209072238 # Other diff --git a/utils/metrics/lpips/lpips.py b/utils/metrics/lpips/lpips.py index 6b2d227..0a529f2 100755 --- a/utils/metrics/lpips/lpips.py +++ b/utils/metrics/lpips/lpips.py @@ -1,4 +1,3 @@ - import os import numpy as np @@ -6,22 +5,40 @@ import torch.nn as nn from PIL import Image -from .utils import normalize_tensor, im2tensor, load_image -from .pretrained_networks import vgg16, alexnet, squeezenet +from .pretrained_networks import alexnet, squeezenet, vgg16 + + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) + return in_feat / (norm_factor + eps) + def spatial_average(in_tens, keepdim=True): - return in_tens.mean([2,3],keepdim=keepdim) + return in_tens.mean([2, 3], keepdim=keepdim) -def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W - in_H, in_W = in_tens.shape[2], in_tens.shape[3] + +def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W + # in_H, in_W = in_tens.shape[2], in_tens.shape[3] return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) # Learned perceptual metric class LPIPS(nn.Module): - def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, - pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False): - """ Initializes a perceptual loss torch.nn.Module + def __init__( + self, + pretrained=True, + net='alex', + version='0.1', + lpips=True, + spatial=False, + pnet_rand=False, + pnet_tune=False, + use_dropout=True, + model_path=None, + eval_mode=True, + verbose=False, + ): + """Initializes a perceptual loss torch.nn.Module Parameters (default listed first) --------------------------------- @@ -39,9 +56,9 @@ def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spati ['alex','vgg','squeeze'] are the base/trunk networks available version : str ['v0.1'] is the default and latest - ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) + ['v0.0'] contained normalization bug; corresponds to https://arxiv.org/abs/1801.03924v1 model_path : 'str' - [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 + [None] is default and loads the pretrained from paper https://arxiv.org/abs/1801.03924v1 The following parameters should only be changed if training the network @@ -56,88 +73,111 @@ def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spati [False] for no dropout when training linear layers """ - super(LPIPS, self).__init__() - if(verbose): - print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% - ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) + super().__init__() + if verbose: + print( + 'Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' + % ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off') + ) self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial - self.lpips = lpips # false means baseline of just averaging all layers + self.lpips = lpips # false means baseline of just averaging all layers self.version = version self.scaling_layer = ScalingLayer() - if(self.pnet_type in ['vgg','vgg16']): + if self.pnet_type in ['vgg', 'vgg16']: net_type = vgg16 - self.chns = [64,128,256,512,512] - elif(self.pnet_type=='alex'): + self.chns = [64, 128, 256, 512, 512] + elif self.pnet_type == 'alex': net_type = alexnet - self.chns = [64,192,384,256,256] - elif(self.pnet_type=='squeeze'): + self.chns = [64, 192, 384, 256, 256] + elif self.pnet_type == 'squeeze': net_type = squeezenet - self.chns = [64,128,256,384,384,512,512] + self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) - if(lpips): + if lpips: self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) - self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] - if(self.pnet_type=='squeeze'): # 7 layers for squeezenet + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + if self.pnet_type == 'squeeze': # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) - self.lins+=[self.lin5,self.lin6] + self.lins += [self.lin5, self.lin6] self.lins = nn.ModuleList(self.lins) - if(pretrained): - if(model_path is None): + if pretrained: + if model_path is None: import inspect import os - model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) - if(verbose): - print('Loading model from: %s'%model_path) - self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) + model_path = os.path.abspath( + os.path.join( + inspect.getfile(self.__init__), + '..', + f'weights/v{version}/{net}.pth', + ) + ) + + if verbose: + print('Loading model from: %s' % model_path) + self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) - if(eval_mode): + if eval_mode: self.eval() def forward(self, in0, in1, retPerLayer=False, normalize=False): - if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] - in0 = 2 * in0 - 1 - in1 = 2 * in1 - 1 + if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] + in0 = 2 * in0 - 1 + in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled - in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) + in0_input, in1_input = ( + (self.scaling_layer(in0), self.scaling_layer(in1)) + if self.version == '0.1' + else (in0, in1) + ) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) - diffs[kk] = (feats0[kk]-feats1[kk])**2 + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - if(self.lpips): - if(self.spatial): - res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] + if self.lpips: + if self.spatial: + res = [ + upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L) + ] else: - res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] + res = [ + spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L) + ] else: - if(self.spatial): - res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] + if self.spatial: + res = [ + upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) + for kk in range(self.L) + ] else: - res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] + res = [ + spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) + for kk in range(self.L) + ] val = 0 - for l in range(self.L): - val += res[l] - - if(retPerLayer): + for lidx in range(self.L): + val += res[lidx] + + if retPerLayer: return (val, res) else: return val @@ -145,26 +185,35 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False): class ScalingLayer(nn.Module): def __init__(self): - super(ScalingLayer, self).__init__() - self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) - self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) + super().__init__() + self.register_buffer('shift', torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): - ''' A single linear layer which does a 1x1 conv ''' - def __init__(self, chn_in, chn_out=1, use_dropout=False): - super(NetLinLayer, self).__init__() + '''A single linear layer which does a 1x1 conv''' - layers = [nn.Dropout(),] if(use_dropout) else [] - layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super().__init__() + + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) - + def calculate_lpips_given_paths(paths, device, backbone='vgg', version='0.1'): loss_fn = LPIPS(net=backbone, version=version).to(device) @@ -180,16 +229,20 @@ def calculate_lpips_given_paths(paths, device, backbone='vgg', version='0.1'): # Load images # img0 = im2tensor(load_image(os.path.join(paths[0], f))).to(device) # RGB image from [-1,1] # img1 = im2tensor(load_image(os.path.join(paths[1], f))).to(device) - img0 = (np.array(Image.open(os.path.join(paths[0], f))).transpose(2, 0, 1).astype(np.float32) / 255) * 2 - 1 - img1 = (np.array(Image.open(os.path.join(paths[1], f))).transpose(2, 0, 1).astype(np.float32) / 255) * 2 - 1 + img0 = ( + np.array(Image.open(os.path.join(paths[0], f))).transpose(2, 0, 1).astype(np.float32) + / 255 + ) * 2 - 1 + img1 = ( + np.array(Image.open(os.path.join(paths[1], f))).transpose(2, 0, 1).astype(np.float32) + / 255 + ) * 2 - 1 img0 = torch.from_numpy(img0).unsqueeze(0).to(device) img1 = torch.from_numpy(img1).unsqueeze(0).to(device) # Compute distance lpips_avg += loss_fn(img0, img1).item() - + lpips_avg /= len(files) return lpips_avg - - diff --git a/utils/metrics/lpips/pretrained_networks.py b/utils/metrics/lpips/pretrained_networks.py index 2d07f9a..ecbb829 100644 --- a/utils/metrics/lpips/pretrained_networks.py +++ b/utils/metrics/lpips/pretrained_networks.py @@ -1,10 +1,12 @@ from collections import namedtuple + import torch from torchvision import models as tv + class squeezenet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): - super(squeezenet, self).__init__() + super().__init__() pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() @@ -16,7 +18,7 @@ def __init__(self, requires_grad=False, pretrained=True): self.N_slices = 7 for x in range(2): self.slice1.add_module(str(x), pretrained_features[x]) - for x in range(2,5): + for x in range(2, 5): self.slice2.add_module(str(x), pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), pretrained_features[x]) @@ -47,15 +49,17 @@ def forward(self, X): h_relu6 = h h = self.slice7(h) h_relu7 = h - vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) - out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) + vgg_outputs = namedtuple( + "SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'] + ) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) return out class alexnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): - super(alexnet, self).__init__() + super().__init__() alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() @@ -88,14 +92,17 @@ def forward(self, X): h_relu4 = h h = self.slice5(h) h_relu5 = h - alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + alexnet_outputs = namedtuple( + "AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'] + ) out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) return out + class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): - super(vgg16, self).__init__() + super().__init__() vgg_pretrained_features = tv.vgg16(weights=tv.VGG16_Weights.IMAGENET1K_V1).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() @@ -128,25 +135,26 @@ def forward(self, X): h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h - vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] + ) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out - class resnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True, num=18): - super(resnet, self).__init__() - if(num==18): + super().__init__() + if num == 18: self.net = tv.resnet18(pretrained=pretrained) - elif(num==34): + elif num == 34: self.net = tv.resnet34(pretrained=pretrained) - elif(num==50): + elif num == 50: self.net = tv.resnet50(pretrained=pretrained) - elif(num==101): + elif num == 101: self.net = tv.resnet101(pretrained=pretrained) - elif(num==152): + elif num == 152: self.net = tv.resnet152(pretrained=pretrained) self.N_slices = 5 @@ -174,7 +182,7 @@ def forward(self, X): h = self.layer4(h) h_conv5 = h - outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) + outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) return out diff --git a/utils/metrics/lpips/utils.py b/utils/metrics/lpips/utils.py deleted file mode 100644 index 46290ad..0000000 --- a/utils/metrics/lpips/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -import warnings -import numpy as np -import torch -from skimage import color -# from skimage.measure import compare_ssim -from matplotlib import pyplot as plt - - -def normalize_tensor(in_feat,eps=1e-10): - norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) - return in_feat/(norm_factor+eps) - -def l2(p0, p1, range=255.): - return .5*np.mean((p0 / range - p1 / range)**2) - -def psnr(p0, p1, peak=255.): - return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) - -# def dssim(p0, p1, range=255.): -# return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. - -def tensor2np(tensor_obj): - # change dimension of a tensor object into a numpy array - return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) - -def np2tensor(np_obj): - # change dimenion of np array into tensor array - return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) - -def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): - # image tensor to lab tensor - img = tensor2im(image_tensor) - img_lab = color.rgb2lab(img) - if(mc_only): - img_lab[:,:,0] = img_lab[:,:,0]-50 - if(to_norm and not mc_only): - img_lab[:,:,0] = img_lab[:,:,0]-50 - img_lab = img_lab/100. - - return np2tensor(img_lab) - -def tensorlab2tensor(lab_tensor,return_inbnd=False): - warnings.filterwarnings("ignore") - - lab = tensor2np(lab_tensor)*100. - lab[:,:,0] = lab[:,:,0]+50 - - rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) - if(return_inbnd): - # convert back to lab, see if we match - lab_back = color.rgb2lab(rgb_back.astype('uint8')) - mask = 1.*np.isclose(lab_back,lab,atol=2.) - mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) - return (im2tensor(rgb_back),mask) - else: - return im2tensor(rgb_back) - -def load_image(path): - # if(path[-3:] == 'dng'): - # import rawpy - # with rawpy.imread(path) as raw: - # img = raw.postprocess() - # elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg'): - # import cv2 - # return cv2.imread(path)[:,:,::-1] - # else: - # import matplotlib.pyplot as plt - # img = (255*plt.imread(path)[:,:,:3]).astype('uint8') - img = (255*plt.imread(path)[:,:,:3]).astype('uint8') - - return img - -def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): - image_numpy = image_tensor[0].cpu().float().numpy() - image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor - return image_numpy.astype(imtype) - -def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): - return torch.Tensor((image / factor - cent) - [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) - -def tensor2vec(vector_tensor): - return vector_tensor.data.cpu().numpy()[:, :, 0, 0] - - -def voc_ap(rec, prec, use_07_metric=False): - """ ap = voc_ap(rec, prec, [use_07_metric]) - Compute VOC AP given precision and recall. - If use_07_metric is true, uses the - VOC 07 11 point method (default:False). - """ - if use_07_metric: - # 11 point metric - ap = 0. - for t in np.arange(0., 1.1, 0.1): - if np.sum(rec >= t) == 0: - p = 0 - else: - p = np.max(prec[rec >= t]) - ap = ap + p / 11. - else: - # correct AP calculation - # first append sentinel values at the end - mrec = np.concatenate(([0.], rec, [1.])) - mpre = np.concatenate(([0.], prec, [0.])) - - # compute the precision envelope - for i in range(mpre.size - 1, 0, -1): - mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) - - # to calculate area under PR curve, look for points - # where X axis (recall) changes value - i = np.where(mrec[1:] != mrec[:-1])[0] - - # and sum (\Delta recall) * prec - ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) - return ap \ No newline at end of file diff --git a/utils/metrics/pytorch_fid/fid_score.py b/utils/metrics/pytorch_fid/fid_score.py index a0f9de3..9e7b709 100644 --- a/utils/metrics/pytorch_fid/fid_score.py +++ b/utils/metrics/pytorch_fid/fid_score.py @@ -42,30 +42,38 @@ from scipy import linalg from torch.nn.functional import adaptive_avg_pool2d - from .inception import InceptionV3 parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) -parser.add_argument('--batch-size', type=int, default=50, - help='Batch size to use') -parser.add_argument('--num-workers', type=int, - help=('Number of processes to use for data loading. ' - 'Defaults to `min(8, num_cpus)`')) -parser.add_argument('--device', type=str, default=None, - help='Device to use. Like cuda, cuda:0 or cpu') -parser.add_argument('--dims', type=int, default=2048, - choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), - help=('Dimensionality of Inception features to use. ' - 'By default, uses pool3 features')) -parser.add_argument('--save-stats', action='store_true', - help=('Generate an npz archive from a directory of samples. ' - 'The first path is used as input and the second as output.')) -parser.add_argument('path', type=str, nargs=2, - help=('Paths to the generated images or ' - 'to .npz statistic files')) - -IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', - 'tif', 'tiff', 'webp'} +parser.add_argument('--batch-size', type=int, default=50, help='Batch size to use') +parser.add_argument( + '--num-workers', + type=int, + help=('Number of processes to use for data loading. ' 'Defaults to `min(8, num_cpus)`'), +) +parser.add_argument( + '--device', type=str, default=None, help='Device to use. Like cuda, cuda:0 or cpu' +) +parser.add_argument( + '--dims', + type=int, + default=2048, + choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), + help=('Dimensionality of Inception features to use. ' 'By default, uses pool3 features'), +) +parser.add_argument( + '--save-stats', + action='store_true', + help=( + 'Generate an npz archive from a directory of samples. ' + 'The first path is used as input and the second as output.' + ), +) +parser.add_argument( + 'path', type=str, nargs=2, help=('Paths to the generated images or ' 'to .npz statistic files') +) + +IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 'tif', 'tiff', 'webp'} class ImagePathDataset(torch.utils.data.Dataset): @@ -84,8 +92,7 @@ def __getitem__(self, i): return img -def get_activations(files, model, batch_size=50, dims=2048, device='cpu', - num_workers=1): +def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=1): """Calculates the activations of the pool_3 layer for all images. Params: @@ -107,16 +114,15 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', """ model.eval() if batch_size > len(files): - print(('Warning: batch size is bigger than the data size. ' - 'Setting batch size to data size')) + print( + 'Warning: batch size is bigger than the data size. ' 'Setting batch size to data size' + ) batch_size = len(files) dataset = ImagePathDataset(files, transforms=TF.ToTensor()) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False, - num_workers=num_workers) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers + ) pred_arr = np.empty((len(files), dims)) @@ -134,7 +140,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', pred = pred.squeeze(3).squeeze(2).cpu().numpy() - pred_arr[start_idx:start_idx + pred.shape[0]] = pred + pred_arr[start_idx : start_idx + pred.shape[0]] = pred start_idx = start_idx + pred.shape[0] @@ -169,18 +175,17 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) - assert mu1.shape == mu2.shape, \ - 'Training and test mean vectors have different lengths' - assert sigma1.shape == sigma2.shape, \ - 'Training and test covariances have different dimensions' + assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions' diff = mu1 - mu2 # Product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): - msg = ('fid calculation produces singular product; ' - 'adding %s to diagonal of cov estimates') % eps + msg = ( + 'fid calculation produces singular product; ' 'adding %s to diagonal of cov estimates' + ) % eps print(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) @@ -189,17 +194,17 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) - raise ValueError('Imaginary component {}'.format(m)) + raise ValueError(f'Imaginary component {m}') covmean = covmean.real tr_covmean = np.trace(covmean) - return (diff.dot(diff) + np.trace(sigma1) - + np.trace(sigma2) - 2 * tr_covmean) + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean -def calculate_activation_statistics(files, model, batch_size=50, dims=2048, - device='cpu', num_workers=1): +def calculate_activation_statistics( + files, model, batch_size=50, dims=2048, device='cpu', num_workers=1 +): """Calculation of the statistics used by the FID. Params: -- files : List of image files paths @@ -223,17 +228,14 @@ def calculate_activation_statistics(files, model, batch_size=50, dims=2048, return mu, sigma -def compute_statistics_of_path(path, model, batch_size, dims, device, - num_workers=1): +def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1): if path.endswith('.npz'): with np.load(path) as f: m, s = f['mu'][:], f['sigma'][:] else: path = pathlib.Path(path) - files = sorted([file for ext in IMAGE_EXTENSIONS - for file in path.glob('*.{}'.format(ext))]) - m, s = calculate_activation_statistics(files, model, batch_size, - dims, device, num_workers) + files = sorted([file for ext in IMAGE_EXTENSIONS for file in path.glob(f'*.{ext}')]) + m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers) return m, s @@ -248,10 +250,8 @@ def calculate_fid_given_paths(paths, batch_size, device, dims=2048, num_workers= model = InceptionV3([block_idx]).to(device) - m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, - dims, device, num_workers) - m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, - dims, device, num_workers) + m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers) + m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, dims, device, num_workers) fid_value = calculate_frechet_distance(m1, s1, m2, s2) return fid_value @@ -278,12 +278,9 @@ def main(): else: num_workers = args.num_workers - - fid_value = calculate_fid_given_paths(args.path, - args.batch_size, - device, - args.dims, - num_workers) + fid_value = calculate_fid_given_paths( + args.path, args.batch_size, device, args.dims, num_workers + ) print('FID: ', fid_value) diff --git a/utils/metrics/pytorch_fid/inception.py b/utils/metrics/pytorch_fid/inception.py index cc56870..8aec4e0 100644 --- a/utils/metrics/pytorch_fid/inception.py +++ b/utils/metrics/pytorch_fid/inception.py @@ -22,18 +22,20 @@ class InceptionV3(nn.Module): # Maps feature dimensionality to their output blocks indices BLOCK_INDEX_BY_DIM = { - 64: 0, # First max pooling features + 64: 0, # First max pooling features 192: 1, # Second max pooling featurs 768: 2, # Pre-aux classifier features - 2048: 3 # Final average pooling features + 2048: 3, # Final average pooling features } - def __init__(self, - output_blocks=(DEFAULT_BLOCK_INDEX,), - resize_input=True, - normalize_input=True, - requires_grad=False, - use_fid_inception=True): + def __init__( + self, + output_blocks=(DEFAULT_BLOCK_INDEX,), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True, + ): """Build pretrained InceptionV3 Parameters @@ -64,15 +66,14 @@ def __init__(self, strongly advised to set this parameter to true to get comparable results. """ - super(InceptionV3, self).__init__() + super().__init__() self.resize_input = resize_input self.normalize_input = normalize_input self.output_blocks = sorted(output_blocks) self.last_needed_block = max(output_blocks) - assert self.last_needed_block <= 3, \ - 'Last possible output block index is 3' + assert self.last_needed_block <= 3, 'Last possible output block index is 3' self.blocks = nn.ModuleList() @@ -86,7 +87,7 @@ def __init__(self, inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, - nn.MaxPool2d(kernel_size=3, stride=2) + nn.MaxPool2d(kernel_size=3, stride=2), ] self.blocks.append(nn.Sequential(*block0)) @@ -95,7 +96,7 @@ def __init__(self, block1 = [ inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, - nn.MaxPool2d(kernel_size=3, stride=2) + nn.MaxPool2d(kernel_size=3, stride=2), ] self.blocks.append(nn.Sequential(*block1)) @@ -119,7 +120,7 @@ def __init__(self, inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, - nn.AdaptiveAvgPool2d(output_size=(1, 1)) + nn.AdaptiveAvgPool2d(output_size=(1, 1)), ] self.blocks.append(nn.Sequential(*block3)) @@ -144,10 +145,7 @@ def forward(self, inp): x = inp if self.resize_input: - x = F.interpolate(x, - size=(299, 299), - mode='bilinear', - align_corners=False) + x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) if self.normalize_input: x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) @@ -203,9 +201,7 @@ def fid_inception_v3(): This method first constructs torchvision's Inception and then patches the necessary parts that are different in the FID Inception model. """ - inception = _inception_v3(num_classes=1008, - aux_logits=False, - weights=None) + inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None) inception.Mixed_5b = FIDInceptionA(192, pool_features=32) inception.Mixed_5c = FIDInceptionA(256, pool_features=64) inception.Mixed_5d = FIDInceptionA(288, pool_features=64) @@ -223,8 +219,9 @@ def fid_inception_v3(): class FIDInceptionA(torchvision.models.inception.InceptionA): """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): - super(FIDInceptionA, self).__init__(in_channels, pool_features) + super().__init__(in_channels, pool_features) def forward(self, x): branch1x1 = self.branch1x1(x) @@ -238,8 +235,7 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] @@ -248,8 +244,9 @@ def forward(self, x): class FIDInceptionC(torchvision.models.inception.InceptionC): """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): - super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + super().__init__(in_channels, channels_7x7) def forward(self, x): branch1x1 = self.branch1x1(x) @@ -266,8 +263,7 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] @@ -276,8 +272,9 @@ def forward(self, x): class FIDInceptionE_1(torchvision.models.inception.InceptionE): """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): - super(FIDInceptionE_1, self).__init__(in_channels) + super().__init__(in_channels) def forward(self, x): branch1x1 = self.branch1x1(x) @@ -299,8 +296,7 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] @@ -309,8 +305,9 @@ def forward(self, x): class FIDInceptionE_2(torchvision.models.inception.InceptionE): """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): - super(FIDInceptionE_2, self).__init__(in_channels) + super().__init__(in_channels) def forward(self, x): branch1x1 = self.branch1x1(x) @@ -338,4 +335,4 @@ def forward(self, x): branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] - return torch.cat(outputs, 1) \ No newline at end of file + return torch.cat(outputs, 1)