diff --git a/README.md b/README.md index 838a934..98199a7 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ python -m var_sep.main Preprocessing scripts are located in the `var_sep/preprocessing` folder for the WaveEq, WaveEq-100 and Moving MNIST datasets: - `var_sep.preprocessing.mnist.make_test_set` creates the Moving MNIST testing set; +- `var_sep.preprocessing.chairs.gen_chairs` creates, from the original dataset to download at [https://www.di.ens.fr/willow/research/seeing3Dchairs/data/rendered_chairs.tar](https://www.di.ens.fr/willow/research/seeing3Dchairs/data/rendered_chairs.tar), the 64x64 images used by the model; - `var_sep.preprocessing.wave.gen_wave` generates the WaveEq dataset; - `var_sep.preprocessing.wave.gen_pixels` chooses pixels to draw from the WaxeEq dataset to create the WaveEq-100 dataset. @@ -45,6 +46,7 @@ which lists options and hyperparameters to train our model. Evaluation scripts on testing sets are located in the `var_sep/test` folder. - `var_sep.test.mnist.test` evaluates the prediction PSNR and SSIM of the model on Moving MNIST; - `var_sep.test.mnist.test_disentanglement` evaluates the disentanglement PSNR and SSIM of the model by swapping contents and digits on Moving MNIST; +- `var_sep.test.chairs.test_disentanglement` evaluates the disentanglement PSNR and SSIM of the model by swapping contents and chairs on 3D Warehouse Chairs; - `var_sep.sst.wave.test` computes the prediction MSE of the model after 6 and 10 prediction steps on SST; - `var_sep.test.wave.test` computes the prediction MSE of the model after 40 prediction steps on WaveEq and WaveEq-100; Please refer to the corresponding help messages for further information. diff --git a/var_sep/data/chairs.py b/var_sep/data/chairs.py new file mode 100644 index 0000000..ec2767f --- /dev/null +++ b/var_sep/data/chairs.py @@ -0,0 +1,67 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import numpy as np +import torch + +from PIL import Image + + +class Chairs(object): + max_length = 62 + + def __init__(self, train, data_root, nt_cond, seq_len=15, image_size=64): + self.train = train + self.nt_cond = nt_cond + assert seq_len <= self.max_length + self.seq_len = seq_len + assert image_size == 64 + self.image_size = image_size + self.data_root = os.path.join(data_root, 'rendered_chairs') + self.sequences = sorted(os.listdir(self.data_root)) + self.sequences.remove('all_chair_names.mat') + rng = np.random.RandomState(42) + rng.shuffle(self.sequences) + if self.train: + self.start_idx = 0 + self.stop_idx = int(len(self.sequences) * 0.85) + else: + self.start_idx = int(len(self.sequences) * 0.85) + self.stop_idx = len(self.sequences) + + def get_sequence(self, index, chosen_idx=None, chosen_id_st=None): + index, idx = divmod(index, self.stop_idx - self.start_idx) + if chosen_idx is not None: + idx = chosen_idx + obj_dir = self.sequences[self.start_idx + idx] + dname = os.path.join(self.data_root, obj_dir) + index, id_st = divmod(index, self.max_length) + if chosen_id_st is not None: + id_st = chosen_id_st + assert index == 0 + sequence = [] + for i in range(id_st, id_st + self.seq_len): + fname = os.path.join(dname, 'renders', f'{i % self.max_length}.png') + sequence.append(np.array(Image.open(fname))) + sequence = np.array(sequence) + return sequence + + def __getitem__(self, index): + sequence = torch.tensor(self.get_sequence(index) / 255).permute(0, 3, 1, 2).float() + return sequence[:self.nt_cond], sequence[self.nt_cond:] + + def __len__(self): + return (self.max_length) * (self.stop_idx - self.start_idx) diff --git a/var_sep/data/sst.py b/var_sep/data/sst.py index 8614dca..8382886 100644 --- a/var_sep/data/sst.py +++ b/var_sep/data/sst.py @@ -94,6 +94,6 @@ def __getitem__(self, idx): self.climato[file_id][1][idx_id + 1: idx_id + self.pred_h + 1]) mu_norm, std_norm = (self.cst[file_id][0][idx_id + 1: idx_id + self.pred_h + 1], self.cst[file_id][1][idx_id + 1: idx_id + self.pred_h + 1]) - return inputs, target, mu_clim, std_clim, mu_norm, std_norm + return inputs, target, mu_clim, std_clim, mu_norm, std_norm, file_id else: return torch.tensor(inputs, dtype=torch.float), torch.tensor(target, dtype=torch.float) diff --git a/var_sep/main.py b/var_sep/main.py index 9fead7e..a453733 100644 --- a/var_sep/main.py +++ b/var_sep/main.py @@ -25,6 +25,7 @@ from torch.utils.data import DataLoader from var_sep.data.moving_mnist import MovingMNIST +from var_sep.data.chairs import Chairs from var_sep.data.sst import SST from var_sep.data.wave_eq import WaveEq, WaveEqPartial from var_sep.networks.model import SeparableNetwork @@ -61,6 +62,10 @@ args.n_object, True) last_activation = 'sigmoid' shape = [1, 64, 64] + elif args.data == 'chairs': + train_set = Chairs(True, args.data_dir, args.nt_cond, args.nt_cond + args.nt_pred) + last_activation = 'sigmoid' + shape = [3, 64, 64] elif args.data == "sst": train_set = SST(args.data_dir, args.nt_cond, args.nt_pred, True, zones=args.zones) shape = [1, 64, 64] @@ -103,9 +108,9 @@ def worker_init_fn(worker_id): Et = get_encoder(args.architecture, shape, args.code_size_t, args.enc_hidden_size, args.nt_cond, args.init_encoder, args.gain_encoder).to(device) - decoder = get_decoder(args.architecture, shape, args.code_size_t, args.code_size_s, last_activation, - args.dec_hidden_size, args.mixing, args.skipco, args.init_encoder, - args.gain_encoder).to(device) + decoder = get_decoder(args.architecture if args.decoder_architecture is None else args.decoder_architecture, + shape, args.code_size_t, args.code_size_s, last_activation, args.dec_hidden_size, + args.mixing, args.skipco, args.init_encoder, args.gain_encoder).to(device) t_resnet = get_resnet(args.code_size_t, args.n_blocks, args.res_hidden_size, args.init_resnet, args.gain_resnet).to(device) diff --git a/var_sep/networks/conv.py b/var_sep/networks/conv.py index bb8fa15..724becd 100644 --- a/var_sep/networks/conv.py +++ b/var_sep/networks/conv.py @@ -315,3 +315,140 @@ def __init__(self, nc, ny, nf, skip, last_activation, mixing): nn.ConvTranspose2d(nf, nc, 3, 1, 1, bias=False), ), ]) + + +# The following implementation of ResNet18 is taken from DrNet (https://github.com/edenton/drnet-py) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet18(nn.Module): + + def __init__(self, pose_dim, nc=3, out_f=None): + block = BasicBlock + layers = [2, 2, 2, 2, 2] + self.inplanes = 64 + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(nc, 64, kernel_size=5, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.conv_out = nn.Conv2d(512, pose_dim, kernel_size=3) + self.bn_out = nn.BatchNorm2d(pose_dim) + self.out_function = activation_factory(out_f) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, return_skip=False): + x = x.view(x.size(0), -1, x.size(3), x.size(4)) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.conv_out(x) + x = self.out_function(x) + + x = x.view(len(x), -1) + + return x diff --git a/var_sep/networks/factory.py b/var_sep/networks/factory.py index 5daf4c2..38eea20 100644 --- a/var_sep/networks/factory.py +++ b/var_sep/networks/factory.py @@ -15,7 +15,7 @@ import numpy as np -from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder +from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder, ResNet18 from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder from var_sep.networks.resnet import MLPResnet from var_sep.networks.utils import init_net @@ -27,6 +27,8 @@ def get_encoder(nn_type, shape, output_size, hidden_size, nt_cond, init_type, in encoder = DCGAN64Encoder(nc * nt_cond, output_size, hidden_size) elif nn_type == 'vgg': encoder = VGG64Encoder(nc * nt_cond, output_size, hidden_size) + elif nn_type == 'resnet': + encoder = ResNet18(output_size, nc * nt_cond) elif nn_type in ['mlp', 'large_mlp']: input_size = nt_cond * np.prod(np.array(shape)) encoder = MLPEncoder(input_size, hidden_size, output_size, 3) diff --git a/var_sep/options.py b/var_sep/options.py index a174011..ea1d840 100644 --- a/var_sep/options.py +++ b/var_sep/options.py @@ -16,8 +16,9 @@ import argparse -DATASETS = ['wave', 'wave_partial', 'sst', 'mnist'] -ARCH_TYPES = ['dcgan', 'vgg', 'mlp', 'large_mlp'] +DATASETS = ['wave', 'wave_partial', 'sst', 'mnist', 'chairs'] +ARCH_TYPES = ['dcgan', 'vgg', 'resnet', 'mlp', 'large_mlp'] +DECODER_ARCH_TYPES = ['dcgan', 'vgg', 'mlp', 'large_mlp'] INITIALIZATIONS = ['orthogonal', 'kaiming', 'normal'] MIXING = ['concat', 'mul'] @@ -54,6 +55,8 @@ help='Multiplier of the prediction loss.') config_p.add_argument('--architecture', type=str, metavar='ARCH', default='dcgan', choices=ARCH_TYPES, help='Encoder and decoder architecture.') +config_p.add_argument('--decoder_architecture', type=str, metavar='ARCH', default=None, choices=DECODER_ARCH_TYPES, + help='If not None, overwrite the decoder architecture choice.') config_p.add_argument('--skipco', action='store_true', help='Whether to use skip connections from encoders to decoders.') config_p.add_argument('--res_hidden_size', type=int, metavar='SIZE', default=512, diff --git a/var_sep/preprocessing/chairs/gen_chairs.py b/var_sep/preprocessing/chairs/gen_chairs.py new file mode 100644 index 0000000..ec30f0d --- /dev/null +++ b/var_sep/preprocessing/chairs/gen_chairs.py @@ -0,0 +1,46 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import tqdm + +from PIL import Image + + +def generate(data_dir, image_size): + data_dir = os.path.join(data_dir, 'rendered_chairs') + sequence_folders = os.listdir(data_dir) + sequence_folders.remove('all_chair_names.mat') + for sequence_folder in tqdm.tqdm(sequence_folders, ncols=0): + sequence_dir = os.path.join(data_dir, sequence_folder, 'renders') + for i, image_file in enumerate(sorted(os.listdir(sequence_dir))): + image = Image.open(os.path.join(sequence_dir, image_file)).crop((100, 100, 500, 500)).resize((image_size, image_size), + resample=Image.LANCZOS) + image.save(os.path.join(sequence_dir, f'{i}.png')) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='3D Warehouse chairs preprocessing.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Folder where videos from the original dataset are stored.') + parser.add_argument('--image_size', type=int, metavar='SIZE', default=64, + help='Width and height of resulting processed videos.') + args = parser.parse_args() + + generate(args.data_dir, args.image_size) diff --git a/var_sep/test/chairs/test_disentanglement.py b/var_sep/test/chairs/test_disentanglement.py new file mode 100644 index 0000000..cbcb5e6 --- /dev/null +++ b/var_sep/test/chairs/test_disentanglement.py @@ -0,0 +1,200 @@ +# Code heavily modified from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights below. + +# Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import random +import torch +import itertools + +import numpy as np +import torch.nn.functional as F + +from collections import defaultdict +from torch.utils.data import DataLoader +from tqdm import tqdm + +from var_sep.data.chairs import Chairs +from var_sep.utils.helper import DotDict, load_json +from var_sep.utils.ssim import ssim_loss +from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder, ResNet18 +from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder +from var_sep.networks.model import SeparableNetwork + + +def _ssim_wrapper(pred, gt): + bsz, nt_pred = pred.shape[0], pred.shape[1] + img_shape = pred.shape[2:] + ssim = ssim_loss(pred.reshape(bsz * nt_pred, *img_shape), gt.reshape(bsz * nt_pred, *img_shape), max_val=1., reduction='none') + return ssim.mean(dim=[2, 3]).view(bsz, nt_pred, img_shape[0]) + + +class SwapDataset(Chairs): + + def __init__(self, train, data_root, nt_cond, seq_len=20, image_size=64): + super(SwapDataset, self).__init__(train, data_root, nt_cond, seq_len=seq_len, image_size=image_size) + + def __getitem__(self, index): + idx_content = np.random.randint(self.stop_idx - self.start_idx) + id_st_content = np.random.randint(self.max_length - self.seq_len) + sequence = torch.tensor(self.get_sequence(index, chosen_idx=idx_content, + chosen_id_st=id_st_content) / 255).permute(0, 3, 1, 2).float() + sequence_swap = torch.tensor(self.get_sequence(index, + chosen_idx=idx_content) / 255).permute(0, 3, 1, 2).float() + return (sequence[:self.nt_cond], sequence[self.nt_cond:], + sequence_swap[:self.nt_cond].unsqueeze(0), sequence_swap[self.nt_cond:].unsqueeze(0)) + + +def load_dataset(args, train=False): + return Chairs(train, args.data_dir, args.nt_cond, seq_len=args.nt_cond + args.nt_pred) + + +def build_model(args): + Es = torch.load(os.path.join(args.xp_dir, 'ov_Es.pt'), map_location=args.device).to(args.device) + Et = torch.load(os.path.join(args.xp_dir, 'ov_Et.pt'), map_location=args.device).to(args.device) + t_resnet = torch.load(os.path.join(args.xp_dir, 't_resnet.pt'), map_location=args.device).to(args.device) + decoder = torch.load(os.path.join(args.xp_dir, 'decoder.pt'), map_location=args.device).to(args.device) + sep_net = SeparableNetwork(Es, Et, t_resnet, decoder, args.nt_cond, args.skipco) + sep_net.eval() + return sep_net + + +def main(args): + ################################################################################################################## + # Setup + ################################################################################################################## + # -- Device handling (CPU, GPU) + if args.device is None: + device = torch.device('cpu') + else: + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) + device = torch.device('cuda:0') + torch.cuda.set_device(0) + # Seed + random.seed(args.test_seed) + np.random.seed(args.test_seed) + torch.manual_seed(args.test_seed) + # Load XP config + xp_config = load_json(os.path.join(args.xp_dir, 'params.json')) + xp_config.device = device + xp_config.data_dir = args.data_dir + xp_config.xp_dir = args.xp_dir + xp_config.nt_pred = args.nt_pred + xp_config.n_object = 1 + + ################################################################################################################## + # Load test data + ################################################################################################################## + print('Loading data...') + test_dataset = load_dataset(xp_config, train=False) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, pin_memory=True) + swap_dataset = SwapDataset(False, args.data_dir, xp_config.nt_cond, seq_len=xp_config.nt_cond + args.nt_pred) + swap_loader = DataLoader(swap_dataset, batch_size=args.batch_size, pin_memory=True) + nc = 3 + size = 64 + + ################################################################################################################## + # Load model + ################################################################################################################## + print('Loading model...') + sep_net = build_model(xp_config) + + ################################################################################################################## + # Eval + ################################################################################################################## + print('Generating samples...') + torch.set_grad_enabled(False) + swap_iterator = iter(swap_loader) + nt_test = xp_config.nt_cond + args.nt_pred + gt_swap = [] + content_swap = [] + cond_swap = [] + target_swap = [] + results = defaultdict(list) + # Evaluation is done by batch + for batch in tqdm(test_loader, ncols=80, desc='evaluation'): + # Data + x_cond, x_target, _, x_gt_swap = next(swap_iterator) + x_gt_swap = x_gt_swap.to(device) + x_cond = x_cond.to(device) + + # Extraction of S + _, _, s_codes, _ = sep_net.get_forecast(x_cond, nt_test) + + # Content swap + x_swap_cond, x_swap_target = batch + x_swap_cond = x_swap_cond.to(device) + x_swap_target = x_swap_target.to(device) + x_swap_cond_byte = x_cond.cpu().mul(255).byte() + x_swap_target_byte = x_swap_target.cpu().mul(255).byte() + cond_swap.append(x_swap_cond_byte.permute(0, 1, 3, 4, 2)) + target_swap.append(x_swap_target_byte.permute(0, 1, 3, 4, 2)) + x_swap_pred = sep_net.get_forecast(x_swap_cond, nt_test, init_s_code=s_codes[:, 0])[0] + x_swap_pred = x_swap_pred[:, xp_config.nt_cond:] + content_swap.append(x_swap_pred.cpu().mul(255).byte().permute(0, 1, 3, 4, 2)) + gt_swap.append(x_gt_swap[:, 0].cpu().mul(255).byte().permute(0, 1, 3, 4, 2)) + + # Pixelwise quantitative eval + x_gt_swap = x_gt_swap.view(-1, xp_config.n_object, args.nt_pred, nc, size, size).to(device) + metrics_batch = {'mse': [], 'psnr': [], 'ssim': []} + for j, reordering in enumerate(itertools.permutations(range(xp_config.n_object))): + mse = torch.mean(F.mse_loss(x_swap_pred, x_gt_swap[:, j], reduction='none'), dim=[3, 4]) + metrics_batch['mse'].append(mse.mean(2).mean(1).cpu()) + metrics_batch['psnr'].append(10 * torch.log10(1 / mse).mean(2).mean(1).cpu()) + metrics_batch['ssim'].append(_ssim_wrapper(x_swap_pred, x_gt_swap[:, j]).mean(2).mean(1).cpu()) + + # Compute metrics for best samples and register + results['mse'].append(torch.min(torch.stack(metrics_batch['mse']), 0)[0]) + results['psnr'].append(torch.max(torch.stack(metrics_batch['psnr']), 0)[0]) + results['ssim'].append(torch.max(torch.stack(metrics_batch['ssim']), 0)[0]) + + ################################################################################################################## + # Print results + ################################################################################################################## + print('\n') + print('Results:') + for name in results.keys(): + res = torch.cat(results[name]).numpy() + results[name] = res + print(name, res.mean(), '+/-', 1.960 * res.std() / np.sqrt(len(res))) + + ################################################################################################################## + # Save samples + ################################################################################################################## + np.savez_compressed(os.path.join(args.xp_dir, 'results_swap.npz'), **results) + np.savez_compressed(os.path.join(args.xp_dir, 'content_swap_gt.npz'), gt_swap=torch.cat(gt_swap).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'content_swap_test.npz'), content_swap=torch.cat(content_swap).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'cond_swap_test.npz'), cond_swap=torch.cat(cond_swap).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'target_swap_test.npz'), target_swap=torch.cat(target_swap).numpy()) + + +if __name__ == '__main__': + p = argparse.ArgumentParser(prog="PDE-Driven Spatiotemporal Disentanglement (3D Warehouse Chairs content swap testing)") + p.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Directory where the dataset is saved.') + p.add_argument('--xp_dir', type=str, metavar='DIR', required=True, + help='Directory where the model configuration file and checkpoints are saved.') + p.add_argument('--batch_size', type=int, metavar='BATCH', default=16, + help='Batch size used to compute metrics.') + p.add_argument('--nt_pred', type=int, metavar='PRED', required=True, + help='Total of frames to predict.') + p.add_argument('--device', type=int, metavar='DEVICE', default=None, + help='GPU where the model should be placed when testing (if None, put it on the CPU)') + p.add_argument('--test_seed', type=int, metavar='SEED', default=1, + help='Manual seed.') + args = DotDict(vars(p.parse_args())) + main(args) diff --git a/var_sep/test/mnist/test_disentanglement.py b/var_sep/test/mnist/test_disentanglement.py index c7829cc..304629b 100644 --- a/var_sep/test/mnist/test_disentanglement.py +++ b/var_sep/test/mnist/test_disentanglement.py @@ -223,7 +223,6 @@ def main(args): np.savez_compressed(os.path.join(args.xp_dir, 'target_swap_test.npz'), target_swap=torch.cat(target_swap).numpy()) - if __name__ == '__main__': p = argparse.ArgumentParser(prog="PDE-Driven Spatiotemporal Disentanglement (Moving MNIST content swap testing)") p.add_argument('--data_dir', type=str, metavar='DIR', required=True, diff --git a/var_sep/test/sst/test.py b/var_sep/test/sst/test.py index ecc7faf..945a493 100644 --- a/var_sep/test/sst/test.py +++ b/var_sep/test/sst/test.py @@ -26,6 +26,22 @@ from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder from var_sep.networks.model import SeparableNetwork +from var_sep.utils.ssim import ssim_loss + + +def _ssim_wrapper(pred, gt): + bsz, nt_pred = pred.shape[0], pred.shape[1] + img_shape = pred.shape[2:] + ssim = ssim_loss(pred.reshape(bsz * nt_pred, *img_shape), gt.reshape(bsz * nt_pred, *img_shape), max_val=1., reduction='none') + return ssim.mean(dim=[2, 3]).view(bsz, nt_pred, img_shape[0]) + + +def get_min(test_loader): + mins, maxs = {}, {} + for zone in test_loader.zones: + mins[zone] = test_loader.data[zone].min() + maxs[zone] = test_loader.data[zone].max() + return mins, maxs def load_dataset(args, train=False, zones=range(17, 21)): @@ -43,9 +59,11 @@ def build_model(args): def compute_mse(args, test_set, sep_net): + mins, maxs = get_min(test_set) all_mse = [] + all_ssim = [] torch.set_grad_enabled(False) - for cond, target, mu_clim, std_clim, mu_norm, std_norm in tqdm(test_set): + for cond, target, mu_clim, std_clim, mu_norm, std_norm, file_id in tqdm(test_set): cond, target = cond.unsqueeze(0).to(args.device), target.unsqueeze(0).to(args.device) if args.offset: forecasts, t_codes, s_codes, t_residuals = sep_net.get_forecast(cond, target.size(1) + args.nt_cond) @@ -59,17 +77,23 @@ def compute_mse(args, test_set, sep_net): forecasts = (forecasts * std_norm) + mu_norm target = (target * std_norm) + mu_norm - # Original space + # Original space for MSE mu_clim, std_clim = (torch.tensor(mu_clim, dtype=torch.float).to(args.device), torch.tensor(std_clim, dtype=torch.float).to(args.device)) forecasts = (forecasts * std_clim) + mu_clim target = (target * std_clim) + mu_clim - mse = (forecasts - target).pow(2).mean(dim=-1).mean(dim=-1).mean(dim=-1) + # Normalize by min and max per zone for SSIM + min_, max_ = mins[file_id], maxs[file_id], + forecasts = (forecasts - min_)/ (max_ - min_) + target = (target - min_) / (max_ - min_) + ssim = _ssim_wrapper(forecasts, target) + all_mse.append(mse.cpu().numpy()) + all_ssim.append(ssim.cpu().numpy()) - return all_mse + return all_mse, all_ssim def main(args): @@ -88,10 +112,13 @@ def main(args): test_set = load_dataset(xp_config, train=False) sep_net = build_model(xp_config) - all_mse = compute_mse(xp_config, test_set, sep_net) + all_mse, all_ssim = compute_mse(xp_config, test_set, sep_net) mse_array = np.concatenate(all_mse, axis=0) + ssim_array = np.concatenate(all_ssim, axis=0) print(f'MSE at t+10: {np.mean(mse_array.mean(axis=0)[:10])}') print(f'MSE at t+6: {np.mean(mse_array.mean(axis=0)[:6])}') + print(f'SSIM at t+10: {np.mean(ssim_array.mean(axis=0)[:10])}') + print(f'SSIM at t+6: {np.mean(ssim_array.mean(axis=0)[:6])}') if __name__ == '__main__': diff --git a/var_sep/train.py b/var_sep/train.py index 3e1b92a..b24941e 100644 --- a/var_sep/train.py +++ b/var_sep/train.py @@ -33,8 +33,8 @@ def zero_order_loss(s_code_old, s_code_new, skipco): if skipco: - s_code_old = s_code_old[0] - s_code_new = s_code_new[0] + s_code_old = torch.cat([s_code_old[0].flatten()] + [x.flatten() for x in s_code_old[1]]) + s_code_new = torch.cat([s_code_new[0].flatten()] + [x.flatten() for x in s_code_new[1]]) return (s_code_old - s_code_new).pow(2).mean()