diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..70db2f3 --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +# PDE-Driven Spatiotemporal Disentanglement + +Official implementation of the paper *PDE-Driven Spatiotemporal Disentanglement* (Jérémie Donà,* Jean-Yves Franceschi,* Sylvain Lamprier, Patrick Gallinari). + + +## Requirements + +All models were trained with Python 3.8.1 and PyTorch 1.4.0 using CUDA 10.1. The `requirements.txt` file lists Python package dependencies. + +We obtained all our models thanks to mixed-precision training with Nvidia's [Apex](https://nvidia.github.io/apex/) (v0.1), allowing to accelerate training on the most recent Nvidia GPU architectures. This optimization can be enabled using the command-line options. + + +## Execution + +All scripts should be executed as modules from the root of this folder. For example, the training script can be launched with: +```bash +python -m var_sep.main +``` + + +## Datasets + +Preprocessing scripts are located in the `var_sep/preprocessing` folder for the WaveEq, WaveEq-100 and Moving MNIST datasets: +- `var_sep.preprocessing.mnist.make_test_set` creates the Moving MNIST testing set; +- `var_sep.preprocessing.wave.gen_wave` generates the WaveEq dataset; +- `var_sep.preprocessing.wave.gen_pixels` chooses pixels to draw from the WaxeEq dataset to create the WaveEq-100 dataset. + +Regarding SST, we refer the reader to the article in which it was introduced ([https://openreview.net/forum?id=By4HsfWAZ](https://openreview.net/forum?id=By4HsfWAZ)) and its authors, as we do not own the preprocessing script to this date. + + +## Training + +Please refer to the help message of `main.py`: +```bash +python -m var_sep.main --help +``` +which lists options and hyperparameters to train our model. + + +## Testing + +Evaluation scripts on testing sets are located in the `var_sep/test` folder. +- `var_sep.test.mnist.test` evaluates the prediction PSNR and SSIM of the model on Moving MNIST; +- `var_sep.test.mnist.test_disentanglement` evaluates the disentanglement PSNR and SSIM of the model by swapping contents and digits on Moving MNIST; +- `var_sep.sst.wave.test` computes the prediction MSE of the model after 6 and 10 prediction steps on SST; +- `var_sep.test.wave.test` computes the prediction MSE of the model after 40 prediction steps on WaveEq and WaveEq-100; +Please refer to the corresponding help messages for further information. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..54ab1fd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +# Requirements + +numpy==1.18.1 +netCDF4==1.5.3 +pyyaml==5.3 +scikit-image==0.16.2 +scipy==1.4.1 +torch==1.4.0 +torchvision==0.5.0 +torchdiffeq==0.0.1 +tqdm==4.43.0 diff --git a/var_sep/data/moving_mnist.py b/var_sep/data/moving_mnist.py new file mode 100644 index 0000000..2c4669e --- /dev/null +++ b/var_sep/data/moving_mnist.py @@ -0,0 +1,340 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Code adapted from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights below. + +# Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import torch + +import numpy as np + +from torch.utils.data import Dataset +from torchvision import datasets + + +class MovingMNIST(Dataset): + """ + Updated stochastic and deterministic MovingMNIST dataset, inspired by + https://github.com/edenton/svg/blob/master/data/moving_mnist.py. + See the paper for more information. + Attributes + ---------- + data : list + When testing, list of testing videos represented as uint8 NumPy arrays (length, width, height). + When training, list of digits shape to use when generating videos from the dataset. + frame_size : int + Width and height of the video frames. + nt_cond : int + Number of conditioning frames. + seq_len : int + Number of frames to produce. + max_speed : int + Maximum speed of moving digits in the videos. + deterministic : bool + Whether to use the deterministic version of the dataset rather that the stochastic one. + num_digits : int + Number of digits in each video. + train : bool + Whether to use the training or testing dataset. + eps : float + Precision parameter to compute intersections between trajectories and frame borders. + """ + eps = 1e-8 + + def __init__(self, data, nx, nt_cond, seq_len, max_speed, deterministic, num_digits, train): + """ + Parameters + ---------- + data : list + When testing, list of testing videos represented as uint8 NumPy arrays (length, width, height). + When training, list of digits shape to use when generating videos from the dataset. + nx : int + Width and height of the video frames. + nt_cond : int + Number of conditioning frames. + seq_len : int + Number of frames to produce. + max_speed : int + Maximum speed of moving digits in the videos. + deterministic : bool + Whether to use the deterministic version of the dataset rather that the stochastic one. + num_digits : int + Number of digits in each video. + train : bool + Whether to use the training or testing dataset. + """ + super(MovingMNIST, self).__init__() + self.data = data + self.frame_size = nx + self.nt_cond = nt_cond + self.seq_len = seq_len + self.max_speed = max_speed + self.deterministic = deterministic + self.num_digits = num_digits + self.train = train + + def __len__(self): + if self.train: + # Arbitrary number + # The number is a trade-off for max efficiency + # If too low, it is not good for batch size and multi threaded dataloader + # If too high, it is not good for shuffling and sampling + return 200000 + return len(self.data) + + def __getitem__(self, index): + if not self.train: + # When testing, pick the selected video (from the precomputed testing set) + return self.data[index][:self.nt_cond] / 255, self.data[index][self.nt_cond:] / 255 + # When training, generate videos on the fly + x = np.zeros((self.seq_len, 1, self.frame_size, self.frame_size), dtype=np.float32) + # Generate the trajectories of each digit independently + for n in range(self.num_digits): + img = self.data[np.random.randint(len(self.data))] # Random digit + trajectory = self._compute_trajectory(*img.shape) # Generate digit trajectory + for t in range(self.seq_len): + sx, sy, _, _ = trajectory[t] + # Adds the generated digit trajectory to the video + x[t, 0, sx:sx + img.shape[0], sy:sy + img.shape[1]] += img + # In case of overlap, brings back video values to [0, 255] + x[x > 255] = 255 + x = x / 255 + return torch.tensor(x[:self.nt_cond]), torch.tensor(x[self.nt_cond:]) + + def _compute_trajectory(self, nx, ny, init_cond=None): + """ + Create a trajectory. + Parameters + ---------- + nx : int + Width of digit image. + ny : int + Height of digit image. + init_cond : tuple + Optional initial condition for the generated trajectory. It is a tuple of integers (posx, poxy, dx, dy) + where posx and poxy are the initial coordinates, and dx and dy form the initial speed vector. + Returns + ------- + list + List of tuples (posx, poxy, dx, dy) describing the evolution of the position and speed of the moving + object. Positions refer to the lower left corner of the object. + """ + x = [] # Trajectory + x_max = self.frame_size - nx # Maximum x coordinate allowed + y_max = self.frame_size - ny # Maximum y coordinate allowed + # Process or create the initial position and speed + if init_cond is None: + sx = np.random.randint(0, x_max + 1) + sy = np.random.randint(0, y_max + 1) + dx = np.random.randint(-self.max_speed, self.max_speed + 1) + dy = np.random.randint(-self.max_speed, self.max_speed + 1) + else: + sx, sy, dx, dy = init_cond + # Create the trajectory + for t in range(self.seq_len): + # After the movement of a timestep is applied, update the position and speed to take into account + # collisions with frame borders + sx, sy, dx, dy = self._process_collision(sx, sy, dx, dy, x_min=0, x_max=x_max, y_min=0, y_max=y_max) + # Add rounded position and speed to the trajectory + x.append([int(round(sx)), int(round(sy)), dx, dy]) + # Keep computing the trajectory with exact positions + sy += dy + sx += dx + return x + + def _process_collision(self, sx, sy, dx, dy, x_min, x_max, y_min, y_max): + """ + Takes as input current object coordinate and speed that might be over the frame borders after the movement of + the last timestep, and updates them to take into account the object collision with frame borders. + Parameters + ---------- + sx : float + Current object x coordinate, prior to checking whether it collided with a frame border. + sy : float + Current object y coordinate, prior to checking whether it collided with a frame border. + dx : int + Current object x speed, prior to checking whether it collided with a frame border. + dy : int + Current object y speed, prior to checking whether it collided with a frame border. + x_min : int + Minimum x coordinate allowed. + x_max : int + Maximum x coordinate allowed. + y_min : int + Minimum y coordinate allowed. + y_max : int + Maximum y coordinate allowed. + """ + # Check collision on all four edges + left_edge = (sx < x_min - self.eps) + upper_edge = (sy < y_min - self.eps) + right_edge = (sx > x_max + self.eps) + bottom_edge = (sy > y_max + self.eps) + # Continue processing as long as a collision is detected + while (left_edge or right_edge or upper_edge or bottom_edge): + # Retroactively compute the collision coordinates, using the current out-of-frame position and speed + # These coordinates are stored in cx and cy + if dx == 0: # x is onstant + cx, cy = (sx, y_min) if upper_edge else (sx, y_max) + elif dy == 0: # y is constant + cx, cy = (x_min, sy) if left_edge else (x_max, sy) + else: + a = dy / dx + b = sy - a * sx + # Searches for the first intersection with frame borders + if left_edge: + left_edge, n = self._get_intersection_x(a, b, x_min, (y_min, y_max)) + if left_edge: + cx, cy = n + if right_edge: + right_edge, n = self._get_intersection_x(a, b, x_max, (y_min, y_max)) + if right_edge: + cx, cy = n + if upper_edge: + upper_edge, n = self._get_intersection_y(a, b, y_min, (x_min, x_max)) + if upper_edge: + cx, cy = n + if bottom_edge: + bottom_edge, n = self._get_intersection_y(a, b, y_max, (x_min, x_max)) + if bottom_edge: + cx, cy = n + # Displacement coefficient to get new coordinates after the bounce, taking into account the time left + # (after all previous displacements) in the timestep to move the object + p = ((sx - cx) / dx) if (dx != 0) else ((sy - cy) / dy) + # In the stochastic case, randomly choose a new speed vector + if not self.deterministic: + dx = np.random.randint(-self.max_speed, self.max_speed + 1) + dy = np.random.randint(-self.max_speed, self.max_speed + 1) + # Reverse speed vector elements depending on the detected collision + if left_edge: + dx = abs(dx) + if right_edge: + dx = -abs(dx) + if upper_edge: + dy = abs(dy) + if bottom_edge: + dy = -abs(dy) + # Compute the remaining displacement to be done during the timestep after the bounce + sx = cx + dx * p + sy = cy + dy * p + # Check again collisions + left_edge = (sx < x_min - self.eps) + upper_edge = (sy < y_min - self.eps) + right_edge = (sx > x_max + self.eps) + bottom_edge = (sy > y_max + self.eps) + # Return updated speed and coordinates + return sx, sy, dx, dy + + def _get_intersection_x(self, a, b, x_lim, by): + """ + Computes the intersection point of trajectory with the upper or lower border of the frame. + Parameters + ---------- + a : float + dy / dx. + b : float + sy - a * sx. + x_lim : int + x coordinate of the border of the frame to test the intersection with. + by : tuple + Tuple of integers representing the frame limits on the y coordinate. + Returns + ------- + bool + Whether the intersection point lies within the frame limits. + tuple + Couple of float coordinates representing the intersection point. + """ + y_inter = a * x_lim + b + if (y_inter >= by[0] - self.eps) and (y_inter <= by[1] + self.eps): + return True, (x_lim, y_inter) + return False, (x_lim, y_inter) + + def _get_intersection_y(self, a, b, y_lim, bx): + """ + Computes the intersection point of trajectory with the left or right border of the frame. + Parameters + ---------- + a : float + dy / dx. + b : float + sy - a * sx. + y_lim : int + y coordinate of the border of the frame to test the intersection with. + bx : tuple + Tuple of integers representing the frame limits on the x coordinate. + Returns + ------- + bool + Whether the intersection point lies within the frame limits. + tuple + Couple of float coordinates representing the intersection point. + """ + x_inter = (y_lim - b) / a + if (x_inter >= bx[0] - self.eps) and (x_inter <= bx[1] + self.eps): + return True, (x_inter, y_lim) + return False, (x_inter, y_lim) + + @classmethod + def make_dataset(cls, data_dir, nx, nt_cond, seq_len, max_speed, deterministic, num_digits, train): + """ + Creates a dataset from the directory where the dataset is saved. + Parameters + ---------- + data_dir : str + Path to the dataset. + nx : int + Width and height of the video frames. + nt_cond : int + Number of conditioning frames. + seq_len : int + Number of frames to produce. + max_speed : int + Maximum speed of moving digits in the videos. + deterministic : bool + Whether to use the deterministic version of the dataset rather that the stochastic one. + num_digits : int + Number of digits in each video. + train : bool + Whether to use the training or testing dataset. + """ + if train: + # When training, only register training MNIST digits + digits = datasets.MNIST(data_dir, train=train, download=True) + data = [np.array(img, dtype=np.uint8) for i, (img, label) in enumerate(digits)] + else: + # When testining, loads the precomputed videos + prefix = '' if deterministic else 's' + dataset = np.load(os.path.join(data_dir, f'{prefix}mmnist_test_{num_digits}digits_{nx}.npz'), + allow_pickle=True) + sequences = dataset['sequences'] + data = [sequences[:, i].astype(np.single) for i in range(sequences.shape[1])] + # Create and return the dataset + return cls(data, nx, nt_cond, seq_len, max_speed, deterministic, num_digits, train) diff --git a/var_sep/data/sst.py b/var_sep/data/sst.py new file mode 100644 index 0000000..8614dca --- /dev/null +++ b/var_sep/data/sst.py @@ -0,0 +1,99 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import torch + +from torch.utils.data import Dataset + +from netCDF4 import Dataset as netCDFDataset + + +def extract_data(fp, variables): + loaded_file = netCDFDataset(fp, 'r') + data_dict = {} + for var in variables: + data_dict[var] = loaded_file.variables[var][:].data + return data_dict + + +class SST(Dataset): + var_names = ['thetao', 'daily_mean', 'daily_std'] + + def __init__(self, data_dir, nt_cond, nt_pred, train, zones=range(1, 30), eval=False): + super(SST, self).__init__() + + self.data_dir = data_dir + self.pred_h = nt_pred + self.zones = list(zones) + self.lb = nt_cond + self.zone_size = 64 + + self.data = {} + self.cst = {} + self.climato = {} + + self.train = train + self.eval = eval + + self._normalize() + + self.first = 0 if self.train else int(0.8 * self.len_) + + # Retrieve length + if self.train: + self.len_ = int(0.8 * self.len_) + else: + self.len_ = self.len_ - int(0.8 * self.len_) + + self.len_ = self.len_ - self.pred_h - self.lb - 1 + self._total_len = len(self.zones) * self.len_ + + def _normalize(self): + for zone in self.zones: + zdata = extract_data(os.path.join(self.data_dir, f'data_{zone}.nc'), variables=self.var_names) + self.len_ = len(zdata["thetao"]) + + climate_mean, climiate_std = zdata['daily_mean'].reshape(-1, 1, 1), zdata['daily_std'].reshape(-1, 1, 1) + zdata["thetao"] = (zdata["thetao"] - climate_mean) / climiate_std + self.climato[zone] = (climate_mean, climiate_std) + + mean = zdata["thetao"].mean(axis=(1, 2)).reshape(-1, 1, 1) + std = zdata["thetao"].std(axis=(1, 2)).reshape(-1, 1, 1) + zdata["thetao"] = (zdata["thetao"] - mean) / std + self.cst[zone] = (mean, std) + + self.data[zone] = zdata["thetao"] + + def __len__(self): + return self._total_len + + def __getitem__(self, idx): + file_id = self.zones[idx // self.len_] + idx_id = (idx % self.len_) + self.lb + 1 + self.first + inputs = self.data[file_id][idx_id - self.lb + 1: idx_id + 1].reshape(self.lb, 1, self.zone_size, + self.zone_size) + target = self.data[file_id][idx_id + 1: idx_id + self.pred_h + 1].reshape(self.pred_h, 1, self.zone_size, + self.zone_size) + + if self.eval: + inputs, target = torch.tensor(inputs, dtype=torch.float), torch.tensor(target, dtype=torch.float) + mu_clim, std_clim = (self.climato[file_id][0][idx_id + 1: idx_id + self.pred_h + 1], + self.climato[file_id][1][idx_id + 1: idx_id + self.pred_h + 1]) + mu_norm, std_norm = (self.cst[file_id][0][idx_id + 1: idx_id + self.pred_h + 1], + self.cst[file_id][1][idx_id + 1: idx_id + self.pred_h + 1]) + return inputs, target, mu_clim, std_clim, mu_norm, std_norm + else: + return torch.tensor(inputs, dtype=torch.float), torch.tensor(target, dtype=torch.float) diff --git a/var_sep/data/wave_eq.py b/var_sep/data/wave_eq.py new file mode 100644 index 0000000..b778117 --- /dev/null +++ b/var_sep/data/wave_eq.py @@ -0,0 +1,90 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import re +import torch + +import numpy as np + +from torch.utils.data import Dataset + + +def extract_id(string): + return int(re.findall(r'\d+', string)[0]) + + +class WaveEq(Dataset): + def __init__(self, data_dir, nt_cond, seq_len, train, downsample): + super(WaveEq, self).__init__() + + self.nt_cond = nt_cond + self.seq_len = seq_len + + base_path = os.path.join(data_dir, 'data') + files = os.listdir(base_path) + files = [os.path.join(base_path, f) for f in files] + + self.train = train + max_seq = int(0.8 * len(files)) + + if train: + files = [file for file in files if extract_id(file) < max_seq] + else: + files = [file for file in files if extract_id(file) >= max_seq] + + self.size = len(files) + self.all_data = [] + + self.downsample = downsample + + for file in files: + data_dict = torch.load(file) + data = data_dict.get('simul') + max_, min_ = data.max(), data.min() + data = (data - min_) / (max_ - min_) + data = data[::self.downsample] + self.nt = len(data) + self.all_data.append(data) + + self.full_seq_len = data[0].size(0) + + def __len__(self): + return self.size * (self.full_seq_len - self.seq_len + 1) + + def __getitem__(self, idx): + # extract data of lenght seq_len + idx_seq = idx // (self.nt + 1 - self.seq_len) # seq is of size nt+1 + idx_in_seq = idx % (self.nt + 1 - self.seq_len) + full_state = self.all_data[idx_seq][idx_in_seq: idx_in_seq + self.seq_len].unsqueeze(1) + return full_state[:self.nt_cond], full_state[self.nt_cond: self.seq_len] + + +class WaveEqPartial(WaveEq): + + def __init__(self, data_dir, nt_cond, seq_len, train, downsample, n_pixels): + super(WaveEqPartial, self).__init__(data_dir, nt_cond, seq_len, train, downsample) + + data_dir = os.path.join(data_dir, 'pixels') + pixels = np.load(os.path.join(data_dir, 'pixels.npz'), allow_pickle=True) + self.rand_w = pixels['rand_w'] + self.rand_h = pixels['rand_h'] + self.n_wave_points = n_pixels + + def __getitem__(self, idx): + cond, target = super().__getitem__(idx) + cond = cond[:, :, self.rand_w[:self.n_wave_points], self.rand_h[:self.n_wave_points]] + target = target[:, :, self.rand_w[:self.n_wave_points], self.rand_h[:self.n_wave_points]] + return cond, target diff --git a/var_sep/main.py b/var_sep/main.py new file mode 100644 index 0000000..9fead7e --- /dev/null +++ b/var_sep/main.py @@ -0,0 +1,125 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import torch + +import numpy as np +import torch.backends.cudnn as cudnn +import torch.optim.lr_scheduler as lr_scheduler + +from torch import optim +from torch.utils.data import DataLoader + +from var_sep.data.moving_mnist import MovingMNIST +from var_sep.data.sst import SST +from var_sep.data.wave_eq import WaveEq, WaveEqPartial +from var_sep.networks.model import SeparableNetwork +from var_sep.networks.factory import get_encoder, get_decoder, get_resnet +from var_sep.networks.utils import ConstantS +from var_sep.options import parser +from var_sep.train import train + + +if __name__ == "__main__": + + # Arguments + args = parser.parse_args() + + # CPU / GPU + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + if args.device is None: + device = torch.device('cpu') + else: + cudnn.benchmark = True + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) + device = torch.device("cuda:0") + + # Seed + seed = np.random.randint(0, 10000) + torch.manual_seed(seed) + + # ######## + # DATASETS + # ######## + last_activation = None + if args.data == 'mnist': + train_set = MovingMNIST.make_dataset(args.data_dir, 64, args.nt_cond, args.nt_cond + args.nt_pred, 4, True, + args.n_object, True) + last_activation = 'sigmoid' + shape = [1, 64, 64] + elif args.data == "sst": + train_set = SST(args.data_dir, args.nt_cond, args.nt_pred, True, zones=args.zones) + shape = [1, 64, 64] + elif args.data == "wave": + train_set = WaveEq(args.data_dir, args.nt_cond, args.nt_cond + args.nt_pred, True, args.downsample) + last_activation = 'sigmoid' + shape = [1, 64, 64] + elif args.data == "wave_partial": + assert args.architecture not in ['dcgan', 'vgg'] + train_set = WaveEqPartial(args.data_dir, args.nt_cond, args.nt_cond + args.nt_pred, True, args.downsample, + args.n_wave_points) + last_activation = 'sigmoid' + shape = [1, args.n_wave_points] + + # Save params + with open(os.path.join(args.xp_dir, 'params.json'), 'w') as f: + json.dump(vars(args), f, indent=4, sort_keys=True) + + # ########### + # DATA LOADER + # ########### + def worker_init_fn(worker_id): + np.random.seed((torch.randint(100000, []).item() + worker_id)) + train_loader = DataLoader(train_set, batch_size=args.batch_size, pin_memory=False, shuffle=True, + num_workers=args.num_workers, worker_init_fn=worker_init_fn) + + # ######## + # NETWORKS + # ######## + if not args.no_s: + Es = get_encoder(args.architecture, shape, args.code_size_s, args.enc_hidden_size, args.nt_cond, + args.init_encoder, args.gain_encoder).to(device) + else: + # Es is constant and equal to one + assert not args.skipco + args.code_size_s = args.code_size_t + args.mixing = 'mul' + Es = ConstantS(return_value=1, code_size=args.code_size_s).to(device) + + Et = get_encoder(args.architecture, shape, args.code_size_t, args.enc_hidden_size, args.nt_cond, + args.init_encoder, args.gain_encoder).to(device) + + decoder = get_decoder(args.architecture, shape, args.code_size_t, args.code_size_s, last_activation, + args.dec_hidden_size, args.mixing, args.skipco, args.init_encoder, + args.gain_encoder).to(device) + + t_resnet = get_resnet(args.code_size_t, args.n_blocks, args.res_hidden_size, args.init_resnet, + args.gain_resnet).to(device) + + sep_net = SeparableNetwork(Es, Et, t_resnet, decoder, args.nt_cond, args.skipco) + + # ######### + # OPTIMIZER + # ######### + optimizer = optim.Adam(sep_net.parameters(), lr=args.lr, betas=(0.9, args.beta2)) + if args.scheduler is not None: + scheduler = lr_scheduler.MultiStepLR(optimizer, args.scheduler_milestones, gamma=args.scheduler_decay) + else: + scheduler = None + + train(args.xp_dir, train_loader, device, sep_net, optimizer, scheduler, args.apex_amp, args.epochs, args.lamb_ae, + args.lamb_s, args.lamb_t, args.lamb_pred, args.offset, args.nt_cond, args.nt_pred, args.no_s, args.skipco) diff --git a/var_sep/networks/conv.py b/var_sep/networks/conv.py new file mode 100644 index 0000000..b2fda13 --- /dev/null +++ b/var_sep/networks/conv.py @@ -0,0 +1,317 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# The following code is adapted from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights +# below. + +# Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +import torch.nn as nn + +from var_sep.networks.utils import activation_factory + + +def make_conv_block(conv, activation, bn=True): + """ + Supplements a convolutional block with activation functions and batch normalization. + Parameters + ---------- + conv : torch.nn.Module + Convolutional block. + activation : str + 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function after + the convolution. + bn : bool + Whether to add batch normalization after the activation. + """ + out_channels = conv.out_channels + modules = [conv] + if bn: + modules.append(nn.BatchNorm2d(out_channels)) + if activation != 'none': + modules.append(activation_factory(activation)) + return nn.Sequential(*modules) + + +class BaseEncoder(nn.Module): + """ + Module implementing the encoders forward method. + Attributes + ---------- + nh : int + Number of dimensions of the output flat vector. + """ + def __init__(self, nh): + """ + Parameters + ---------- + nh : int + Number of dimensions of the output flat vector. + """ + super(BaseEncoder, self).__init__() + self.nh = nh + + def forward(self, x, return_skip=False): + """ + Parameters + ---------- + x : torch.Tensor + Encoder input. + return_skip : bool + Whether to extract and return, besides the network output, skip connections. + """ + x = x.view(x.size(0), -1, x.size(3), x.size(4)) + skips = [] + h = x + for layer in self.conv: + h = layer(h) + skips.append(h) + h = self.last_conv(h).view(-1, self.nh) + if return_skip: + return h, skips[::-1] + return h + + +class DCGAN64Encoder(BaseEncoder): + """ + Module implementing the DCGAN encoder. + """ + def __init__(self, nc, nh, nf): + """ + Parameters + ---------- + nc : int + Number of channels in the input data. + nh : int + Number of dimensions of the output flat vector. + nf : int + Number of filters per channel of the first convolution. + """ + super(DCGAN64Encoder, self).__init__(nh) + self.conv = nn.ModuleList([ + make_conv_block(nn.Conv2d(nc, nf, 4, 2, 1, bias=False), activation='leaky_relu', bn=False), + make_conv_block(nn.Conv2d(nf, nf * 2, 4, 2, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 2, nf * 4, 4, 2, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False), activation='leaky_relu') + ]) + self.last_conv = make_conv_block(nn.Conv2d(nf * 8, nh, 4, 1, 0, bias=False), activation='tanh') + + +class VGG64Encoder(BaseEncoder): + """ + Module implementing the VGG encoder. + """ + def __init__(self, nc, nh, nf): + """ + Parameters + ---------- + nc : int + Number of channels in the input data. + nh : int + Number of dimensions of the output flat vector. + nf : int + Number of filters per channel of the first convolution. + """ + super(VGG64Encoder, self).__init__(nh) + self.conv = nn.ModuleList([ + nn.Sequential( + make_conv_block(nn.Conv2d(nc, nf, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf, nf, 3, 1, 1, bias=False), activation='leaky_relu'), + ), + nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2, padding=0), + make_conv_block(nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), + ), + nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2, padding=0), + make_conv_block(nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), + ), + nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2, padding=0), + make_conv_block(nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False), activation='leaky_relu'), + ) + ]) + self.last_conv = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2, padding=0), + make_conv_block(nn.Conv2d(nf * 8, nh, 4, 1, 0, bias=False), activation='tanh') + ) + + +class BaseDecoder(nn.Module): + """ + Module implementing the decoders forward method. + + Attributes + ---------- + ny : int + Number of dimensions of the output flat vector. + skip : bool + Whether to include skip connections into the decoder. + mixing : str + 'mul' or 'concat'. Whether to multiply both inputs, or concatenate them. + """ + def __init__(self, ny, skip, last_activation, mixing): + """ + Parameters + ---------- + ny : int + Number of dimensions of the output flat vector. + skip : bool + Whether to include skip connections into the decoder. + last_activation : str + 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function after + the last convolution. + mixing : str + 'mul' or 'concat'. Whether to multiply both inputs, or concatenate them. + """ + super(BaseDecoder, self).__init__() + self.ny = ny + self.skip = skip + self.mixing = mixing + + self.last_activation = activation_factory(last_activation) + + def forward(self, z1, z2, skip=None): + """ + Parameters + ---------- + z1 : torch.Tensor + First decoder input (S). + z2 : torch.Tensor + Second decoder input (S). + skip : list + List of tensors representing skip connections. + """ + assert skip is None and not self.skip or self.skip and skip is not None + + if self.mixing == 'concat': + z = torch.cat([z1, z2], dim=1) + else: + z = z1 * z2 + + h = self.first_upconv(z.view(*z.shape, 1, 1)) + for i, layer in enumerate(self.conv): + if skip is not None: + h = torch.cat([h, skip[i]], 1) + h = layer(h) + return self.last_activation(h) + + +class DCGAN64Decoder(BaseDecoder): + """ + Module implementing the DCGAN decoder. + """ + def __init__(self, nc, ny, nf, skip, last_activation, mixing): + """ + Parameters + ---------- + nc : int + Number of channels in the output shape. + ny : int + Number of dimensions of the input flat vector. + nf : int + Number of filters per channel of the first convolution of the mirror encoder architecture. + skip : list + List of tensors representing skip connections. + last_activation : str + 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function after + the last convolution. + mixing : str + 'mul' or 'concat'. Whether to multiply both inputs, or concatenate them. + """ + super(DCGAN64Decoder, self).__init__(ny, skip, last_activation, mixing) + # decoder + coef = 2 if skip else 1 + self.first_upconv = make_conv_block(nn.ConvTranspose2d(ny, nf * 8, 4, 1, 0, bias=False), activation='leaky_relu') + self.conv = nn.ModuleList([ + make_conv_block(nn.ConvTranspose2d(nf * 8 * coef, nf * 4, 4, 2, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.ConvTranspose2d(nf * 4 * coef, nf * 2, 4, 2, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.ConvTranspose2d(nf * 2 * coef, nf, 4, 2, 1, bias=False), activation='leaky_relu'), + nn.ConvTranspose2d(nf * coef, nc, 4, 2, 1, bias=False), + ]) + + +class VGG64Decoder(BaseDecoder): + """ + Module implementing the VGG decoder. + """ + def __init__(self, nc, ny, nf, skip, last_activation, mixing): + """ + Parameters + ---------- + nc : int + Number of channels in the output shape. + ny : int + Number of dimensions of the input flat vector. + nf : int + Number of filters per channel of the first convolution of the mirror encoder architecture. + skip : list + List of tensors representing skip connections. + last_activation : str + 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function after + the last convolution. + mixing : str + 'mul' or 'concat'. Whether to multiply both inputs, or concatenate them. + """ + super(VGG64Decoder, self).__init__(ny, skip, last_activation, mixing) + # decoder + coef = 2 if skip else 1 + self.first_upconv = nn.Sequential( + make_conv_block(nn.ConvTranspose2d(ny, nf * 8, 4, 1, 0, bias=False), activation='leaky_relu'), + nn.Upsample(scale_factor=2, mode='nearest'), + ) + self.conv = nn.ModuleList([ + nn.Sequential( + make_conv_block(nn.Conv2d(nf * 8 * coef, nf * 8, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 8, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), + nn.Upsample(scale_factor=2, mode='nearest'), + ), + nn.Sequential( + make_conv_block(nn.Conv2d(nf * 4 * coef, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 4, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), + nn.Upsample(scale_factor=2, mode='nearest'), + ), + nn.Sequential( + make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), + make_conv_block(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=False), activation='leaky_relu'), + nn.Upsample(scale_factor=2, mode='nearest'), + ), + nn.Sequential( + make_conv_block(nn.Conv2d(nf * coef, nf, 3, 1, 1, bias=False), activation='leaky_relu'), + nn.ConvTranspose2d(nf, nc, 3, 1, 1, bias=False), + ), + ]) diff --git a/var_sep/networks/factory.py b/var_sep/networks/factory.py new file mode 100644 index 0000000..5daf4c2 --- /dev/null +++ b/var_sep/networks/factory.py @@ -0,0 +1,67 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder +from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder +from var_sep.networks.resnet import MLPResnet +from var_sep.networks.utils import init_net + + +def get_encoder(nn_type, shape, output_size, hidden_size, nt_cond, init_type, init_gain): + nc = shape[0] + if nn_type == 'dcgan': + encoder = DCGAN64Encoder(nc * nt_cond, output_size, hidden_size) + elif nn_type == 'vgg': + encoder = VGG64Encoder(nc * nt_cond, output_size, hidden_size) + elif nn_type in ['mlp', 'large_mlp']: + input_size = nt_cond * np.prod(np.array(shape)) + encoder = MLPEncoder(input_size, hidden_size, output_size, 3) + + init_net(encoder, init_type=init_type, init_gain=init_gain) + + return encoder + + +def get_decoder(nn_type, shape, code_size_t, code_size_s, last_activation, hidden_size, mixing, skipco, init_type, + init_gain): + assert not skipco or nn_type in ['dcgan', 'vgg'] + + if mixing == 'mul': + assert code_size_t == code_size_s + input_size = code_size_t + else: + input_size = code_size_t + code_size_s + + nc = shape[0] + if nn_type == 'dcgan': + decoder = DCGAN64Decoder(nc, input_size, hidden_size, skipco, last_activation, mixing) + elif nn_type == 'vgg': + decoder = VGG64Decoder(nc, input_size, hidden_size, skipco, last_activation, mixing) + elif nn_type == 'mlp': + decoder = MLPDecoder(input_size, hidden_size, shape, 3, last_activation, mixing) + elif nn_type == 'large_mlp': + decoder = MLPDecoder(input_size, hidden_size, shape, 4, last_activation, mixing) + + init_net(decoder, init_type=init_type, init_gain=init_gain) + + return decoder + + +def get_resnet(latent_size, n_blocks, hidden_size, init_type, gain_res): + resnet = MLPResnet(latent_size, n_blocks, hidden_size) + init_net(resnet, init_type=init_type, init_gain=gain_res) + return resnet diff --git a/var_sep/networks/mlp.py b/var_sep/networks/mlp.py new file mode 100644 index 0000000..e5f927a --- /dev/null +++ b/var_sep/networks/mlp.py @@ -0,0 +1,76 @@ +# The following code is taken from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights +# below. + +# Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn + +from var_sep.networks.utils import activation_factory + + +def make_lin_block(ninp, nout, activation): + """ + Creates a linear block formed by an activation function and a linear operation. + Parameters + ---------- + ninp : int + Input dimension. + nout : int + Output dimension. + activation : str + 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function before + the linear operation. + """ + modules = [] + if activation != 'none': + modules.append(activation_factory(activation)) + modules.append(nn.Linear(ninp, nout)) + return nn.Sequential(*modules) + + +class MLP(nn.Module): + """ + Module implementing an MLP. + """ + def __init__(self, ninp, nhid, nout, nlayers, activation='relu'): + """ + Parameters + ---------- + ninp : int + Input dimension. + nhid : int + Number of dimensions in intermediary layers. + nout : int + Output dimension. + nlayers : int + Number of layers in the MLP. + activation : str + 'relu', 'leaky_relu', 'elu', 'sigmoid', or 'tanh'. Adds the corresponding activation function before the + linear operation. + """ + super().__init__() + assert nhid == 0 or nlayers > 1 + modules = [ + make_lin_block( + ninp=ninp if il == 0 else nhid, + nout=nout if il == nlayers - 1 else nhid, + activation=activation if il > 0 else 'none', + ) for il in range(nlayers) + ] + self.module = nn.Sequential(*modules) + + def forward(self, x): + return self.module(x) diff --git a/var_sep/networks/mlp_encdec.py b/var_sep/networks/mlp_encdec.py new file mode 100644 index 0000000..38219f9 --- /dev/null +++ b/var_sep/networks/mlp_encdec.py @@ -0,0 +1,50 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +import numpy as np +import torch.nn as nn + +from var_sep.networks.mlp import MLP +from var_sep.networks.utils import activation_factory + + +class MLPEncoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size, nlayers): + super(MLPEncoder, self).__init__() + self.mlp = MLP(input_size, hidden_size, output_size, nlayers) + + def forward(self, x, return_skip=False): + x = x.view(len(x), -1) + return self.mlp(x) + + +class MLPDecoder(nn.Module): + def __init__(self, latent_size, hidden_size, output_shape, nlayers, last_activation, mixing): + super(MLPDecoder, self).__init__() + self.output_shape = output_shape + self.mixing = mixing + self.mlp = MLP(latent_size, hidden_size, np.prod(np.array(output_shape)), nlayers) + self.last_activation = activation_factory(last_activation) + + def forward(self, z1, z2, skip=None): + if self.mixing == 'concat': + z = torch.cat([z1, z2], dim=1) + else: + z = z1 * z2 + x = self.mlp(z) + x = self.last_activation(x) + return x.view([-1] + self.output_shape) diff --git a/var_sep/networks/model.py b/var_sep/networks/model.py new file mode 100644 index 0000000..678efde --- /dev/null +++ b/var_sep/networks/model.py @@ -0,0 +1,92 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class SeparableNetwork(nn.Module): + + def __init__(self, Es, Et, t_resnet, decoder, nt_cond, skipco): + super(SeparableNetwork, self).__init__() + + assert isinstance(Es, nn.Module) + assert isinstance(Et, nn.Module) + assert isinstance(t_resnet, nn.Module) + assert isinstance(decoder, nn.Module) + + # Networks + self.Es = Es + self.Et = Et + self.decoder = decoder + self.t_resnet = t_resnet + + # Attributes + self.nt_cond = nt_cond + self.skipco = skipco + + # Gradient-enabling parameter + self.__grad = True + + @property + def grad(self): + return self.__grad + + @grad.setter + def grad(self, grad): + assert isinstance(grad, bool) + self.__grad = grad + + def get_forecast(self, cond, n_forecast, init_t_code=None, init_s_code=None): + s_codes = [] + t_codes = [] + forecasts = [] + t_residuals = [] + + if init_s_code is None: + s_code = self.Es(cond, return_skip=self.skipco) + else: + s_code = init_s_code + if self.skipco: + s_code, s_skipco = s_code + else: + s_skipco = None + + if init_t_code is None: + t_code = self.Et(cond) + else: + t_code = init_t_code + + s_codes.append(s_code) + t_codes.append(t_code) + + # Decode first frame + forecast = self.decoder(s_code, t_code, skip=s_skipco) + forecasts.append(forecast) + + # Forward prediction + for t in range(1, n_forecast): + t_code, t_res = self.t_resnet(t_code) + t_codes.append(t_code) + t_residuals.append(t_res) + forecast = self.decoder(s_code, t_code, skip=s_skipco) + forecasts.append(forecast) + + # Stack predictions + forecasts = torch.cat([x.unsqueeze(1) for x in forecasts], dim=1) + t_codes = torch.cat([x.unsqueeze(1) for x in t_codes], dim=1) + s_codes = torch.cat([x.unsqueeze(1) for x in s_codes], dim=1) + + return forecasts, t_codes, s_codes, t_residuals diff --git a/var_sep/networks/resnet.py b/var_sep/networks/resnet.py new file mode 100644 index 0000000..96f6d86 --- /dev/null +++ b/var_sep/networks/resnet.py @@ -0,0 +1,52 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn + +from var_sep.networks.mlp import MLP + + +class MLPResBlock(nn.Module): + def __init__(self, input_size, hidden_size): + super(MLPResBlock, self).__init__() + self.mlp = MLP(input_size, hidden_size, input_size, 3) + + def forward(self, x): + residual = self.mlp(x) + return x + residual, residual + + +class MLPResnet(nn.Module): + def __init__(self, input_size, n_blocks, hidden_size): + super(MLPResnet, self).__init__() + + self.in_size = input_size + self.n_blocks = n_blocks + + blocks = [] + for i in range(self.n_blocks): + blocks += [MLPResBlock(input_size, hidden_size)] + self.blocks = nn.ModuleList(blocks) + + def forward(self, x, return_res=True): + residuals = [] + for j in range(self.n_blocks): + x, res = self.blocks[j].forward(x) + residuals.append(res) + + if return_res: + return x, residuals + else: + return x diff --git a/var_sep/networks/utils.py b/var_sep/networks/utils.py new file mode 100644 index 0000000..77e4a4c --- /dev/null +++ b/var_sep/networks/utils.py @@ -0,0 +1,109 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +import torch.nn as nn + + +class ConstantS(nn.Module): + def __init__(self, return_value=1, code_size=1): + super(ConstantS, self).__init__() + self.code_size = code_size + self.return_value = return_value + + def forward(self, x, return_skip=False): + bt_size = len(x) + return torch.ones(bt_size, self.code_size).to(x) * self.return_value + + +# The following code is adapted from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights +# below. + +# Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def activation_factory(name): + """ + Returns the activation layer corresponding to the input activation name. + Parameters + ---------- + name : str + 'relu', 'leaky_relu', 'elu', 'sigmoid', or 'tanh'. Adds the corresponding activation function after the + convolution. + """ + if name == 'relu': + return nn.ReLU(inplace=True) + if name == 'leaky_relu': + return nn.LeakyReLU(0.2, inplace=True) + if name == 'elu': + return nn.ELU(inplace=True) + if name == 'sigmoid': + return nn.Sigmoid() + if name == 'tanh': + return nn.Tanh() + if name is None or name == "identity": + return nn.Identity() + + raise ValueError(f'Activation function `{name}` not yet implemented') + + +def init_net(net, init_type='normal', init_gain=0.02): + """ + Initializes the input module with the given parameters. + Parameters + ---------- + m : torch.nn.Module + Module to initialize. + init_type : str + 'normal', 'xavier', 'kaiming', or 'orthogonal'. Orthogonal initialization types for convolutions and linear + operations. + init_gain : float + Gain to use for the initialization. + """ + def init_func(m): # Define the initialization function + classname = m.__class__.__name__ + if classname in ('Conv2d', 'ConvTranspose2d', 'Linear'): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif classname == 'BatchNorm2d': + if m.weight is not None: + nn.init.normal_(m.weight.data, 1.0, init_gain) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + net.apply(init_func) # Iterate the initialization function on the modules diff --git a/var_sep/options.py b/var_sep/options.py new file mode 100644 index 0000000..a174011 --- /dev/null +++ b/var_sep/options.py @@ -0,0 +1,115 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse + + +DATASETS = ['wave', 'wave_partial', 'sst', 'mnist'] +ARCH_TYPES = ['dcgan', 'vgg', 'mlp', 'large_mlp'] +INITIALIZATIONS = ['orthogonal', 'kaiming', 'normal'] +MIXING = ['concat', 'mul'] + + +parser = argparse.ArgumentParser(prog="PDE-Driven Spatiotemporal Disentanglement (training)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument('--xp_dir', type=str, metavar='DIR', required=True, + help='Directory where models will be saved.') + +distr_p = parser.add_argument_group(title='Distributed', + description='Options for training on GPUs and distributed dataset loading.') +distr_p.add_argument('--apex_amp', action='store_true', + help='Whether to use Nvidia\'s Apex mixed-precision training.') +distr_p.add_argument('--device', type=int, metavar='DEVICE', default=0, + help='If not None, indicates the index of the GPU to use.') +distr_p.add_argument('--num_workers', type=int, metavar='NB', default=4, + help='Number of childs processes for data loading.') + +config_p = parser.add_argument_group(title='Experiment Configuration', + description='Model and loss parameters.') + +config_p.add_argument('--nt_cond', type=int, metavar='COND', default=5, + help='Number of conditioning observations') +config_p.add_argument('--nt_pred', type=int, metavar='PRED', default=10, + help='Number of observations to predict') +config_p.add_argument('--lamb_ae', type=float, metavar='LAMBDA', default=10, + help='Multiplier of the autoencoding loss.') +config_p.add_argument('--lamb_s', type=float, metavar='LAMBDA', default=45, + help='Multiplier of the S invariance loss.') +config_p.add_argument('--lamb_t', type=float, metavar='LAMBDA', default=0.001, + help='Multiplier of the T regularization loss.') +config_p.add_argument('--lamb_pred', type=float, metavar='LAMBDA', default=45, + help='Multiplier of the prediction loss.') +config_p.add_argument('--architecture', type=str, metavar='ARCH', default='dcgan', choices=ARCH_TYPES, + help='Encoder and decoder architecture.') +config_p.add_argument('--skipco', action='store_true', + help='Whether to use skip connections from encoders to decoders.') +config_p.add_argument('--res_hidden_size', type=int, metavar='SIZE', default=512, + help='Hidden size of MLPs in the residual integrator.') +config_p.add_argument('--enc_hidden_size', type=int, metavar='SIZE', default=64, + help='Hidden size of MLP encoders, or number of filters in convolutional encoders.') +config_p.add_argument('--dec_hidden_size', type=int, metavar='SIZE', default=64, + help='Hidden size of MLP decoders, or number of filters in convolutional decoders.') +config_p.add_argument('--n_blocks', type=int, metavar='BLOCKS', default=1, + help='Number of resblocks in the residual integrator.') +config_p.add_argument('--code_size_s', type=int, metavar='SIZE', default=128, + help='Number of dimensions in S (without skip connections).') +config_p.add_argument('--code_size_t', type=int, metavar='SIZE', default=20, + help='Number of dimensions in T.') +config_p.add_argument('--mixing', type=str, metavar='MIXING', default='concat', choices=MIXING, + help='Whether to concatenate or multiply S and T; in the latter case, their dimensions ' + + 'be equal.') +config_p.add_argument('--init_encoder', type=str, metavar='INIT', default='orthogonal', choices=INITIALIZATIONS, + help='Initialization type of the encoder and the decoder.') +config_p.add_argument('--gain_encoder', type=float, metavar='GAIN', default=0.02, + help='Initialization gain of the encoder and the decoder.') +config_p.add_argument('--init_resnet', type=str, metavar='INIT', default='orthogonal', choices=INITIALIZATIONS, + help='Initialization type of the linear layers of the MLP blocks in the integrator.') +config_p.add_argument('--gain_resnet', type=float, metavar='GAIN', default=1.41, + help='Initialization gain of the linear layers of the MLP blocks in the integrator.') +config_p.add_argument('--no_s', action='store_true', + help='If activated, desactivates the static component.') +config_p.add_argument('--offset', type=int, metavar='SIZE', default=5, + help='When non-zero and equal to the number of conditioning frames, reconstructs ' + + 'conditioning observations, besides forecasting future observations.') +config_p.add_argument('--batch_size', type=int, metavar='SIZE', default=128, + help='Training batch size.') +config_p.add_argument('--lr', type=float, metavar='LR', default=4e-4, + help='Learning rate of Adam optimizer.') +config_p.add_argument('--beta2', type=float, metavar='BETA', default=0.99, + help='Second-order decay parameter of the Adam optimizer.') +config_p.add_argument('--epochs', type=int, metavar='EPOCH', default=200, + help='Number of epochs to train on.') +config_p.add_argument('--scheduler', action='store_true', + help='If activated, uses a scheluder dividing the learning rate at given epoch milestones.') +config_p.add_argument('--scheduler_decay', type=float, metavar='DECAY', default=0.5, + help='Multiplier to learning rate applied at each scheduler milestone.') +config_p.add_argument('--scheduler_milestones', type=int, nargs='+', metavar='EPOCHS', default=[300, 400, 500, 600], + help='Scheduler epoch milestones where the learning rate is multiplied by the decay parameter.') + +data_p = parser.add_argument_group(title='Dataset', + description='Chosen dataset and dataset parameters.') +data_p.add_argument('--data', type=str, metavar='DATASET', default='mnist', choices=DATASETS, + help='Dataset choice.') +data_p.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Data directory.') +parser.add_argument('--downsample', type=int, metavar='DOWNSAMPLE', default=2, + help='Set the sampling rate for the WaveEq dataset.') +parser.add_argument('--n_wave_points', type=int, metavar='NUMBER', default=100, + help='Number of random pixels to select for partial WaveEq (WaveEq-100).') +parser.add_argument('--zones', type=int, metavar='ZONES', default=list(range(1, 30)), nargs='+', + help='SST zones to train on.') +parser.add_argument('--n_object', type=int, metavar='NUMBER', default=2, + help='Number of digits in the Moving MNIST data.') diff --git a/var_sep/preprocessing/mnist/make_test_set.py b/var_sep/preprocessing/mnist/make_test_set.py new file mode 100644 index 0000000..e2f22e2 --- /dev/null +++ b/var_sep/preprocessing/mnist/make_test_set.py @@ -0,0 +1,100 @@ +# Code adapted from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights below. + +# Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +import numpy as np + +from os.path import join +from tqdm import trange +from torchvision import datasets + +from var_sep.data.moving_mnist import MovingMNIST + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='Moving MNIST testing set generation.', + description='Generates the Moving MNIST testing set. Videos and latent space (position, speed) are saved in \ + an npz file.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Folder where the testing set will be saved.') + parser.add_argument('--seq_len', type=int, metavar='LEN', default=100, + help='Number of frames per testing sequences.') + parser.add_argument('--seed', type=int, metavar='SEED', default=42, + help='Fixed NumPy seed to produce the same dataset at each run.') + parser.add_argument('--digits', type=int, metavar='NUM', default=2, + help='Number of digits to appear in each video.') + parser.add_argument('--frame_size', type=int, metavar='SIZE', default=64, + help='Size of generated frames.') + parser.add_argument('--max_speed', type=int, metavar='SPEED', default=4, + help='Maximum speed of generated trajectories.') + args = parser.parse_args() + + # Fix random seed + np.random.seed(args.seed) + + # Load digits and shuffle them + digits = datasets.MNIST(args.data_dir, train=False, download=True) + digits_idx = np.random.permutation(len(digits)) + # Random trajectories are made using the dataset code + trajectory_sampler = MovingMNIST([], args.frame_size, 0, args.seq_len, args.max_speed, True, args.digits, True) + # Register videos, latent space (position, speed), labels of digits and digit images + test_videos = [] + test_latents = [] + test_labels = [] + test_objects = [] + # The size of the testing set is the total number of testing digits in MNIST divided by the number of digits + for i in trange(len(digits) // args.digits): + x = np.zeros((args.seq_len, 1, args.frame_size, args.frame_size), dtype=np.float32) + latents = [] + labels = [] + objects = [] + # Pick the digits randomly chosen for sequence i and generate their trajectories + for n in range(args.digits): + img, label = digits[digits_idx[i * args.digits + n]] + img = np.array(img, dtype=np.uint8) + trajectory = trajectory_sampler._compute_trajectory(*img.shape) + latents.append(np.array(trajectory)) + labels.append(label) + objects.append(img) + for t in range(args.seq_len): + sx, sy, _, _ = trajectory[t] + x[t, 0, sx:sx + img.shape[0], sy:sy + img.shape[1]] += img + x[x > 255] = 255 + # Register video and other information + test_videos.append(x.astype(np.uint8)) + test_latents.append(np.array(latents)) + test_labels.append(np.array(labels).astype(np.uint8)) + test_objects.append(np.array(objects)) + # Stack computed videos and other information + test_videos = np.array(test_videos, dtype=np.uint8).transpose(1, 0, 2, 3, 4) + test_latents = np.array(test_latents).transpose(2, 0, 1, 3) + test_labels = np.array(test_labels, dtype=np.uint8) + test_objects = np.array(test_objects) + + # Save results at the given path + fname = f'mmnist_test_{args.digits}digits_{args.frame_size}.npz' + print(f'Saving testset at {join(args.data_dir, fname)}') + # Create the directory if needed + if not os.path.isdir(args.data_dir): + os.makedirs(args.data_dir) + np.savez_compressed(join(args.data_dir, fname), + sequences=test_videos, latents=test_latents, labels=test_labels, digits=test_objects) diff --git a/var_sep/preprocessing/wave/gen_pixels.py b/var_sep/preprocessing/wave/gen_pixels.py new file mode 100644 index 0000000..42d071e --- /dev/null +++ b/var_sep/preprocessing/wave/gen_pixels.py @@ -0,0 +1,52 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='Choice of sample pixels for the WaveEq-100 dataset.', + description='Generates the pixels used for the WaveEqPartial (WaveEq-100 if 100 pixels are drawn) dataset. \ + Chosen coordinates are save in an npz file, in `rand_w` and `rand_h` fields.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Folder where the data will be saved.') + parser.add_argument('--number', type=int, metavar='NUM', default=100, + help='Number of pixels to generate.') + parser.add_argument('--frame_size', type=int, metavar='SIZE', default=64, + help='Size of frames (bound on pixel values).') + parser.add_argument('--seed', type=int, metavar='SEED', default=42, + help='Fixed NumPy seed to produce the same dataset at each run.') + args = parser.parse_args() + + # Fix random seed + np.random.seed(args.seed) + + # Create the directory if needed + data_dir = os.path.join(args.data_dir, 'pixels') + if not os.path.isdir(data_dir): + os.makedirs(data_dir) + + # Generate pixel coordinates + rand_w = np.random.randint(args.frame_size, size=args.number) + rand_h = np.random.randint(args.frame_size, size=args.number) + + # Save coordinates + np.savez_compressed(os.path.join(data_dir, 'pixels.npz'), rand_w=rand_w, rand_h=rand_h) diff --git a/var_sep/preprocessing/wave/gen_wave.py b/var_sep/preprocessing/wave/gen_wave.py new file mode 100644 index 0000000..213e9a7 --- /dev/null +++ b/var_sep/preprocessing/wave/gen_wave.py @@ -0,0 +1,165 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import torch + +import numpy as np + +from functools import partial +from torchdiffeq import odeint +from tqdm import trange + + +def decreasing_energy_source(t, invT0, f0): + return f0 * np.exp(-invT0 * t) + + +def circle_idx(center=(32, 32), r=5): + cols, rows = np.meshgrid(range(64), range(64)) + idx = (cols - center[0])**2 + (rows-center[1])**2 < r**2 + return idx + + +def derivative(t, y, source, c, nx, ny, dx, dz, circular, order): + """ + The wave equation is defined by w'' = c2 LAP w. + + We generate the data by using the partial differential equation + on the full state (w,w'): + + (w, w') = (w', c2 LAP w) + + This function implements the mapping: + (w, w') --> (w', w'') + + via finite difference scheme for data generation to allow integration + using torchdiffeq further. + """ + + # if a circular source is to be used + if circular: + circle_mask = circle_idx().astype(float) + circle_mask = torch.tensor(circle_mask, dtype=torch.float) + else: + circle_mask = circle_idx(r=1).astype(float) + circle_mask = torch.tensor(circle_mask, dtype=torch.float) + + # State : (2, 64, 64) + state = y[0] # (64, 64) + state_diff = y[1] # (64, 64) + + state_yy = torch.zeros(state_diff.shape) + state_xx = torch.zeros(state_diff.shape) + + # Calculate partial derivatives, be careful around the boundaries + if order == 3: + # Third order + for i in range(1, ny - 1): + state_yy[:, i] = state[:, i + 1] - 2 * state[:, i] + state[:, i - 1] + + for j in range(1, nx - 1): + state_xx[j, :] = state[j - 1, :] - 2 * state[j, :] + state[j + 1, :] + elif order == 5: + # Fifth order + for i in range(2, nx - 2): + state_yy[:, i] = (-1 / 12 * state[:, i + 2] + 4 / 3 * state[:, i + 1] - 5 / 2 * state[:, i] + + 4 / 3 * state[:, i - 1] - 1 / 12 * state[:, i - 2]) + for j in range(2, ny - 2): + state_xx[j, :] = (-1 / 12 * state[j + 2, :] + 4 / 3 * state[j + 1, :] - 5 / 2 * state[j, :] + + 4 / 3 * state[j - 1, :] - 1 / 12 * state[j - 2, :]) + + lap = (c**2) * (state_yy + state_xx) / dx**2 + + if source is not None: + lap = source(t.item()) * circle_mask + lap + + derivative = torch.cat([state_diff.unsqueeze(0), lap.unsqueeze(0)], 0) + + return derivative + + +def generate(size, frame_size, seq_len, dt, data_dir): + """ + Generates the WaveEq dataset in folder \'data\' of the given directory as pt files `homogenous_wave${INDEXSEQ}.pt` + where $INDEXSEQ is the index of the sequence, each containing the following fields: + - `simul`: float tensor of dimensions (length, width, height) representing a sequence; + - `c`: celocity coefficient used to create the associated sequence in `simul`. + + Parameters + ---------- + size : int + Number of sequences to generate (size of the dataset). + frame_size : int + Width and height of the sequences. + seq_len : int + Length of generated sequences. + dt : float + Step size of ODE solver and time interval between each frame. + data_dir : str + Directory where the folder `data` will be created. + """ + # Create the directory if needed + data_dir = os.path.join(data_dir, 'data') + if not os.path.isdir(data_dir): + os.mkdir(data_dir) + + # Generate all sequences + for i in trange(size): + # Source + source_init_value = np.random.uniform(1, 30) + source = partial(decreasing_energy_source, invT0=20, f0=source_init_value) + + # Null initial condition + initial_condition = torch.zeros(1, frame_size, frame_size) + + # Velocity coefficient + c = np.random.uniform(300, 400) + + # Numerically solving wave equation + dF = partial(derivative, source=source, c=c, nx=frame_size, ny=frame_size, dx=1, dz=1, circular=True, order=5) + t = torch.arange(0, dt * seq_len, dt) + simul = odeint(dF, y0=initial_condition.expand((2, frame_size, frame_size)), t=t, method="rk4")[:, 0] + print(simul.size()) + # Save sequences and velocities coefficients + torch.save({'simul': simul, 'c': c}, os.path.join(data_dir, f'homogenous_wave{i}.pt')) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='WaveEq preprocessing.', + description='Generates the WaveEq dataset in folder \'data\' of the given directory as pt files.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Folder where the data will be saved.') + parser.add_argument('--seq_len', type=int, metavar='LEN', default=300, + help='Length of generated sequences.') + parser.add_argument('--seed', type=int, metavar='SEED', default=42, + help='Fixed NumPy seed to produce the same dataset at each run.') + parser.add_argument('--frame_size', type=int, metavar='SIZE', default=64, + help='Size of generated frames.') + parser.add_argument('--size', type=int, metavar='SIZE', default=300, + help='Number of sequences to generate (size of the dataset).') + parser.add_argument('--dt', type=int, metavar='SIZE', default=0.001, + help='Step size of ODE solver and time interval between each frame.') + args = parser.parse_args() + + # Fix random seed + np.random.seed(args.seed) + + # Generate dataset + generate(args.size, args.frame_size, args.seq_len, args.dt, args.data_dir) diff --git a/var_sep/test/mnist/test.py b/var_sep/test/mnist/test.py new file mode 100644 index 0000000..6f74d12 --- /dev/null +++ b/var_sep/test/mnist/test.py @@ -0,0 +1,207 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Code adapted from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights below. + +# Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import random +import torch + +import numpy as np +import torch.nn.functional as F + +from collections import defaultdict +from torch.utils.data import DataLoader +from tqdm import tqdm + +from var_sep.data.moving_mnist import MovingMNIST +from var_sep.utils.helper import load_json +from var_sep.utils.ssim import ssim_loss +from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder +from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder +from var_sep.networks.model import SeparableNetwork + + +def _ssim_wrapper(pred, gt): + bsz, nt_pred = pred.shape[0], pred.shape[1] + img_shape = pred.shape[2:] + ssim = ssim_loss(pred.reshape(bsz * nt_pred, *img_shape), gt.reshape(bsz * nt_pred, *img_shape), max_val=1., + reduction='none') + return ssim.mean(dim=[2, 3]).view(bsz, nt_pred, img_shape[0]) + + +def load_dataset(args, train=False): + return MovingMNIST.make_dataset(args.data_dir, 64, args.nt_cond, args.nt_cond + args.nt_pred, 4, True, + args.n_object, train) + + +def build_model(args): + Es = torch.load(os.path.join(args.xp_dir, 'ov_Es.pt'), map_location=args.device).to(args.device) + Et = torch.load(os.path.join(args.xp_dir, 'ov_Et.pt'), map_location=args.device).to(args.device) + t_resnet = torch.load(os.path.join(args.xp_dir, 't_resnet.pt'), map_location=args.device).to(args.device) + decoder = torch.load(os.path.join(args.xp_dir, 'decoder.pt'), map_location=args.device).to(args.device) + sep_net = SeparableNetwork(Es, Et, t_resnet, decoder, args.nt_cond, args.skipco) + sep_net.eval() + return sep_net + + +def main(args): + ################################################################################################################## + # Setup + ################################################################################################################## + # -- Device handling (CPU, GPU) + if args.device is None: + device = torch.device('cpu') + else: + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) + device = torch.device('cuda:0') + torch.cuda.set_device(0) + # Seed + random.seed(args.test_seed) + np.random.seed(args.test_seed) + torch.manual_seed(args.test_seed) + # Load XP config + xp_config = load_json(os.path.join(args.xp_dir, 'params.json')) + xp_config.device = device + xp_config.data_dir = args.data_dir + xp_config.xp_dir = args.xp_dir + xp_config.nt_pred = args.nt_pred + + + ################################################################################################################## + # Load test data + ################################################################################################################## + print('Loading data...') + test_dataset = load_dataset(xp_config, train=False) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, pin_memory=True) + train_dataset = load_dataset(xp_config, train=True) + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True) + nc = 1 + size = 64 + + ################################################################################################################## + # Load model + ################################################################################################################## + print('Loading model...') + sep_net = build_model(xp_config) + + ################################################################################################################## + # Eval + ################################################################################################################## + print('Generating samples...') + torch.set_grad_enabled(False) + train_iterator = iter(train_loader) + nt_test = xp_config.nt_cond + args.nt_pred + predictions = [] + content_swap = [] + cond_swap = [] + target_swap = [] + cond = [] + gt = [] + results = defaultdict(list) + # Evaluation is done by batch + for batch in tqdm(test_loader, ncols=80, desc='evaluation'): + # Data + x_cond, x_target = batch + bsz = len(x_cond) + x_cond = x_cond.to(device) + x_target = x_target.to(device) + cond.append(x_cond.cpu().mul(255).byte().permute(0, 1, 3, 4, 2)) + gt.append(x_target.cpu().mul(255).byte().permute(0, 1, 3, 4, 2)) + + # Prediction + x_pred, _, s_codes, _ = sep_net.get_forecast(x_cond, nt_test) + x_pred = x_pred[:, xp_config.nt_cond:] + + # Content swap + x_swap_cond, x_swap_target = next(train_iterator) + x_swap_cond = x_swap_cond[:bsz].to(device) + x_swap_target = x_swap_target[:bsz].to(device) + x_swap_cond_byte = x_swap_cond.cpu().mul(255).byte() + x_swap_target_byte = x_swap_target.cpu().mul(255).byte() + cond_swap.append(x_swap_cond_byte.permute(0, 1, 3, 4, 2)) + target_swap.append(x_swap_target_byte.permute(0, 1, 3, 4, 2)) + x_swap_pred = sep_net.get_forecast(x_swap_cond, nt_test, init_s_code=s_codes[:, 0])[0] + x_swap_pred = x_swap_pred[:, xp_config.dt:] + content_swap.append(x_swap_pred.cpu().mul(255).byte().permute(0, 1, 3, 4, 2)) + + # Pixelwise quantitative eval + x_target = x_target.view(-1, args.nt_pred, nc, size, size) + mse = torch.mean(F.mse_loss(x_pred, x_target, reduction='none'), dim=[3, 4]) + metrics_batch = { + 'mse': mse.mean(2).mean(1).cpu(), + 'psnr': 10 * torch.log10(1 / mse).mean(2).mean(1).cpu(), + 'ssim': _ssim_wrapper(x_pred, x_target).mean(2).mean(1).cpu() + } + predictions.append(x_pred.cpu().mul(255).byte().permute(0, 1, 3, 4, 2)) + + # Compute metrics for best samples and register + for name in metrics_batch.keys(): + results[name].append(metrics_batch[name]) + + ################################################################################################################## + # Print results + ################################################################################################################## + print('\n') + print('Results:') + for name in results.keys(): + res = torch.cat(results[name]).numpy() + results[name] = res + print(name, res.mean(), '+/-', 1.960 * res.std() / np.sqrt(len(res))) + + ################################################################################################################## + # Save samples + ################################################################################################################## + np.savez_compressed(os.path.join(args.xp_dir, 'results.npz'), **results) + np.savez_compressed(os.path.join(args.xp_dir, 'predictions.npz'), predictions=torch.cat(predictions).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'gt.npz'), gt=torch.cat(gt).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'cond.npz'), cond=torch.cat(cond).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'content_swap.npz'), content_swap=torch.cat(content_swap).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'cond_swap.npz'), target_swap=torch.cat(cond_swap).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'target_swap.npz'), target_swap=torch.cat(target_swap).numpy()) + + +if __name__ == '__main__': + p = argparse.ArgumentParser(prog="PDE-Driven Spatiotemporal Disentanglement (Moving MNIST testing)") + p.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Directory where the dataset is saved.') + p.add_argument('--xp_dir', type=str, metavar='DIR', required=True, + help='Directory where the model configuration file and checkpoints are saved.') + p.add_argument('--batch_size', type=int, metavar='BATCH', default=16, + help='Batch size used to compute metrics.') + p.add_argument('--nt_pred', type=int, metavar='PRED', required=True, + help='Total of frames to predict.') + p.add_argument('--device', type=int, metavar='DEVICE', default=None, + help='GPU where the model should be placed when testing (if None, put it on the CPU)') + p.add_argument('--test_seed', type=int, metavar='SEED', default=1, + help='Manual seed.') + args = p.parse_args() + main(args) diff --git a/var_sep/test/mnist/test_disentanglement.py b/var_sep/test/mnist/test_disentanglement.py new file mode 100644 index 0000000..c7829cc --- /dev/null +++ b/var_sep/test/mnist/test_disentanglement.py @@ -0,0 +1,242 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Code heavily modified from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights below. + +# Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import random +import torch +import math +import itertools + +import numpy as np +import torch.nn.functional as F + +from collections import defaultdict +from torch.utils.data import DataLoader, Dataset +from torchvision import datasets +from tqdm import tqdm + +from var_sep.data.moving_mnist import MovingMNIST +from var_sep.utils.helper import DotDict, load_json +from var_sep.utils.ssim import ssim_loss +from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder +from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder +from var_sep.networks.model import SeparableNetwork + + +def _ssim_wrapper(pred, gt): + bsz, nt_pred = pred.shape[0], pred.shape[1] + img_shape = pred.shape[2:] + ssim = ssim_loss(pred.reshape(bsz * nt_pred, *img_shape), gt.reshape(bsz * nt_pred, *img_shape), max_val=1., reduction='none') + return ssim.mean(dim=[2, 3]).view(bsz, nt_pred, img_shape[0]) + + +class SwapDataset(Dataset): + + def __init__(self, data_dir, seq_len, nt_cond, n_object): + self.seq_len = seq_len + self.n_object = n_object + self.nt_cond = nt_cond + self.frame_size = 64 + self.object_size = 28 + self.digits_permutation = np.random.permutation(10000) + self.trajectories = np.load(os.path.join(data_dir, f'mmnist_test_{n_object}digits_{self.frame_size}.npz'), + allow_pickle=True)['latents'] + self.images = datasets.MNIST(data_dir, train=False, download=True) + + def __len__(self): + return 10000 // self.n_object + + def __getitem__(self, index): + # get trajectory + x_trajectory_reverse = np.zeros((self.seq_len, 1, self.frame_size, self.frame_size), dtype=np.float32) + x_swap = np.zeros((math.factorial(self.n_object), self.seq_len, 1, self.frame_size, self.frame_size), + dtype=np.float32) + img = [self.images[self.digits_permutation[index + i * (10000 // self.n_object)]][0] + for i in range(self.n_object)] + trajectory = self.trajectories[:, index] + trajectory_reverse = self.trajectories[:, len(self) - index - 1] + for t in range(self.seq_len): + for i in range(self.n_object): + sx, sy, _, _ = trajectory_reverse[t, i] + x_trajectory_reverse[t, 0, sx:sx + self.object_size, sy:sy + self.object_size] += img[i] + for j, reordering in enumerate(itertools.permutations(range(self.n_object))): + for i in range(self.n_object): + sx, sy, _, _ = trajectory[t, i] + x_swap[j, t, 0, sx:sx + self.object_size, sy:sy + self.object_size] += img[reordering[i]] + x_trajectory_reverse[x_trajectory_reverse > 255] = 255 + x_swap[x_swap > 255] = 255 + return (torch.tensor(x_trajectory_reverse[:self.nt_cond]) / 255, + torch.tensor(x_trajectory_reverse[self.nt_cond:]) / 255, + torch.tensor(x_swap[:, :self.nt_cond]) / 255, torch.tensor(x_swap[:, self.nt_cond:]) / 255) + + +def load_dataset(args, train=False): + return MovingMNIST.make_dataset(args.data_dir, 64, args.nt_cond, args.nt_cond + args.nt_pred, 4, True, + args.n_object, train) + + +def build_model(args): + Es = torch.load(os.path.join(args.xp_dir, 'ov_Es.pt'), map_location=args.device).to(args.device) + Et = torch.load(os.path.join(args.xp_dir, 'ov_Et.pt'), map_location=args.device).to(args.device) + t_resnet = torch.load(os.path.join(args.xp_dir, 't_resnet.pt'), map_location=args.device).to(args.device) + decoder = torch.load(os.path.join(args.xp_dir, 'decoder.pt'), map_location=args.device).to(args.device) + sep_net = SeparableNetwork(Es, Et, t_resnet, decoder, args.nt_cond, args.skipco) + sep_net.eval() + return sep_net + + +def main(args): + ################################################################################################################## + # Setup + ################################################################################################################## + # -- Device handling (CPU, GPU) + if args.device is None: + device = torch.device('cpu') + else: + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) + device = torch.device('cuda:0') + torch.cuda.set_device(0) + # Seed + random.seed(args.test_seed) + np.random.seed(args.test_seed) + torch.manual_seed(args.test_seed) + # Load XP config + xp_config = load_json(os.path.join(args.xp_dir, 'params.json')) + xp_config.device = device + xp_config.data_dir = args.data_dir + xp_config.xp_dir = args.xp_dir + xp_config.nt_pred = args.nt_pred + + ################################################################################################################## + # Load test data + ################################################################################################################## + print('Loading data...') + test_dataset = load_dataset(xp_config, train=False) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, pin_memory=True) + swap_dataset = SwapDataset(args.data_dir, xp_config.nt_cond + args.nt_pred, xp_config.nt_cond, xp_config.n_object) + swap_loader = DataLoader(swap_dataset, batch_size=args.batch_size, pin_memory=True) + nc = 1 + size = 64 + + ################################################################################################################## + # Load model + ################################################################################################################## + print('Loading model...') + sep_net = build_model(xp_config) + + ################################################################################################################## + # Eval + ################################################################################################################## + print('Generating samples...') + torch.set_grad_enabled(False) + swap_iterator = iter(swap_loader) + nt_test = xp_config.nt_cond + args.nt_pred + gt_swap = [] + content_swap = [] + cond_swap = [] + target_swap = [] + results = defaultdict(list) + # Evaluation is done by batch + for batch in tqdm(test_loader, ncols=80, desc='evaluation'): + # Data + x_cond, x_target, _, x_gt_swap = next(swap_iterator) + x_gt_swap = x_gt_swap.to(device) + x_cond = x_cond.to(device) + + # Extraction of S + _, _, s_codes, _ = sep_net.get_forecast(x_cond, nt_test) + + # Content swap + x_swap_cond, x_swap_target = batch + x_swap_cond = x_swap_cond.to(device) + x_swap_target = x_swap_target.to(device) + x_swap_cond_byte = x_cond.cpu().mul(255).byte() + x_swap_target_byte = x_swap_target.cpu().mul(255).byte() + cond_swap.append(x_swap_cond_byte.permute(0, 1, 3, 4, 2)) + target_swap.append(x_swap_target_byte.permute(0, 1, 3, 4, 2)) + x_swap_pred = sep_net.get_forecast(x_swap_cond, nt_test, init_s_code=s_codes[:, 0])[0] + x_swap_pred = x_swap_pred[:, xp_config.nt_cond:] + content_swap.append(x_swap_pred.cpu().mul(255).byte().permute(0, 1, 3, 4, 2)) + gt_swap.append(x_gt_swap[:, 0].cpu().mul(255).byte().permute(0, 1, 3, 4, 2)) + + # Pixelwise quantitative eval + x_gt_swap = x_gt_swap.view(-1, xp_config.n_object, args.nt_pred, nc, size, size).to(device) + metrics_batch = {'mse': [], 'psnr': [], 'ssim': []} + for j, reordering in enumerate(itertools.permutations(range(xp_config.n_object))): + mse = torch.mean(F.mse_loss(x_swap_pred, x_gt_swap[:, j], reduction='none'), dim=[3, 4]) + metrics_batch['mse'].append(mse.mean(2).mean(1).cpu()) + metrics_batch['psnr'].append(10 * torch.log10(1 / mse).mean(2).mean(1).cpu()) + metrics_batch['ssim'].append(_ssim_wrapper(x_swap_pred, x_gt_swap[:, j]).mean(2).mean(1).cpu()) + + # Compute metrics for best samples and register + results['mse'].append(torch.min(torch.stack(metrics_batch['mse']), 0)[0]) + results['psnr'].append(torch.max(torch.stack(metrics_batch['psnr']), 0)[0]) + results['ssim'].append(torch.max(torch.stack(metrics_batch['ssim']), 0)[0]) + + ################################################################################################################## + # Print results + ################################################################################################################## + print('\n') + print('Results:') + for name in results.keys(): + res = torch.cat(results[name]).numpy() + results[name] = res + print(name, res.mean(), '+/-', 1.960 * res.std() / np.sqrt(len(res))) + + ################################################################################################################## + # Save samples + ################################################################################################################## + np.savez_compressed(os.path.join(args.xp_dir, 'results_swap.npz'), **results) + np.savez_compressed(os.path.join(args.xp_dir, 'content_swap_gt.npz'), gt_swap=torch.cat(gt_swap).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'content_swap_test.npz'), content_swap=torch.cat(content_swap).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'cond_swap_test.npz'), cond_swap=torch.cat(cond_swap).numpy()) + np.savez_compressed(os.path.join(args.xp_dir, 'target_swap_test.npz'), target_swap=torch.cat(target_swap).numpy()) + + + +if __name__ == '__main__': + p = argparse.ArgumentParser(prog="PDE-Driven Spatiotemporal Disentanglement (Moving MNIST content swap testing)") + p.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Directory where the dataset is saved.') + p.add_argument('--xp_dir', type=str, metavar='DIR', required=True, + help='Directory where the model configuration file and checkpoints are saved.') + p.add_argument('--batch_size', type=int, metavar='BATCH', default=16, + help='Batch size used to compute metrics.') + p.add_argument('--nt_pred', type=int, metavar='PRED', required=True, + help='Total of frames to predict.') + p.add_argument('--device', type=int, metavar='DEVICE', default=None, + help='GPU where the model should be placed when testing (if None, put it on the CPU)') + p.add_argument('--test_seed', type=int, metavar='SEED', default=1, + help='Manual seed.') + args = DotDict(vars(p.parse_args())) + main(args) diff --git a/var_sep/test/sst/test.py b/var_sep/test/sst/test.py new file mode 100644 index 0000000..ecc7faf --- /dev/null +++ b/var_sep/test/sst/test.py @@ -0,0 +1,106 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import torch + +import numpy as np + +from tqdm import tqdm + +from var_sep.data.sst import SST +from var_sep.utils.helper import DotDict, load_json +from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder +from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder +from var_sep.networks.model import SeparableNetwork + + +def load_dataset(args, train=False, zones=range(17, 21)): + return SST(args.data_dir, args.nt_cond, args.nt_pred, train, zones=zones, eval=True) + + +def build_model(args): + Es = torch.load(os.path.join(args.xp_dir, 'ov_Es.pt'), map_location=args.device).to(args.device) + Et = torch.load(os.path.join(args.xp_dir, 'ov_Et.pt'), map_location=args.device).to(args.device) + t_resnet = torch.load(os.path.join(args.xp_dir, 't_resnet.pt'), map_location=args.device).to(args.device) + decoder = torch.load(os.path.join(args.xp_dir, 'decoder.pt'), map_location=args.device).to(args.device) + sep_net = SeparableNetwork(Es, Et, t_resnet, decoder, args.nt_cond, args.skipco) + sep_net.eval() + return sep_net + + +def compute_mse(args, test_set, sep_net): + all_mse = [] + torch.set_grad_enabled(False) + for cond, target, mu_clim, std_clim, mu_norm, std_norm in tqdm(test_set): + cond, target = cond.unsqueeze(0).to(args.device), target.unsqueeze(0).to(args.device) + if args.offset: + forecasts, t_codes, s_codes, t_residuals = sep_net.get_forecast(cond, target.size(1) + args.nt_cond) + forecasts = forecasts[:, args.nt_cond:] + else: + forecasts, t_codes, s_codes, t_residuals, s_residuals = sep_net.get_forecast(cond, target.size(1)) + + mu_norm, std_norm = (torch.tensor(mu_norm, dtype=torch.float).to(args.device), + torch.tensor(std_norm, dtype=torch.float).to(args.device)) + + forecasts = (forecasts * std_norm) + mu_norm + target = (target * std_norm) + mu_norm + + # Original space + mu_clim, std_clim = (torch.tensor(mu_clim, dtype=torch.float).to(args.device), + torch.tensor(std_clim, dtype=torch.float).to(args.device)) + forecasts = (forecasts * std_clim) + mu_clim + target = (target * std_clim) + mu_clim + + mse = (forecasts - target).pow(2).mean(dim=-1).mean(dim=-1).mean(dim=-1) + + all_mse.append(mse.cpu().numpy()) + + return all_mse + + +def main(args): + if args.device is None: + device = torch.device('cpu') + else: + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) + device = torch.device('cuda:0') + torch.cuda.set_device(0) + # Load XP config + xp_config = load_json(os.path.join(args.xp_dir, 'params.json')) + xp_config.device = device + xp_config.nt_pred = 10 + args.nt_pred = 10 + + test_set = load_dataset(xp_config, train=False) + sep_net = build_model(xp_config) + + all_mse = compute_mse(xp_config, test_set, sep_net) + mse_array = np.concatenate(all_mse, axis=0) + print(f'MSE at t+10: {np.mean(mse_array.mean(axis=0)[:10])}') + print(f'MSE at t+6: {np.mean(mse_array.mean(axis=0)[:6])}') + + +if __name__ == '__main__': + p = argparse.ArgumentParser(prog="PDE-Driven Spatiotemporal Disentanglement (Moving MNIST testing)") + p.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Directory where the dataset is saved.') + p.add_argument('--xp_dir', type=str, metavar='DIR', required=True, + help='Directory where the model configuration file and checkpoints are saved.') + p.add_argument('--device', type=int, metavar='DEVICE', default=None, + help='GPU where the model should be placed when testing (if None, put it on the CPU)') + args = DotDict(vars(p.parse_args())) + main(args) diff --git a/var_sep/test/wave/test.py b/var_sep/test/wave/test.py new file mode 100644 index 0000000..847373b --- /dev/null +++ b/var_sep/test/wave/test.py @@ -0,0 +1,106 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import torch + +import numpy as np + +from torch.utils.data import DataLoader +from tqdm import tqdm + +from var_sep.data.wave_eq import WaveEq, WaveEqPartial +from var_sep.utils.helper import DotDict, load_json +from var_sep.networks.conv import DCGAN64Encoder, VGG64Encoder, DCGAN64Decoder, VGG64Decoder +from var_sep.networks.mlp_encdec import MLPEncoder, MLPDecoder +from var_sep.networks.model import SeparableNetwork + + +def load_dataset(args, train=False): + if args.data == 'wave': + return WaveEq(args.data_dir, args.nt_cond, args.nt_cond + args.nt_pred, train, args.downsample) + else: + return WaveEqPartial(args.data_dir, args.nt_cond, args.nt_cond + args.nt_pred, True, args.downsample, + args.n_wave_points) + + +def build_model(args): + Es = torch.load(os.path.join(args.xp_dir, 'ov_Es.pt'), map_location=args.device).to(args.device) + Et = torch.load(os.path.join(args.xp_dir, 'ov_Et.pt'), map_location=args.device).to(args.device) + t_resnet = torch.load(os.path.join(args.xp_dir, 't_resnet.pt'), map_location=args.device).to(args.device) + decoder = torch.load(os.path.join(args.xp_dir, 'decoder.pt'), map_location=args.device).to(args.device) + sep_net = SeparableNetwork(Es, Et, t_resnet, decoder, args.nt_cond, args.skipco) + sep_net.eval() + return sep_net + + +def compute_mse(args, batch_size, test_set, sep_net): + all_mse = [] + loader = DataLoader(test_set, batch_size=batch_size, pin_memory=False, shuffle=False, num_workers=3) + torch.set_grad_enabled(False) + for cond, target in tqdm(loader): + cond, target = cond.to(args.device), target.to(args.device) + if args.offset: + forecasts, t_codes, s_codes, t_residuals = sep_net.get_forecast(cond, target.size(1) + args.nt_cond) + forecasts = forecasts[:, args.nt_cond:] + else: + forecasts, t_codes, s_codes, t_residuals = sep_net.get_forecast(cond, target.size(1)) + + forecasts = forecasts.view(target.shape) + + if args.data == 'wave': + mse = (forecasts - target).pow(2).mean(dim=-1).mean(dim=-1).mean(dim=-1) + else: + mse = (forecasts - target).pow(2).mean(dim=-1) + + all_mse.append(mse.data.cpu().numpy()) + + return all_mse + + +def main(args): + if args.device is None: + device = torch.device('cpu') + else: + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) + device = torch.device('cuda:0') + torch.cuda.set_device(0) + # Load XP config + xp_config = load_json(os.path.join(args.xp_dir, 'params.json')) + xp_config.device = device + xp_config.nt_pred = 40 + args.nt_pred = 40 + + test_set = load_dataset(xp_config, train=False) + sep_net = build_model(xp_config) + + all_mse = compute_mse(xp_config, args.batch_size, test_set, sep_net) + mse_array = np.concatenate(all_mse, axis=0) + print(f'MSE at t+40: {np.mean(mse_array.mean(axis=0)[:40])}') + + +if __name__ == '__main__': + p = argparse.ArgumentParser(prog="PDE-Driven Spatiotemporal Disentanglement (Moving MNIST testing)") + p.add_argument('--data_dir', type=str, metavar='DIR', required=True, + help='Directory where the dataset is saved.') + p.add_argument('--xp_dir', type=str, metavar='DIR', required=True, + help='Directory where the model configuration file and checkpoints are saved.') + p.add_argument('--batch_size', type=int, metavar='BATCH', default=256, + help='Batch size used to compute metrics.') + p.add_argument('--device', type=int, metavar='DEVICE', default=None, + help='GPU where the model should be placed when testing (if None, put it on the CPU)') + args = DotDict(vars(p.parse_args())) + main(args) diff --git a/var_sep/train.py b/var_sep/train.py new file mode 100644 index 0000000..3e1b92a --- /dev/null +++ b/var_sep/train.py @@ -0,0 +1,158 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +import numpy as np +import torch.nn.functional as F + +from tqdm import tqdm + +from var_sep.utils.helper import save + + +# Apex +try: + from apex import amp + APX_ = 1 +except Exception: + APX_ = 0 + + +def zero_order_loss(s_code_old, s_code_new, skipco): + if skipco: + s_code_old = s_code_old[0] + s_code_new = s_code_new[0] + return (s_code_old - s_code_new).pow(2).mean() + + +def ae_loss(cond, target, sep_net, nt_cond, offset, skipco): + + """ + Autoencoding function: we consider in this case that: + + if offset == nt_cond: + St = S_{t, t+1, ..., t+nt_cond} , Tt = T_{t, t+1, ..., t+nt_cond} is associated to t, i.e + D(St, Tt) = vt, + somehow like a backward inference. + + if offset == 0: + St = S_{t, t+1, ..., t+nt_cond} , Tt = T_{t, t+1, ..., t+nt_cond} is associated to t+nt_cond, i.e + D(St, Tt) = v(t + nt_cond), + somehow like estimating how dynamic has moved from t up to t + dt + + This function also returns the result of the application of Es on the first and last seen frames. + """ + + full_data = torch.cat([cond, target], dim=1) + data_new = full_data[:, -nt_cond:] + data_old = full_data[:, :nt_cond] + + # Encode spatial information + s_code_old = sep_net.Es(data_old, return_skip=skipco) + s_code_new = sep_net.Es(data_new, return_skip=skipco) + + # Encode motion information at a random time + if offset == 0: + t_random = np.random.randint(nt_cond, full_data.size(1)) + else: + t_random = np.random.randint(nt_cond, full_data.size(1) + 1) + t_code_random = sep_net.Et(full_data[:, t_random - nt_cond:t_random]) + + # Decode from S and random T + if skipco: + reconstruction = sep_net.decoder(s_code_old[0], t_code_random, skip=s_code_old[1]) + else: + reconstruction = sep_net.decoder(s_code_old, t_code_random) + + # AE loss + supervision_data = full_data[:, t_random - offset] + loss = F.mse_loss(supervision_data, reconstruction, reduction='mean') + + return loss, s_code_new, s_code_old + + +def train(xp_dir, train_loader, device, sep_net, optimizer, scheduler, apex_amp, epochs, lamb_ae, lamb_s, lamb_t, + lamb_pred, offset, nt_cond, nt_pred, no_s, skipco): + + if apex_amp and not APX_: + raise ImportError('Apex not installed (https://github.com/NVIDIA/apex)') + + if apex_amp: + sep_net, optimizer = amp.initialize(sep_net, optimizer, opt_level='O1') + + if no_s: + lamb_t = 0 + print("No regularization on T as there is no S") + + assert offset == nt_cond or offset == 0 + + try: + for epoch in range(epochs): + + sep_net.train() + for bt, (cond, target) in enumerate(tqdm(train_loader)): + cond, target = cond.to(device), target.to(device) + total_loss = 0 + + optimizer.zero_grad() + + # ########## + # AUTOENCODE + # ########## + ae_loss_value, s_recent, s_old = ae_loss(cond, target, sep_net, nt_cond, offset, skipco) + total_loss += lamb_ae * ae_loss_value + + # ################## + # SPATIAL INVARIANCE + # ################## + spatial_ode_loss = zero_order_loss(s_old, s_recent, skipco) + total_loss += lamb_s * spatial_ode_loss + + # ############# + # FORECAST LOSS + # ############# + full_data = torch.cat([cond, target], dim=1) # Concatenate all frames + forecasts, t_codes, s_codes, t_residuals = sep_net.get_forecast(cond, nt_pred + offset, init_s_code=s_old) + # To make data and target match + if offset == 0: + forecast_offset = nt_cond + else: + forecast_offset = 0 + forecast_loss = F.mse_loss(forecasts, full_data[:, forecast_offset:]) + total_loss += lamb_pred * forecast_loss + + # ################ + # T REGULARIZATION + # ################ + t_reg = 0.5 * torch.sum(t_codes[:, 0].pow(2), dim=1).mean() + total_loss += lamb_t * t_reg + + if apex_amp: + with amp.scale_loss(total_loss, optimizer) as scaled_loss: + scaled_loss.backward() + + else: + total_loss.backward() + + optimizer.step() + + if scheduler is not None: + scheduler.step() + + except KeyboardInterrupt: + pass + + save(xp_dir, sep_net) diff --git a/var_sep/utils/helper.py b/var_sep/utils/helper.py new file mode 100644 index 0000000..78603a1 --- /dev/null +++ b/var_sep/utils/helper.py @@ -0,0 +1,77 @@ +# Copyright 2020 Jérémie Donà, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import torch +import yaml + + +def save(elem_xp_path, sep_net): + to_save = True + while to_save: + try: + torch.save(sep_net.Et, os.path.join(elem_xp_path, 'ov_Et.pt')) + torch.save(sep_net.Es, os.path.join(elem_xp_path, 'ov_Es.pt')) + torch.save(sep_net.decoder, os.path.join(elem_xp_path, 'decoder.pt')) + torch.save(sep_net.t_resnet, os.path.join(elem_xp_path, 't_resnet.pt')) + to_save = False + except: + print("unable to save all files") + + +# The following code is adapted from SRVP https://github.com/edouardelasalles/srvp; see license notice and copyrights +# below. + +# # Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class DotDict(dict): + """ + Dot notation access to dictionary attributes. + """ + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_yaml(path): + """ + Loads a yaml input file. + """ + with open(path, 'r') as f: + opt = yaml.safe_load(f) + return DotDict(opt) + + +def load_json(path): + """ + Loads a json input file. + """ + with open(path, 'r') as f: + opt = json.load(f) + return DotDict(opt) diff --git a/var_sep/utils/ssim.py b/var_sep/utils/ssim.py new file mode 100644 index 0000000..daacc6c --- /dev/null +++ b/var_sep/utils/ssim.py @@ -0,0 +1,149 @@ +# Code from PyTorch PR https://github.com/pytorch/pytorch/pull/22289, see license and copyrights below. + +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +import torch + +from torch.nn import _reduction as _Reduction +from torch.nn.functional import conv2d + + +def _fspecial_gaussian(size, channel, sigma): + coords = torch.tensor([(x - (size - 1.) / 2.) for x in range(size)]) + coords = -coords ** 2 / (2. * sigma ** 2) + grid = coords.view(1, -1) + coords.view(-1, 1) + grid = grid.view(1, -1) + grid = grid.softmax(-1) + kernel = grid.view(1, 1, size, size) + kernel = kernel.expand(channel, 1, size, size).contiguous() + return kernel + + +def _ssim(input, target, max_val, k1, k2, channel, kernel): + c1 = (k1 * max_val) ** 2 + c2 = (k2 * max_val) ** 2 + + mu1 = conv2d(input, kernel, groups=channel) + mu2 = conv2d(target, kernel, groups=channel) + + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + + sigma1_sq = conv2d(input * input, kernel, groups=channel) - mu1_sq + sigma2_sq = conv2d(target * target, kernel, groups=channel) - mu2_sq + sigma12 = conv2d(input * target, kernel, groups=channel) - mu1_mu2 + + v1 = 2 * sigma12 + c2 + v2 = sigma1_sq + sigma2_sq + c2 + + ssim = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2) + return ssim, v1 / v2 + + +def ssim_loss(input, target, max_val, filter_size=11, k1=0.01, k2=0.03, + sigma=1.5, kernel=None, size_average=None, reduce=None, reduction='mean'): + r"""ssim_loss(input, target, max_val, filter_size, k1, k2, + sigma, kernel=None, size_average=None, reduce=None, reduction='mean') -> Tensor + Measures the structural similarity index (SSIM) error. + See :class:`~torch.nn.SSIMLoss` for details. + """ + + if input.size() != target.size(): + raise ValueError('Expected input size ({}) to match target size ({}).' + .format(input.size(0), target.size(0))) + + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + dim = input.dim() + if dim == 2: + input = input.expand(1, 1, input.dim(-2), input.dim(-1)) + target = target.expand(1, 1, target.dim(-2), target.dim(-1)) + elif dim == 3: + input = input.expand(1, input.dim(-3), input.dim(-2), input.dim(-1)) + target = target.expand(1, target.dim(-3), target.dim(-2), target.dim(-1)) + elif dim != 4: + raise ValueError('Expected 2, 3, or 4 dimensions (got {})'.format(dim)) + + _, channel, _, _ = input.size() + + if kernel is None: + kernel = _fspecial_gaussian(filter_size, channel, sigma) + kernel = kernel.to(device=input.device) + + ret, _ = _ssim(input, target, max_val, k1, k2, channel, kernel) + + if reduction != 'none': + ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret) + return ret