Skip to content

Commit

Permalink
Added FID calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Dec 20, 2019
1 parent 7069936 commit 5f60e58
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 12 deletions.
106 changes: 106 additions & 0 deletions calc_inception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import argparse
import pickle
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import inception_v3, Inception3
import numpy as np
from tqdm import tqdm

from dataset import MultiResolutionDataset


class Inception3Feature(Inception3):
def forward(self, x):
if x.shape[2] != 299 or x.shape[3] != 299:
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)

x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64

x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192

x = self.Mixed_5b(x) # 35 x 35 x 192
x = self.Mixed_5c(x) # 35 x 35 x 256
x = self.Mixed_5d(x) # 35 x 35 x 288

x = self.Mixed_6a(x) # 35 x 35 x 288
x = self.Mixed_6b(x) # 17 x 17 x 768
x = self.Mixed_6c(x) # 17 x 17 x 768
x = self.Mixed_6d(x) # 17 x 17 x 768
x = self.Mixed_6e(x) # 17 x 17 x 768

x = self.Mixed_7a(x) # 17 x 17 x 768
x = self.Mixed_7b(x) # 8 x 8 x 1280
x = self.Mixed_7c(x) # 8 x 8 x 2048

x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048

return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048


def load_patched_inception_v3():
inception = inception_v3(pretrained=True)
inception_feat = Inception3Feature()
inception_feat.load_state_dict(inception.state_dict())

return inception_feat


@torch.no_grad()
def extract_features(loader, inception, device):
pbar = tqdm(loader)

feature_list = []

for img in pbar:
img = img.to(device)
feature = inception(img)
feature_list.append(feature.to('cpu'))

features = torch.cat(feature_list, 0)

return features


if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(
description='Calculate Inception v3 features for datasets'
)
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--batch', default=64, type=int, help='batch size')
parser.add_argument('path', metavar='PATH', help='path to datset lmdb file')

args = parser.parse_args()

inception = load_patched_inception_v3()
inception = nn.DataParallel(inception).eval().to(device)

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)

dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
loader = DataLoader(dset, batch_size=args.batch, num_workers=4)

features = extract_features(loader, inception, device).numpy()

print(f'extracted {features.shape[0]} features')

mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)

name = os.path.splitext(os.path.basename(args.path))[0]

with open(f'inception_{name}.pkl', 'wb') as f:
pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f)
105 changes: 105 additions & 0 deletions fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import argparse
import pickle

import torch
import numpy as np
from scipy import linalg
from tqdm import tqdm

from model import Generator
from calc_inception import load_patched_inception_v3


@torch.no_grad()
def extract_feature_from_samples(
generator, inception, truncation, truncation_latent, batch_size, n_sample, device
):
n_batch = n_sample // batch_size
resid = n_sample - (n_batch * batch_size)
batch_sizes = [batch_size] * n_batch + [resid]
features = []

for batch in tqdm(batch_sizes):
latent = torch.randn(batch, 512, device=device)
img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
feat = inception(img)
features.append(feat.to('cpu'))

features = torch.cat(features, 0)

return features


def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)

if not np.isfinite(cov_sqrt).all():
print('product of cov matrices is singular')
offset = np.eye(sample_cov.shape[0]) * eps
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))

if np.iscomplexobj(cov_sqrt):
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
m = np.max(np.abs(cov_sqrt.imag))

raise ValueError(f'Imaginary component {m}')

cov_sqrt = cov_sqrt.real

mean_diff = sample_mean - real_mean
mean_norm = mean_diff @ mean_diff

trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)

fid = mean_norm + trace

return fid


if __name__ == '__main__':
device = 'cuda'

parser = argparse.ArgumentParser()

parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation_mean', type=int, default=4096)
parser.add_argument('--batch', type=int, default=64)
parser.add_argument('--n_sample', type=int, default=5000)
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--inception', type=str, default=None, required=True)
parser.add_argument('ckpt', metavar='CHECKPOINT')

args = parser.parse_args()

ckpt = torch.load(args.ckpt)

g = Generator(args.size, 512, 8).to(device)
g.load_state_dict(ckpt['g_ema'])
g.eval()

if args.truncation < 1:
with torch.no_grad():
mean_latent = g.mean_latent(args.truncation_mean)

else:
mean_latent = None

inception = load_patched_inception_v3().to(device)
inception.eval()

features = extract_feature_from_samples(
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
).numpy()
print(f'extracted {features.shape[0]} features')

sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)

with open(args.inception, 'rb') as f:
embeds = pickle.load(f)
real_mean = embeds['mean']
real_cov = embeds['cov']

fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)

print('fid:', fid)
Binary file added inception_ffhq.pkl
Binary file not shown.
28 changes: 16 additions & 12 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def __init__(
lr_mlp=0.01,
):
super().__init__()

self.style_dim = style_dim

layers = [PixelNorm()]
Expand Down Expand Up @@ -425,24 +425,28 @@ def __init__(
in_channel = out_channel

self.n_latent = log_size * 2 - 2



def mean_latent(self, n_latent):
latent_in = torch.randn(n_latent, self.style_dim, device=self.input.input.device)
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device
)
latent = self.style(latent_in).mean(0, keepdim=True)

return latent


def forward(self, styles, return_latents=False, truncation=0, truncation_latent=None):
def forward(
self, styles, return_latents=False, truncation=0, truncation_latent=None
):
styles = [self.style(s) for s in styles]
if truncation > 0:

if truncation < 1:
style_t = []

for style in styles:
style_t.append(style + truncation * (truncation_latent - style))

style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)

styles = style_t

if len(styles) < 2:
Expand Down

0 comments on commit 5f60e58

Please sign in to comment.