forked from rosinality/stylegan2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7069936
commit 5f60e58
Showing
4 changed files
with
227 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import argparse | ||
import pickle | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
from torchvision.models import inception_v3, Inception3 | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from dataset import MultiResolutionDataset | ||
|
||
|
||
class Inception3Feature(Inception3): | ||
def forward(self, x): | ||
if x.shape[2] != 299 or x.shape[3] != 299: | ||
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) | ||
|
||
x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 | ||
x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 | ||
x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 | ||
x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 | ||
|
||
x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 | ||
x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 | ||
x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 | ||
|
||
x = self.Mixed_5b(x) # 35 x 35 x 192 | ||
x = self.Mixed_5c(x) # 35 x 35 x 256 | ||
x = self.Mixed_5d(x) # 35 x 35 x 288 | ||
|
||
x = self.Mixed_6a(x) # 35 x 35 x 288 | ||
x = self.Mixed_6b(x) # 17 x 17 x 768 | ||
x = self.Mixed_6c(x) # 17 x 17 x 768 | ||
x = self.Mixed_6d(x) # 17 x 17 x 768 | ||
x = self.Mixed_6e(x) # 17 x 17 x 768 | ||
|
||
x = self.Mixed_7a(x) # 17 x 17 x 768 | ||
x = self.Mixed_7b(x) # 8 x 8 x 1280 | ||
x = self.Mixed_7c(x) # 8 x 8 x 2048 | ||
|
||
x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 | ||
|
||
return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 | ||
|
||
|
||
def load_patched_inception_v3(): | ||
inception = inception_v3(pretrained=True) | ||
inception_feat = Inception3Feature() | ||
inception_feat.load_state_dict(inception.state_dict()) | ||
|
||
return inception_feat | ||
|
||
|
||
@torch.no_grad() | ||
def extract_features(loader, inception, device): | ||
pbar = tqdm(loader) | ||
|
||
feature_list = [] | ||
|
||
for img in pbar: | ||
img = img.to(device) | ||
feature = inception(img) | ||
feature_list.append(feature.to('cpu')) | ||
|
||
features = torch.cat(feature_list, 0) | ||
|
||
return features | ||
|
||
|
||
if __name__ == '__main__': | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
parser = argparse.ArgumentParser( | ||
description='Calculate Inception v3 features for datasets' | ||
) | ||
parser.add_argument('--size', type=int, default=256) | ||
parser.add_argument('--batch', default=64, type=int, help='batch size') | ||
parser.add_argument('path', metavar='PATH', help='path to datset lmdb file') | ||
|
||
args = parser.parse_args() | ||
|
||
inception = load_patched_inception_v3() | ||
inception = nn.DataParallel(inception).eval().to(device) | ||
|
||
transform = transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] | ||
) | ||
|
||
dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) | ||
loader = DataLoader(dset, batch_size=args.batch, num_workers=4) | ||
|
||
features = extract_features(loader, inception, device).numpy() | ||
|
||
print(f'extracted {features.shape[0]} features') | ||
|
||
mean = np.mean(features, 0) | ||
cov = np.cov(features, rowvar=False) | ||
|
||
name = os.path.splitext(os.path.basename(args.path))[0] | ||
|
||
with open(f'inception_{name}.pkl', 'wb') as f: | ||
pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import argparse | ||
import pickle | ||
|
||
import torch | ||
import numpy as np | ||
from scipy import linalg | ||
from tqdm import tqdm | ||
|
||
from model import Generator | ||
from calc_inception import load_patched_inception_v3 | ||
|
||
|
||
@torch.no_grad() | ||
def extract_feature_from_samples( | ||
generator, inception, truncation, truncation_latent, batch_size, n_sample, device | ||
): | ||
n_batch = n_sample // batch_size | ||
resid = n_sample - (n_batch * batch_size) | ||
batch_sizes = [batch_size] * n_batch + [resid] | ||
features = [] | ||
|
||
for batch in tqdm(batch_sizes): | ||
latent = torch.randn(batch, 512, device=device) | ||
img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) | ||
feat = inception(img) | ||
features.append(feat.to('cpu')) | ||
|
||
features = torch.cat(features, 0) | ||
|
||
return features | ||
|
||
|
||
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): | ||
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) | ||
|
||
if not np.isfinite(cov_sqrt).all(): | ||
print('product of cov matrices is singular') | ||
offset = np.eye(sample_cov.shape[0]) * eps | ||
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) | ||
|
||
if np.iscomplexobj(cov_sqrt): | ||
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): | ||
m = np.max(np.abs(cov_sqrt.imag)) | ||
|
||
raise ValueError(f'Imaginary component {m}') | ||
|
||
cov_sqrt = cov_sqrt.real | ||
|
||
mean_diff = sample_mean - real_mean | ||
mean_norm = mean_diff @ mean_diff | ||
|
||
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) | ||
|
||
fid = mean_norm + trace | ||
|
||
return fid | ||
|
||
|
||
if __name__ == '__main__': | ||
device = 'cuda' | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--truncation', type=float, default=1) | ||
parser.add_argument('--truncation_mean', type=int, default=4096) | ||
parser.add_argument('--batch', type=int, default=64) | ||
parser.add_argument('--n_sample', type=int, default=5000) | ||
parser.add_argument('--size', type=int, default=256) | ||
parser.add_argument('--inception', type=str, default=None, required=True) | ||
parser.add_argument('ckpt', metavar='CHECKPOINT') | ||
|
||
args = parser.parse_args() | ||
|
||
ckpt = torch.load(args.ckpt) | ||
|
||
g = Generator(args.size, 512, 8).to(device) | ||
g.load_state_dict(ckpt['g_ema']) | ||
g.eval() | ||
|
||
if args.truncation < 1: | ||
with torch.no_grad(): | ||
mean_latent = g.mean_latent(args.truncation_mean) | ||
|
||
else: | ||
mean_latent = None | ||
|
||
inception = load_patched_inception_v3().to(device) | ||
inception.eval() | ||
|
||
features = extract_feature_from_samples( | ||
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device | ||
).numpy() | ||
print(f'extracted {features.shape[0]} features') | ||
|
||
sample_mean = np.mean(features, 0) | ||
sample_cov = np.cov(features, rowvar=False) | ||
|
||
with open(args.inception, 'rb') as f: | ||
embeds = pickle.load(f) | ||
real_mean = embeds['mean'] | ||
real_cov = embeds['cov'] | ||
|
||
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) | ||
|
||
print('fid:', fid) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters