Skip to content

Commit

Permalink
Updated code for the new version of the preprint
Browse files Browse the repository at this point in the history
  • Loading branch information
White-Link committed Oct 7, 2020
1 parent 63215fb commit 87f82e8
Show file tree
Hide file tree
Showing 12 changed files with 503 additions and 15 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ python -m var_sep.main

Preprocessing scripts are located in the `var_sep/preprocessing` folder for the WaveEq, WaveEq-100 and Moving MNIST datasets:
- `var_sep.preprocessing.mnist.make_test_set` creates the Moving MNIST testing set;
- `var_sep.preprocessing.chairs.gen_chairs` creates, from the original dataset to download at [https://www.di.ens.fr/willow/research/seeing3Dchairs/data/rendered_chairs.tar](https://www.di.ens.fr/willow/research/seeing3Dchairs/data/rendered_chairs.tar), the 64x64 images used by the model;
- `var_sep.preprocessing.wave.gen_wave` generates the WaveEq dataset;
- `var_sep.preprocessing.wave.gen_pixels` chooses pixels to draw from the WaxeEq dataset to create the WaveEq-100 dataset.

Expand All @@ -45,6 +46,7 @@ which lists options and hyperparameters to train our model.
Evaluation scripts on testing sets are located in the `var_sep/test` folder.
- `var_sep.test.mnist.test` evaluates the prediction PSNR and SSIM of the model on Moving MNIST;
- `var_sep.test.mnist.test_disentanglement` evaluates the disentanglement PSNR and SSIM of the model by swapping contents and digits on Moving MNIST;
- `var_sep.test.chairs.test_disentanglement` evaluates the disentanglement PSNR and SSIM of the model by swapping contents and chairs on 3D Warehouse Chairs;
- `var_sep.sst.wave.test` computes the prediction MSE of the model after 6 and 10 prediction steps on SST;
- `var_sep.test.wave.test` computes the prediction MSE of the model after 40 prediction steps on WaveEq and WaveEq-100;
Please refer to the corresponding help messages for further information.
67 changes: 67 additions & 0 deletions var_sep/data/chairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import numpy as np
import torch

from PIL import Image


class Chairs(object):
max_length = 62

def __init__(self, train, data_root, nt_cond, seq_len=15, image_size=64):
self.train = train
self.nt_cond = nt_cond
assert seq_len <= self.max_length
self.seq_len = seq_len
assert image_size == 64
self.image_size = image_size
self.data_root = os.path.join(data_root, 'rendered_chairs')
self.sequences = sorted(os.listdir(self.data_root))
self.sequences.remove('all_chair_names.mat')
rng = np.random.RandomState(42)
rng.shuffle(self.sequences)
if self.train:
self.start_idx = 0
self.stop_idx = int(len(self.sequences) * 0.85)
else:
self.start_idx = int(len(self.sequences) * 0.85)
self.stop_idx = len(self.sequences)

def get_sequence(self, index, chosen_idx=None, chosen_id_st=None):
index, idx = divmod(index, self.stop_idx - self.start_idx)
if chosen_idx is not None:
idx = chosen_idx
obj_dir = self.sequences[self.start_idx + idx]
dname = os.path.join(self.data_root, obj_dir)
index, id_st = divmod(index, self.max_length)
if chosen_id_st is not None:
id_st = chosen_id_st
assert index == 0
sequence = []
for i in range(id_st, id_st + self.seq_len):
fname = os.path.join(dname, 'renders', f'{i % self.max_length}.png')
sequence.append(np.array(Image.open(fname)))
sequence = np.array(sequence)
return sequence

def __getitem__(self, index):
sequence = torch.tensor(self.get_sequence(index) / 255).permute(0, 3, 1, 2).float()
return sequence[:self.nt_cond], sequence[self.nt_cond:]

def __len__(self):
return (self.max_length) * (self.stop_idx - self.start_idx)
2 changes: 1 addition & 1 deletion var_sep/data/sst.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ def __getitem__(self, idx):
self.climato[file_id][1][idx_id + 1: idx_id + self.pred_h + 1])
mu_norm, std_norm = (self.cst[file_id][0][idx_id + 1: idx_id + self.pred_h + 1],
self.cst[file_id][1][idx_id + 1: idx_id + self.pred_h + 1])
return inputs, target, mu_clim, std_clim, mu_norm, std_norm
return inputs, target, mu_clim, std_clim, mu_norm, std_norm, file_id
else:
return torch.tensor(inputs, dtype=torch.float), torch.tensor(target, dtype=torch.float)
11 changes: 8 additions & 3 deletions var_sep/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.utils.data import DataLoader

from var_sep.data.moving_mnist import MovingMNIST
from var_sep.data.chairs import Chairs
from var_sep.data.sst import SST
from var_sep.data.wave_eq import WaveEq, WaveEqPartial
from var_sep.networks.model import SeparableNetwork
Expand Down Expand Up @@ -61,6 +62,10 @@
args.n_object, True)
last_activation = 'sigmoid'
shape = [1, 64, 64]
elif args.data == 'chairs':
train_set = Chairs(True, args.data_dir, args.nt_cond, args.nt_cond + args.nt_pred)
last_activation = 'sigmoid'
shape = [3, 64, 64]
elif args.data == "sst":
train_set = SST(args.data_dir, args.nt_cond, args.nt_pred, True, zones=args.zones)
shape = [1, 64, 64]
Expand Down Expand Up @@ -103,9 +108,9 @@ def worker_init_fn(worker_id):
Et = get_encoder(args.architecture, shape, args.code_size_t, args.enc_hidden_size, args.nt_cond,
args.init_encoder, args.gain_encoder).to(device)

decoder = get_decoder(args.architecture, shape, args.code_size_t, args.code_size_s, last_activation,
args.dec_hidden_size, args.mixing, args.skipco, args.init_encoder,
args.gain_encoder).to(device)
decoder = get_decoder(args.architecture if args.decoder_architecture is None else args.decoder_architecture,
shape, args.code_size_t, args.code_size_s, last_activation, args.dec_hidden_size,
args.mixing, args.skipco, args.init_encoder, args.gain_encoder).to(device)

t_resnet = get_resnet(args.code_size_t, args.n_blocks, args.res_hidden_size, args.init_resnet,
args.gain_resnet).to(device)
Expand Down
137 changes: 137 additions & 0 deletions var_sep/networks/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,140 @@ def __init__(self, nc, ny, nf, skip, last_activation, mixing):
nn.ConvTranspose2d(nf, nc, 3, 1, 1, bias=False),
),
])


# The following implementation of ResNet18 is taken from DrNet (https://github.com/edenton/drnet-py)


def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class ResNet18(nn.Module):

def __init__(self, pose_dim, nc=3, out_f=None):
block = BasicBlock
layers = [2, 2, 2, 2, 2]
self.inplanes = 64
super(ResNet18, self).__init__()
self.conv1 = nn.Conv2d(nc, 64, kernel_size=5, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.conv_out = nn.Conv2d(512, pose_dim, kernel_size=3)
self.bn_out = nn.BatchNorm2d(pose_dim)
self.out_function = activation_factory(out_f)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride),
nn.BatchNorm2d(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x, return_skip=False):
x = x.view(x.size(0), -1, x.size(3), x.size(4))

x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.conv_out(x)
x = self.out_function(x)

x = x.view(len(x), -1)

return x
4 changes: 3 additions & 1 deletion var_sep/networks/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import numpy as np

from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder
from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder, ResNet18
from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder
from var_sep.networks.resnet import MLPResnet
from var_sep.networks.utils import init_net
Expand All @@ -27,6 +27,8 @@ def get_encoder(nn_type, shape, output_size, hidden_size, nt_cond, init_type, in
encoder = DCGAN64Encoder(nc * nt_cond, output_size, hidden_size)
elif nn_type == 'vgg':
encoder = VGG64Encoder(nc * nt_cond, output_size, hidden_size)
elif nn_type == 'resnet':
encoder = ResNet18(output_size, nc * nt_cond)
elif nn_type in ['mlp', 'large_mlp']:
input_size = nt_cond * np.prod(np.array(shape))
encoder = MLPEncoder(input_size, hidden_size, output_size, 3)
Expand Down
7 changes: 5 additions & 2 deletions var_sep/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import argparse


DATASETS = ['wave', 'wave_partial', 'sst', 'mnist']
ARCH_TYPES = ['dcgan', 'vgg', 'mlp', 'large_mlp']
DATASETS = ['wave', 'wave_partial', 'sst', 'mnist', 'chairs']
ARCH_TYPES = ['dcgan', 'vgg', 'resnet', 'mlp', 'large_mlp']
DECODER_ARCH_TYPES = ['dcgan', 'vgg', 'mlp', 'large_mlp']
INITIALIZATIONS = ['orthogonal', 'kaiming', 'normal']
MIXING = ['concat', 'mul']

Expand Down Expand Up @@ -54,6 +55,8 @@
help='Multiplier of the prediction loss.')
config_p.add_argument('--architecture', type=str, metavar='ARCH', default='dcgan', choices=ARCH_TYPES,
help='Encoder and decoder architecture.')
config_p.add_argument('--decoder_architecture', type=str, metavar='ARCH', default=None, choices=DECODER_ARCH_TYPES,
help='If not None, overwrite the decoder architecture choice.')
config_p.add_argument('--skipco', action='store_true',
help='Whether to use skip connections from encoders to decoders.')
config_p.add_argument('--res_hidden_size', type=int, metavar='SIZE', default=512,
Expand Down
46 changes: 46 additions & 0 deletions var_sep/preprocessing/chairs/gen_chairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import os
import tqdm

from PIL import Image


def generate(data_dir, image_size):
data_dir = os.path.join(data_dir, 'rendered_chairs')
sequence_folders = os.listdir(data_dir)
sequence_folders.remove('all_chair_names.mat')
for sequence_folder in tqdm.tqdm(sequence_folders, ncols=0):
sequence_dir = os.path.join(data_dir, sequence_folder, 'renders')
for i, image_file in enumerate(sorted(os.listdir(sequence_dir))):
image = Image.open(os.path.join(sequence_dir, image_file)).crop((100, 100, 500, 500)).resize((image_size, image_size),
resample=Image.LANCZOS)
image.save(os.path.join(sequence_dir, f'{i}.png'))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='3D Warehouse chairs preprocessing.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('--data_dir', type=str, metavar='DIR', required=True,
help='Folder where videos from the original dataset are stored.')
parser.add_argument('--image_size', type=int, metavar='SIZE', default=64,
help='Width and height of resulting processed videos.')
args = parser.parse_args()

generate(args.data_dir, args.image_size)
Loading

0 comments on commit 87f82e8

Please sign in to comment.