diff --git a/configs/ilo_single_file.yaml b/configs/ilo_single_file.yaml new file mode 100644 index 00000000..c8a77939 --- /dev/null +++ b/configs/ilo_single_file.yaml @@ -0,0 +1,182 @@ +hydra: + job: + chdir: True # change to output folder + +save: True + +simulation: + object_height: 0.3 + # these distance parameters are typically fixed for a given PSF + scene2mask: 40e-2 + mask2sensor: 4e-3 + # see waveprop.devices + sensor: "rpi_hq" + snr_db: 20 + # Downsampling for PSF + downsample: 8 + + # max val in simulated measured (quantized 8 bits) + max_val: 230 + + image_format: RGB + + flatcam: False # only supported if mask.type is "MURA" or "MLS" + + + + +# COPIED FROM HERE: https://github.com/ebezzam/ilo_lensless/blob/ilo_lensless/configs/config_default.yaml +# TODO PRUNE +# Inputs/Outputs directories and files +files: + original_fp: data/celeba_mini/000019.jpg + psf: data/psf/tape_rgb.png + + #preprocess_dir: files/demo/preprocessed/ + preprocess_dir: data/celeba_mini_aligned/ + #psf_dir: files/demo/psf/ + output_dir: ./output/ + files_ext: '.png' + #pandas_file: ./metrics.pkl + + face_aligner: 'models/ilo/shape_predictor_68_face_landmarks.dat' + +mask: + type: "tape" # "MURA", "MLS", "FZA", "PhaseContour" + + # Coded Aperture (MURA or MLS) + #flatcam_method: 'MLS' + n_bits: 8 # e.g. 8 for MLS, 99 for MURA + + # Phase Contour + noise_period: [16, 16] + refractive_index: 1.2 + phase_mask_iter: 10 + + # Fresnel Zone Aperture + radius: 0.32e-3 + +lensless_imaging: + bool: false + psf_path: data/psf/tape_rgb.png # if simulated false + + # PSF parameters (TODO consistent with simulation) + scene2mask: 40e-2 #m + mask2sensor: 4e-3 #m + object_height: 0.45 #m, TODO: made big to fit entire region + sensor: 'rpi_hq' + psf_size: [3040, 4056] + downsample: 8 + image_format: RGB + + # Simulation parameters + simulated: + bool: False + gaussian_noise: True + snr: 40.0 + max_val: 255 + flatcam: False + +# Preprocessing actions +preprocessing: + align: + bool: False + resize: + bool: True + image_size: [1024, 1024] + grayscale: # TODO : works but PSF grayscale? how to do ? + bool: False + mask: + bool: False + bounding_box: + horizontal: + - 200 + - 400 + vertical: + - 200 + - 400 + noise: + bool: False + mode: gaussian #pepper #s&p #poisson #salt + mean: 0 + var: 0.01 + amount: 0.05 + salt_vs_pepper: 0.5 + +# Tasks to perform +task: + grayscale: + bool: False # TODO + +# Model configuration +# only StyleGAN is supported for now. +model: + type: 'stylegan2' + checkpoint: models/ilo/stylegan2-ffhq-config-f.pt + mapping: models/ilo/gaussian_fit.pt + +# Optimization parameters +opti_params: + seed: 42 + device: cuda + + # batchsizes + batchsize_process : 1 + batchsize_preprocess : 50 + + # Range of layers to optimize (no need to touch it) + start_layer: 0 + end_layer: 8 + + # steps per layer -> define which layer is optimized + # if you want to skip optimization in some layers, just use 0 to the corresponding indices of steps. + # steps: [50, 50, 50, 50, 50, 50, 50, 50, 50] # up to 9, but need to extend max_radius_* + steps: [50, 50, 50, 50, 50, 50] + + # learning rate per layer + lr: [0.1, 0.08, 0.06, 0.04, 0.02, 0.01, 0.01, 0.01, 0.01] + + # whether to schedule per layer or in total + lr_same_pace: False + + # project latents to unit ball + project: False + # project: True + + # projections (decent results with false, toilsome to fine-tune) + do_project_latent: False + do_project_noises: False + do_project_gen_out: False + + max_radius_latent: [300, 500, 1000, 2000, 4000, 8000, 8000, 8000, 8000] + max_radius_noises: [300, 2000, 2000, 4000, 6000, 8000, 8000, 8000, 8000] + max_radius_gen_out: [0, 500, 1000, 2000, 4000, 8000, 8000, 8000, 8000] + # note: first value of max_radius_gen_out is not used + + +# Loss parameters +loss_params: + ## weights of different losses + geocross: 0.01 + mse: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + pe: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + dead_zone_linear: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + # tolerance of dead zone linear function + dead_zone_linear_alpha: 0.05 + # LPIPS method + lpips_method: 'default' + + +logs: + # Save of logs + # if true, intermediate frames from optimization are saved. + save_gif: False + # determines how often we save intermediate frames in each steps. Activated only if save_gif=True. + save_every: 50 + + # Forward of generated images + # if true, save it + save_forward: True + # dir of forward save + forward_dir: ./output_forward/ + diff --git a/ilo_requirements.txt b/ilo_requirements.txt new file mode 100644 index 00000000..c1e6b201 --- /dev/null +++ b/ilo_requirements.txt @@ -0,0 +1,2 @@ +dlib>==19.24.2 +ninja==1.10.2.3 \ No newline at end of file diff --git a/lensless/recon/ilo_stylegan2/ilo.py b/lensless/recon/ilo_stylegan2/ilo.py new file mode 100644 index 00000000..eccd9afb --- /dev/null +++ b/lensless/recon/ilo_stylegan2/ilo.py @@ -0,0 +1,482 @@ +""" +TODO : authorship from original ILO +https://github.com/giannisdaras/ilo/blob/master/ilo_stylegan.py +""" + +import os +import numpy as np +import math +from tqdm import tqdm +from PIL import Image +from hydra.utils import to_absolute_path +import torch +from torch import optim +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from torchvision import transforms +from . import ( + lpips as lpips, +) # TODO : linking needs to be fixed, or use LPIPS from torchmetrics: https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html +from .stylegan2 import Generator +from .utils import project_onto_l1_ball, zero_padding_tensor +from waveprop.devices import sensor_dict, SensorParam +from lensless.recon.rfft_convolve import RealFFTConvolve2D +from waveprop.simulation import FarFieldSimulator + + +torch.set_printoptions(precision=5) +torch.autograd.set_detect_anomaly(True) + + +# def get_transformation(image_size): +# return transforms.Compose( +# [transforms.Resize(image_size), +# transforms.ToTensor(), +# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + + +def get_transformation(): + return transforms.Compose([transforms.ToTensor()]) + + +# Latent z -> latent w +class MappingProxy(nn.Module): + def __init__(self, gaussian_ft): + super(MappingProxy, self).__init__() + self.mean = gaussian_ft["mean"] + self.std = gaussian_ft["std"] + self.lrelu = torch.nn.LeakyReLU(0.2) + + def forward(self, x): + x = self.lrelu(self.std * x + self.mean) + return x + + +def loss_geocross(latent): + if latent.size()[1:] == (1, 512): + return 0 + else: + num_latents = latent.size()[1] + X = latent.view(-1, 1, num_latents, 512) + Y = latent.view(-1, num_latents, 1, 512) + A = ((X - Y).pow(2).sum(-1) + 1e-9).sqrt() + B = ((X + Y).pow(2).sum(-1) + 1e-9).sqrt() + D = 2 * torch.atan2(A, B) + D = ((D.pow(2) * 512).mean((1, 2)) / 8.0).mean() + return D + + +class SphericalOptimizer: + def __init__(self, params): + self.params = params + with torch.no_grad(): + self.radii = { + param: (param.pow(2).sum(tuple(range(2, param.ndim)), keepdim=True) + 1e-9).sqrt() + for param in params + } + + @torch.no_grad() + def step(self, closure=None): + for param in self.params: + param.data.div_( + (param.pow(2).sum(tuple(range(2, param.ndim)), keepdim=True) + 1e-9).sqrt() + ) + param.mul_(self.radii[param]) + + +class LatentOptimizer(torch.nn.Module): + def __init__(self, config, psf, mask=None): + super().__init__() + + model_checkpoint = to_absolute_path(config["model"]["checkpoint"]) + mapping_checkpoint = to_absolute_path(config["model"]["mapping"]) + if not os.path.exists(model_checkpoint): + print("Model checkpoint does not exist") + + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + + msg = "Do you want to download and use ILO files from SwitchDrive (560MB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + + current_path = os.path.dirname(__file__) + + model_dir = os.path.join(current_path, "..", "..", "models") + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + ilo_path = os.path.join(model_dir, "ilo") + if os.path.exists(ilo_path): + print("ILO already exists (no need to download).") + else: + url = "https://drive.switch.ch/index.php/s/hD0JqMJemFJ7FKo/download" + filename = "ilo.zip" + download_and_extract_archive( + url, model_dir, filename=filename, remove_finished=True + ) + + model_checkpoint = os.path.join(model_dir, "ilo", "stylegan2-ffhq-config-f.pt") + mapping_checkpoint = os.path.join(model_dir, "ilo", "gaussian_fit.pt") + + self.device = config["opti_params"]["device"] + + # Load models and pre-trained weights + gen = Generator(1024, 512, 8) + gen.load_state_dict(torch.load(model_checkpoint)["g_ema"], strict=False) + gen.eval() + self.gen = gen.to(self.device) + self.gen.start_layer = config["opti_params"]["start_layer"] + self.gen.end_layer = config["opti_params"]["end_layer"] + self.mpl = MappingProxy(torch.load(mapping_checkpoint, self.device)) + + cuda_ids = [0] + if self.device.startswith("cuda:"): + cuda_ids = self.device.split(":")[-1].split(",") + cuda_ids = [int(cuda_id) for cuda_id in cuda_ids] + + self.percept = lpips.PerceptualLoss( + model="net-lin", net="vgg", use_gpu=self.device.startswith("cuda"), gpu_ids=cuda_ids + ) + + # Transform on each image + self.transform = get_transformation() + + # Task + image_size = np.array(config["preprocessing"]["resize"]["image_size"]) + + # Load PSF + try: + psf_image = psf.copy() + except Exception: + if mask is None: + psf_path = config["lensless_imaging"]["psf_path"] + psf_image = np.array(Image.open(to_absolute_path(psf_path)).convert("RGB")) + else: + psf_image = mask.psf.copy() + + self.lensless_imaging = True + scene2mask = config["lensless_imaging"]["scene2mask"] + mask2sensor = config["lensless_imaging"]["mask2sensor"] + object_height = config["lensless_imaging"]["object_height"] + sensor = config["lensless_imaging"]["sensor"] + sensor_config = sensor_dict[sensor] + mask_type = config["mask"]["type"] + flatcam = config["lensless_imaging"]["simulated"]["flatcam"] + snr = config["lensless_imaging"]["simulated"]["snr"] + max_val = config["lensless_imaging"]["simulated"]["max_val"] + + self.psf_size = np.array(psf_image.shape) + # bring channel dimension to the front + self.psf_size = np.roll(self.psf_size, 1) + # try: + # self.psf_size = np.array(psf_image.shape) + # except Exception: + # self.psf_size = np.array(config["lensless_imaging"]["psf_size"]) + + # Input image at the right size + magnification = mask2sensor / scene2mask + scene_dim = sensor_config[SensorParam.SIZE] / magnification + object_height_pix = int(np.round(object_height / scene_dim[1] * self.psf_size[1])) + scaling = object_height_pix / image_size[1] + self.object_dim = (np.round(image_size * scaling)).astype(int).tolist() + + # Normalize and forward model + psf_image = psf_image / np.linalg.norm(psf_image.ravel()) + # self.psf_image = torch.from_numpy(psf_image).unsqueeze(0).to(self.device) + self.psf_image = ( + torch.from_numpy(psf_image).float().permute(2, 0, 1).unsqueeze(0).to(self.device) + ) + + # Create forward model + # self.forward_model = RealFFTConvolve2D(self.psf_image) + if flatcam and mask_type.upper() not in ["MURA", "MLS"]: + print("Separable assumption only available for coded apertures.") + flatcam = False + if flatcam: + self.forward_model = lambda img: mask.simulate(img, snr_db=snr) + # elif mask is not None: + else: + + self.simulator = FarFieldSimulator( + psf=self.psf_image, + object_height=object_height, + scene2mask=scene2mask, + mask2sensor=mask2sensor, + sensor=sensor, + snr_db=None, + quantize=True, + max_val=max_val, + is_torch=True, + device_conv=self.device, + ) + self.forward_model = self.simulator.propagate + + # Opti parameters + self.start_layer = config["opti_params"]["start_layer"] + self.end_layer = config["opti_params"]["end_layer"] + self.steps = config["opti_params"]["steps"] + + self.lr = config["opti_params"]["lr"] + self.lr_same_pace = config["opti_params"]["lr_same_pace"] + + self.project = config["opti_params"]["project"] + self.do_project_latent = config["opti_params"]["do_project_latent"] + self.do_project_noises = config["opti_params"]["do_project_noises"] + self.do_project_gen_out = config["opti_params"]["do_project_gen_out"] + + self.max_radius_latent = config["opti_params"]["max_radius_latent"] + self.max_radius_noises = config["opti_params"]["max_radius_noises"] + self.max_radius_gen_out = config["opti_params"]["max_radius_gen_out"] + + # Loss parmaters + self.geocross = config["loss_params"]["geocross"] + self.mse = config["loss_params"]["mse"] + self.pe = config["loss_params"]["pe"] + self.dead_zone_linear = config["loss_params"]["dead_zone_linear"] + self.dead_zone_linear_alpha = config["loss_params"]["dead_zone_linear_alpha"] + self.lpips_method = config["loss_params"]["lpips_method"] + + # Logs parameters + self.save_gif = config["logs"]["save_gif"] + self.save_every = config["logs"]["save_every"] + self.save_forward = config["logs"]["save_forward"] + + def init_state(self, input_files): + + # Initialize the state of the optimizer, has to be performed before every run + + self.layer_in = None + self.best = None + self.best_forward = None + self.current_step = 0 + + # Load images + input_images = [] + for input_file in input_files: + input_images.append(self.transform(Image.open(input_file).convert("RGB"))) + self.input_images = torch.stack(input_images, 0).to(self.device) + self.batchsize = self.input_images.shape[0] + + # Initialization of latent vector + noises_single = self.gen.make_noise(self.batchsize) + self.noises = [] + for noise in noises_single: + self.noises.append(noise.normal_()) + self.latent_z = torch.randn( + (self.batchsize, 18, 512), dtype=torch.float, requires_grad=True, device=self.device + ) + self.gen_outs = [None] + + def get_lr(self, t, initial_lr, rampdown=0.75, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + return initial_lr * lr_ramp + + def invert_(self, start_layer, noise_list, steps, index): + learning_rate_init = self.lr[index] + print(f"Running round {index + 1} / {len(self.steps)} of ILO.") + + # noise_list containts the indices of nodes that we will be optimizing over + for i in range(len(self.noises)): + if i in noise_list: + self.noises[i].requires_grad = True + else: + self.noises[i].requires_grad = False + with torch.no_grad(): + if start_layer == 0: + var_list = [self.latent_z] + self.noises + else: + self.gen_outs[-1].requires_grad = True + var_list = [self.latent_z] + self.noises + [self.gen_outs[-1]] + prev_gen_out = ( + torch.ones(self.gen_outs[-1].shape, device=self.gen_outs[-1].device) + * self.gen_outs[-1] + ) + prev_latent = ( + torch.ones(self.latent_z.shape, device=self.latent_z.device) * self.latent_z + ) + prev_noises = [ + torch.ones(noise.shape, device=noise.device) * noise for noise in self.noises + ] + + # set network that we will be optimizing over + self.gen.start_layer = start_layer + self.gen.end_layer = self.end_layer + + # Optimizer + optimizer = optim.Adam(var_list, lr=self.lr[index]) + ps = SphericalOptimizer([self.latent_z] + self.noises) + pbar = tqdm(range(steps)) + self.current_step += steps + + # Loss + mse_loss = 0 + p_loss = 0 + + for i in pbar: + # Update learning rate + if self.lr_same_pace: + total_steps = sum(self.steps) + t = i / total_steps + else: + t = i / steps + + lr = self.get_lr(t, learning_rate_init) + optimizer.param_groups[0]["lr"] = lr + + # Update generated image + latent_w = self.mpl(self.latent_z) + + img_gen, _ = self.gen( + [latent_w], + input_is_latent=True, + noise=self.noises, + layer_in=self.gen_outs[-1], + ) + + # Normalize output of GAN from standardize to [0, 1] per batch + img_gen = torch.clamp(img_gen, -1, 1) + img_gen = 0.5 * img_gen + 0.5 + + # Calculate loss + loss = 0 + + # TODO : check if image always generated on 1024x1024 + # Downsample to the original size + A_img_gen = img_gen + # A_img_gen = self.downsampler_1024_image(img_gen) + # A_img_gen = F.interpolate(A_img_gen, size=, mode='bicubic') + + A_img_gen = F.interpolate(A_img_gen, size=self.object_dim, mode="bicubic") + A_img_gen = zero_padding_tensor(A_img_gen, self.psf_size) + A_img_gen = self.forward_model(A_img_gen).to(torch.float32) + + # Using the all range [0,1], as LPIPS expects [0, 1] + with torch.no_grad(): + max_vals = torch.max(torch.flatten(A_img_gen, start_dim=1), dim=1)[0] + max_vals = max_vals.unsqueeze(1).unsqueeze(1).unsqueeze(1) + A_img_gen = A_img_gen / max_vals + + # Calculate perceptual loss (LPIPS) + if self.pe[index] != 0: + if self.lpips_method == "default": + p_loss = self.percept(A_img_gen, self.input_images, normalize=True).mean() + # elif self.lpips_method == 'fill_mask': + # # TODO : maybe need to downsampeld if super resolution + # filled = self.mask * self.input_images + (1 - self.mask) * img_gen + # p_loss = self.percept(self.downsampler_image_256(A_img_gen), self.downsampler_image_256(filled), normalize=True).mean() + else: + raise NotImplementedError("LPIPS policy not implemented") + loss += self.pe[index] * p_loss + + # Calculate dead_zone_linear loss + diff = torch.abs(A_img_gen - self.input_images) - self.dead_zone_linear_alpha + loss += ( + self.dead_zone_linear[index] + * torch.max(torch.zeros(diff.shape, device=diff.device), diff).mean() + ) + + # Calculate MSE loss + mse_loss = F.mse_loss(A_img_gen, self.input_images) + loss += self.mse[index] * mse_loss + + # Calculate Geocross loss + loss += self.geocross * loss_geocross(self.latent_z[:, start_layer:]) + + # Backpropagate + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Backpropagate on projections + if self.project: + ps.step() + + if self.max_radius_gen_out[index] == float("inf"): + self.do_project_gen_out = False + + if self.max_radius_latent[index] == float("inf"): + self.do_project_latent = False + + if self.max_radius_noises[index] == float("inf"): + self.do_project_noises = False + + if start_layer != 0 and self.do_project_gen_out: + deviation = project_onto_l1_ball( + self.gen_outs[-1] - prev_gen_out, self.max_radius_gen_out[index] + ) + var_list[-1].data = (prev_gen_out + deviation).data + if self.do_project_latent: + deviation = project_onto_l1_ball( + self.latent_z - prev_latent, self.max_radius_latent[index] + ) + var_list[0].data = (prev_latent + deviation).data + if self.do_project_noises: + deviations = [ + project_onto_l1_ball(noise - prev_noise, self.max_radius_noises[index]) + for noise, prev_noise in zip(self.noises, prev_noises) + ] + for i, deviation in enumerate(deviations): + var_list[i + 1].data = (prev_noises[i] + deviation).data + + # Update best image + # if mse_loss < mse_min: + # mse_min = mse_loss + # self.best = img_gen + + # if self.save_forward: + # self.best_forward = A_img_gen + + self.best = img_gen + + if self.save_forward: + self.best_forward = A_img_gen + + # Update tqdm and print + pbar.set_description((f"perceptual: {p_loss:.4f};" f" mse: {mse_loss:.4f};")) + # TODO : probably broken coz of batch + # Save some intermediate images of the optimization + if self.save_gif and i % self.save_every == 0: + torchvision.utils.save_image( + img_gen, + f"gif_{start_layer}_{i}.png", + nrow=int(img_gen.shape[0] ** 0.5), + normalize=True, + ) + + # Update in between layers + with torch.no_grad(): + latent_w = self.mpl(self.latent_z) + self.gen.end_layer = self.gen.start_layer + intermediate_out, _ = self.gen( + [latent_w], + input_is_latent=True, + noise=self.noises, + layer_in=self.gen_outs[-1], + skip=None, + ) + self.gen_outs.append(intermediate_out) + self.gen.end_layer = self.end_layer + print() + + def invert(self): + print("Start of the invertion") + for i, steps in enumerate(self.steps): + begin_from = i + self.start_layer + if begin_from > self.end_layer: + raise Exception("Attempting to go after end layer...") + self.invert_(begin_from, range(5 + 2 * begin_from), int(steps), i) + + return ( + self.input_images, + (self.latent_z, self.noises, self.gen_outs), + (self.best, self.best_forward), + ) diff --git a/lensless/recon/ilo_stylegan2/lpips/__init__.py b/lensless/recon/ilo_stylegan2/lpips/__init__.py new file mode 100644 index 00000000..3927b179 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/lpips/__init__.py @@ -0,0 +1,48 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import torch +from . import dist_model + + +class PerceptualLoss(torch.nn.Module): + def __init__( + self, + model="net-lin", + net="alex", + colorspace="rgb", + spatial=False, + use_gpu=True, + gpu_ids=[0], + ): + super(PerceptualLoss, self).__init__() + self.use_gpu = use_gpu + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model = dist_model.DistModel() + self.model.initialize( + model=model, + net=net, + use_gpu=use_gpu, + colorspace=colorspace, + spatial=self.spatial, + gpu_ids=gpu_ids, + ) + + def forward(self, pred, target, normalize=False): + """ + Pred and target are Variables. + If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] + If normalize is False, assumes the images are already between [-1,+1] + + Inputs pred and target are Nx3xHxW + Output pytorch Variable N long + """ + + if normalize: + target = 2 * target - 1 + pred = 2 * pred - 1 + + return self.model.forward(target, pred) diff --git a/lensless/recon/ilo_stylegan2/lpips/base_model.py b/lensless/recon/ilo_stylegan2/lpips/base_model.py new file mode 100644 index 00000000..5c851ac8 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/lpips/base_model.py @@ -0,0 +1,59 @@ +import os +import numpy as np +import torch + + +class BaseModel: + def __init__(self): + pass + + def name(self): + return "BaseModel" + + def initialize(self, use_gpu=True, gpu_ids=[0]): + self.use_gpu = use_gpu + self.gpu_ids = gpu_ids + + def forward(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = "%s_net_%s.pth" % (epoch_label, network_label) + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = "%s_net_%s.pth" % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print("Loading network from %s" % save_path) + network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, "done_flag"), flag) + np.savetxt( + os.path.join(self.save_dir, "done_flag"), + [ + flag, + ], + fmt="%i", + ) diff --git a/lensless/recon/ilo_stylegan2/lpips/dist_model.py b/lensless/recon/ilo_stylegan2/lpips/dist_model.py new file mode 100644 index 00000000..ccfc340d --- /dev/null +++ b/lensless/recon/ilo_stylegan2/lpips/dist_model.py @@ -0,0 +1,307 @@ +from __future__ import absolute_import + +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +from .base_model import BaseModel +from scipy.ndimage import zoom +from tqdm import tqdm + + +from . import networks_basic as networks +from . import utils as util + + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize( + self, + model="net-lin", + net="alex", + colorspace="Lab", + pnet_rand=False, + pnet_tune=False, + model_path=None, + use_gpu=True, + printNet=False, + spatial=False, + is_train=False, + lr=0.0001, + beta1=0.5, + version="0.1", + gpu_ids=[0], + ): + """ + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). + spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. + spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + gpu_ids - int array - [0] by default, gpus to use + """ + BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model_name = "%s [%s]" % (model, net) + + if self.model == "net-lin": # pretrained net + linear layer + self.net = networks.PNetLin( + pnet_rand=pnet_rand, + pnet_tune=pnet_tune, + pnet_type=net, + use_dropout=True, + spatial=spatial, + version=version, + lpips=True, + ) + kw = {} + if not use_gpu: + kw["map_location"] = "cpu" + if model_path is None: + import inspect + + model_path = os.path.abspath( + os.path.join( + inspect.getfile(self.initialize), + "..", + "weights/v%s/%s.pth" % (version, net), + ) + ) + + if not is_train: + print("Loading model from: %s" % model_path) + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif self.model == "net": # pretrained network + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif self.model in ["L2", "l2"]: + self.net = networks.L2( + use_gpu=use_gpu, colorspace=colorspace + ) # not really a network, only for testing + self.model_name = "L2" + elif self.model in ["DSSIM", "dssim", "SSIM", "ssim"]: + self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace) + self.model_name = "SSIM" + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = networks.BCERankingLoss() + self.parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + if use_gpu: + self.net.to(gpu_ids[0]) + self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + if self.is_train: + self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if printNet: + print("---------- Networks initialized -------------") + networks.print_network(self.net) + print("-----------------------------------------------") + + def forward(self, in0, in1, retPerLayer=False): + """Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + """ + + return self.net.forward(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if hasattr(module, "weight") and module.kernel_size == (1, 1): + module.weight.data = torch.clamp(module.weight.data, min=0) + + def set_input(self, data): + self.input_ref = data["ref"] + self.input_p0 = data["p0"] + self.input_p1 = data["p1"] + self.input_judge = data["judge"] + + if self.use_gpu: + self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + self.var_ref = Variable(self.input_ref, requires_grad=True) + self.var_p0 = Variable(self.input_p0, requires_grad=True) + self.var_p1 = Variable(self.input_p1, requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + self.d0 = self.forward(self.var_ref, self.var_p0) + self.d1 = self.forward(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) + + self.var_judge = Variable(1.0 * self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2.0 - 1.0) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self, d0, d1, judge): + """d0, d1 are Variables, judge is a Tensor""" + d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() + judge_per = judge.cpu().numpy().flatten() + return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) + + def get_current_errors(self): + retDict = OrderedDict( + [("loss_total", self.loss_total.data.cpu().numpy()), ("acc_r", self.acc_r)] + ) + + for key in retDict.keys(): + retDict[key] = np.mean(retDict[key]) + + return retDict + + def get_current_visuals(self): + zoom_factor = 256 / self.var_ref.data.size()[2] + + ref_img = util.tensor2im(self.var_ref.data) + p0_img = util.tensor2im(self.var_p0.data) + p1_img = util.tensor2im(self.var_p1.data) + + ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) + p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) + p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) + + return OrderedDict([("ref", ref_img_vis), ("p0", p0_img_vis), ("p1", p1_img_vis)]) + + def save(self, path, label): + if self.use_gpu: + self.save_network(self.net.module, path, "", label) + else: + self.save_network(self.net, path, "", label) + self.save_network(self.rankLoss.net, path, "rank", label) + + def update_learning_rate(self, nepoch_decay): + lrd = self.lr / nepoch_decay + lr = self.old_lr - lrd + + for param_group in self.optimizer_net.param_groups: + param_group["lr"] = lr + + print("update lr [%s] decay: %f -> %f" % (type, self.old_lr, lr)) + self.old_lr = lr + + +def score_2afc_dataset(data_loader, func, name=""): + """Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + """ + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s += func(data["ref"], data["p0"]).data.cpu().numpy().flatten().tolist() + d1s += func(data["ref"], data["p1"]).data.cpu().numpy().flatten().tolist() + gts += data["judge"].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s < d1s) * (1.0 - gts) + (d1s < d0s) * gts + (d1s == d0s) * 0.5 + + return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) + + +def score_jnd_dataset(data_loader, func, name=""): + """Function computes JND score using distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return pytorch array of length N + OUTPUTS + [0] - JND score in [0,1], mAP score (area under precision-recall curve) + [1] - dictionary with following elements + ds - N array containing distances between two patches shown to human evaluator + sames - N array containing fraction of people who thought the two patches were identical + CONSTS + N - number of test triplets in data_loader + """ + + ds = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + ds += func(data["p0"], data["p1"]).data.cpu().numpy().tolist() + gts += data["same"].cpu().numpy().flatten().tolist() + + sames = np.array(gts) + ds = np.array(ds) + + sorted_inds = np.argsort(ds) + sames_sorted = sames[sorted_inds] + + TPs = np.cumsum(sames_sorted) + FPs = np.cumsum(1 - sames_sorted) + FNs = np.sum(sames_sorted) - TPs + + precs = TPs / (TPs + FPs) + recs = TPs / (TPs + FNs) + score = util.voc_ap(recs, precs) + + return (score, dict(ds=ds, sames=sames)) diff --git a/lensless/recon/ilo_stylegan2/lpips/networks_basic.py b/lensless/recon/ilo_stylegan2/lpips/networks_basic.py new file mode 100644 index 00000000..07a5330f --- /dev/null +++ b/lensless/recon/ilo_stylegan2/lpips/networks_basic.py @@ -0,0 +1,250 @@ +from __future__ import absolute_import + +import torch +import torch.nn as nn +from torch.autograd import Variable +from . import pretrained_networks as pn +from . import utils as util + + +def spatial_average(in_tens, keepdim=True): + return in_tens.mean([2, 3], keepdim=keepdim) + + +def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W + in_H = in_tens.shape[2] + scale_factor = 1.0 * out_H / in_H + + return nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)(in_tens) + + +# Learned perceptual metric +class PNetLin(nn.Module): + def __init__( + self, + pnet_type="vgg", + pnet_rand=False, + pnet_tune=False, + use_dropout=True, + spatial=False, + version="0.1", + lpips=True, + ): + super(PNetLin, self).__init__() + + self.pnet_type = pnet_type + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.lpips = lpips + self.version = version + self.scaling_layer = ScalingLayer() + + if self.pnet_type in ["vgg", "vgg16"]: + net_type = pn.vgg16 + self.chns = [64, 128, 256, 512, 512] + elif self.pnet_type == "alex": + net_type = pn.alexnet + self.chns = [64, 192, 384, 256, 256] + elif self.pnet_type == "squeeze": + net_type = pn.squeezenet + self.chns = [64, 128, 256, 384, 384, 512, 512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + if lpips: + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + if self.pnet_type == "squeeze": # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins += [self.lin5, self.lin6] + + def forward(self, in0, in1, retPerLayer=False): + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = ( + (self.scaling_layer(in0), self.scaling_layer(in1)) + if self.version == "0.1" + else (in0, in1) + ) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + feats0, feats1, diffs = {}, {}, {} + + for kk in range(self.L): + feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor( + outs1[kk] + ) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + if self.lpips: + if self.spatial: + res = [ + upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) + for kk in range(self.L) + ] + else: + res = [ + spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(self.L) + ] + else: + if self.spatial: + res = [ + upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) + for kk in range(self.L) + ] + else: + res = [ + spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) + for kk in range(self.L) + ] + + val = res[0] + for idx in range(1, self.L): + val += res[idx] + + if retPerLayer: + return (val, res) + else: + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) + self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class Dist2LogitLayer(nn.Module): + """takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True)""" + + def __init__(self, chn_mid=32, use_sigmoid=True): + super(Dist2LogitLayer, self).__init__() + + layers = [ + nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), + ] + layers += [ + nn.LeakyReLU(0.2, True), + ] + layers += [ + nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), + ] + layers += [ + nn.LeakyReLU(0.2, True), + ] + layers += [ + nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), + ] + if use_sigmoid: + layers += [ + nn.Sigmoid(), + ] + self.model = nn.Sequential(*layers) + + def forward(self, d0, d1, eps=0.1): + return self.model.forward( + torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1) + ) + + +class BCERankingLoss(nn.Module): + def __init__(self, chn_mid=32): + super(BCERankingLoss, self).__init__() + self.net = Dist2LogitLayer(chn_mid=chn_mid) + # self.parameters = list(self.net.parameters()) + self.loss = torch.nn.BCELoss() + + def forward(self, d0, d1, judge): + per = (judge + 1.0) / 2.0 + self.logit = self.net.forward(d0, d1) + return self.loss(self.logit, per) + + +# L2, DSSIM metrics +class FakeNet(nn.Module): + def __init__(self, use_gpu=True, colorspace="Lab"): + super(FakeNet, self).__init__() + self.use_gpu = use_gpu + self.colorspace = colorspace + + +class L2(FakeNet): + def forward(self, in0, in1, retPerLayer=None): + assert in0.size()[0] == 1 # currently only supports batchSize 1 + + if self.colorspace == "RGB": + (N, C, X, Y) = in0.size() + value = torch.mean( + torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view( + N, 1, 1, Y + ), + dim=3, + ).view(N) + return value + elif self.colorspace == "Lab": + value = util.l2( + util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), + util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), + range=100.0, + ).astype("float") + ret_var = Variable(torch.Tensor((value,))) + if self.use_gpu: + ret_var = ret_var.cuda() + return ret_var + + +class DSSIM(FakeNet): + def forward(self, in0, in1, retPerLayer=None): + assert in0.size()[0] == 1 # currently only supports batchSize 1 + + if self.colorspace == "RGB": + value = util.dssim( + 1.0 * util.tensor2im(in0.data), 1.0 * util.tensor2im(in1.data), range=255.0 + ).astype("float") + elif self.colorspace == "Lab": + value = util.dssim( + util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), + util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), + range=100.0, + ).astype("float") + ret_var = Variable(torch.Tensor((value,))) + if self.use_gpu: + ret_var = ret_var.cuda() + return ret_var + + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print("Network", net) + print("Total number of parameters: %d" % num_params) diff --git a/lensless/recon/ilo_stylegan2/lpips/pretrained_networks.py b/lensless/recon/ilo_stylegan2/lpips/pretrained_networks.py new file mode 100644 index 00000000..9b594646 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/lpips/pretrained_networks.py @@ -0,0 +1,187 @@ +from collections import namedtuple +import torch +from torchvision import models as tv + + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple( + "SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"] + ) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple( + "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] + ) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if num == 18: + self.net = tv.resnet18(pretrained=pretrained) + elif num == 34: + self.net = tv.resnet34(pretrained=pretrained) + elif num == 50: + self.net = tv.resnet50(pretrained=pretrained) + elif num == 101: + self.net = tv.resnet101(pretrained=pretrained) + elif num == 152: + self.net = tv.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/lensless/recon/ilo_stylegan2/lpips/utils.py b/lensless/recon/ilo_stylegan2/lpips/utils.py new file mode 100644 index 00000000..0839c7db --- /dev/null +++ b/lensless/recon/ilo_stylegan2/lpips/utils.py @@ -0,0 +1,118 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from skimage.metrics import structural_similarity + +import torch + + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) + return in_feat / (norm_factor + eps) + + +def l2(p0, p1, range=255.0): + return 0.5 * np.mean((p0 / range - p1 / range) ** 2) + + +def psnr(p0, p1, peak=255.0): + return 10 * np.log10(peak**2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) + + +def dssim(p0, p1, range=255.0): + return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2.0 + + +def tensor2np(tensor_obj): + # change dimension of a tensor object into a numpy array + return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) + + +def np2tensor(np_obj): + # change dimenion of np array into tensor array + return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + + +def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): + # image tensor to lab tensor + from skimage import color + + img = tensor2im(image_tensor) + img_lab = color.rgb2lab(img) + if mc_only: + img_lab[:, :, 0] = img_lab[:, :, 0] - 50 + if to_norm and not mc_only: + img_lab[:, :, 0] = img_lab[:, :, 0] - 50 + img_lab = img_lab / 100.0 + + return np2tensor(img_lab) + + +def tensorlab2tensor(lab_tensor, return_inbnd=False): + from skimage import color + import warnings + + warnings.filterwarnings("ignore") + + lab = tensor2np(lab_tensor) * 100.0 + lab[:, :, 0] = lab[:, :, 0] + 50 + + rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) + if return_inbnd: + # convert back to lab, see if we match + lab_back = color.rgb2lab(rgb_back.astype("uint8")) + mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) + mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) + return (im2tensor(rgb_back), mask) + else: + return im2tensor(rgb_back) + + +def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + + +def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): + return torch.Tensor((image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + + +def tensor2vec(vector_tensor): + return vector_tensor.data.cpu().numpy()[:, :, 0, 0] + + +def voc_ap(rec, prec, use_07_metric=False): + """ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0.0 + for t in np.arange(0.0, 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11.0 + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.0], rec, [1.0])) + mpre = np.concatenate(([0.0], prec, [0.0])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap diff --git a/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/alex.pth b/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/alex.pth new file mode 100644 index 00000000..256709fa Binary files /dev/null and b/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/alex.pth differ diff --git a/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/squeeze.pth b/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/squeeze.pth new file mode 100644 index 00000000..1e8faa77 Binary files /dev/null and b/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/squeeze.pth differ diff --git a/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/vgg.pth b/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/vgg.pth new file mode 100644 index 00000000..90f4e134 Binary files /dev/null and b/lensless/recon/ilo_stylegan2/lpips/weights/v0.0/vgg.pth differ diff --git a/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/alex.pth b/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/alex.pth new file mode 100644 index 00000000..1df9dfe6 Binary files /dev/null and b/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/alex.pth differ diff --git a/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/squeeze.pth b/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/squeeze.pth new file mode 100644 index 00000000..a3bd383b Binary files /dev/null and b/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/squeeze.pth differ diff --git a/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/vgg.pth b/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/vgg.pth new file mode 100644 index 00000000..47e943cf Binary files /dev/null and b/lensless/recon/ilo_stylegan2/lpips/weights/v0.1/vgg.pth differ diff --git a/lensless/recon/ilo_stylegan2/op/__init__.py b/lensless/recon/ilo_stylegan2/op/__init__.py new file mode 100644 index 00000000..d0918d92 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/lensless/recon/ilo_stylegan2/op/fused_act.py b/lensless/recon/ilo_stylegan2/op/fused_act.py new file mode 100644 index 00000000..7a072af0 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/op/fused_act.py @@ -0,0 +1,92 @@ +import os + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +fused = load( + "fused", + sources=[ + os.path.join(module_path, "fused_bias_act.cpp"), + os.path.join(module_path, "fused_bias_act_kernel.cu"), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + (out,) = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + (out,) = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + if input.device.type == "cpu": + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale + ) + + else: + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/lensless/recon/ilo_stylegan2/op/fused_bias_act.cpp b/lensless/recon/ilo_stylegan2/op/fused_bias_act.cpp new file mode 100644 index 00000000..a0543187 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/lensless/recon/ilo_stylegan2/op/fused_bias_act_kernel.cu b/lensless/recon/ilo_stylegan2/op/fused_bias_act_kernel.cu new file mode 100644 index 00000000..8d2f03c7 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/lensless/recon/ilo_stylegan2/op/upfirdn2d.cpp b/lensless/recon/ilo_stylegan2/op/upfirdn2d.cpp new file mode 100644 index 00000000..b07aa205 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/lensless/recon/ilo_stylegan2/op/upfirdn2d.py b/lensless/recon/ilo_stylegan2/op/upfirdn2d.py new file mode 100644 index 00000000..1fd2750b --- /dev/null +++ b/lensless/recon/ilo_stylegan2/op/upfirdn2d.py @@ -0,0 +1,190 @@ +import os + +import torch +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + "upfirdn2d", + sources=[ + os.path.join(module_path, "upfirdn2d.cpp"), + os.path.join(module_path, "upfirdn2d_kernel.cu"), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + (kernel,) = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == "cpu": + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + else: + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/lensless/recon/ilo_stylegan2/op/upfirdn2d_kernel.cu b/lensless/recon/ilo_stylegan2/op/upfirdn2d_kernel.cu new file mode 100644 index 00000000..ed3eea30 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/op/upfirdn2d_kernel.cu @@ -0,0 +1,369 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} \ No newline at end of file diff --git a/lensless/recon/ilo_stylegan2/stylegan2.py b/lensless/recon/ilo_stylegan2/stylegan2.py new file mode 100644 index 00000000..cf21644d --- /dev/null +++ b/lensless/recon/ilo_stylegan2/stylegan2.py @@ -0,0 +1,674 @@ +""" +TODO : authorship from original stylegan2 repo / ILO +https://github.com/giannisdaras/ilo/blob/master/model.py +""" + +import math +import random + +import torch +from torch import nn +from torch.nn import functional as F +from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input**2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor**2) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor**2) + + self.register_buffer("kernel", kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size**2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size**2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " + f"upsample={self.upsample}, downsample={self.downsample})" + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + self.start_layer = 0 + self.end_layer = 8 + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear(style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu") + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2**res, 2**res] + self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2**i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv(out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self, bs=1): + device = self.input.input.device + + noises = [torch.randn(bs, 1, 2**2, 2**2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(bs, 1, 2**i, 2**i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn(n_latent, self.style_dim, device=self.input.input.device) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + layer_in=None, + skip=None, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append(truncation_latent + truncation * (style - truncation_latent)) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + ## Changed by ILO + out = self.input(latent) + + if self.start_layer == 0: + out = self.conv1(out, latent[:, 0], noise=noise[0]) # 0th layer + skip = self.to_rgb1(out, latent[:, 1]) + if self.end_layer == 0: + return out, skip + i = 1 + current_layer = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + if current_layer < self.start_layer: + pass + elif current_layer == self.start_layer: + out = conv1(layer_in, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + elif current_layer > self.end_layer: + return out, skip + else: + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + current_layer += 1 + i += 2 + + ## + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/lensless/recon/ilo_stylegan2/utils.py b/lensless/recon/ilo_stylegan2/utils.py new file mode 100644 index 00000000..d3d39495 --- /dev/null +++ b/lensless/recon/ilo_stylegan2/utils.py @@ -0,0 +1,139 @@ +import re +import numpy as np +from PIL import Image + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def zero_padding_tensor(images, final_size): + # -- pad rest with zeros + # negative cropping = center crop + + padding = np.array(final_size[-2:]) - np.array(images[0].shape[-2:]) + left = padding[1] // 2 + right = padding[1] - left + top = padding[0] // 2 + bottom = padding[0] - top + padder = torch.nn.ConstantPad2d((left, right, top, bottom), 0.0) + + return padder(images) + + +def project_onto_l1_ball(x, eps): + """ + See: https://gist.github.com/tonyduan/1329998205d88c566588e57e3e2c0c55 + """ + original_shape = x.shape + x = x.view(x.shape[0], -1) + mask = (torch.norm(x, p=1, dim=1) < eps).float().unsqueeze(1) + mu, _ = torch.sort(torch.abs(x), dim=1, descending=True) + cumsum = torch.cumsum(mu, dim=1) + arange = torch.arange(1, x.shape[1] + 1, device=x.device) + rho, _ = torch.max((mu * arange > (cumsum - eps)) * arange, dim=1) + theta = (cumsum[torch.arange(x.shape[0]), rho.cpu() - 1] - eps) / rho + proj = (torch.abs(x) - theta.unsqueeze(1)).clamp(min=0) + x = mask * x + (1 - mask) * proj * torch.sign(x) + return x.view(original_shape) + + +class BicubicDownSample(nn.Module): + def bicubic_kernel(self, x, a=-0.50): + """ + This equation is exactly copied from the website below: + https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic + """ + abs_x = torch.abs(x) + if abs_x <= 1.0: + return (a + 2.0) * torch.pow(abs_x, 3.0) - (a + 3.0) * torch.pow(abs_x, 2.0) + 1 + elif 1.0 < abs_x < 2.0: + return ( + a * torch.pow(abs_x, 3) + - 5.0 * a * torch.pow(abs_x, 2.0) + + 8.0 * a * abs_x + - 4.0 * a + ) + else: + return 0.0 + + def __init__(self, factor=4, cuda=True, padding="reflect"): + super().__init__() + self.factor = factor + size = factor * 4 + k = torch.tensor( + [ + self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) + for i in range(size) + ], + dtype=torch.float32, + ) + k = k / torch.sum(k) + # k = torch.einsum('i,j->ij', (k, k)) + k1 = torch.reshape(k, shape=(1, 1, size, 1)) + self.k1 = torch.cat([k1, k1, k1], dim=0) + k2 = torch.reshape(k, shape=(1, 1, 1, size)) + self.k2 = torch.cat([k2, k2, k2], dim=0) + self.cuda = ".cuda" if cuda else "" + self.padding = padding + # self.padding = 'constant' + # self.padding = 'replicate' + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x, nhwc=False, clip_round=False, byte_output=False): + filter_height = self.factor * 4 + filter_width = self.factor * 4 + stride = self.factor + pad_along_height = max(filter_height - stride, 0) + pad_along_width = max(filter_width - stride, 0) + filters1 = self.k1.type("torch{}.FloatTensor".format(self.cuda)) + filters2 = self.k2.type("torch{}.FloatTensor".format(self.cuda)) + # compute actual padding values for each side + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + # apply mirror padding + if nhwc: + x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW + # downscaling performed by 1-d convolution + x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) + x = F.conv2d(input=x.float(), weight=filters1, stride=(stride, 1), groups=3) + if clip_round: + x = torch.clamp(torch.round(x), 0.0, 255.0) + x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) + x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) + if clip_round: + x = torch.clamp(torch.round(x), 0.0, 255.0) + if nhwc: + x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) + if byte_output: + return x.type("torch.ByteTensor") + else: + return x + + +# Utils +def get_array(file): + img = np.array(Image.open(file).convert("RGB")) + img = img / 255 + return img + + +def atof(text): + try: + retval = float(text) + except ValueError: + retval = text + return retval + + +def natural_keys(text): + """ + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + float regex comes from https://stackoverflow.com/a/12643073/190597 + """ + return [atof(c) for c in re.split(r"[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)", text)] diff --git a/lensless/utils/align_face.py b/lensless/utils/align_face.py new file mode 100644 index 00000000..d3724ab3 --- /dev/null +++ b/lensless/utils/align_face.py @@ -0,0 +1,192 @@ +import os +import numpy as np +import PIL.Image +import scipy.ndimage +import dlib + +# try: +# import torch +# import torchvision.transforms as tf +# from torchvision.datasets.utils import download_and_extract_archive +# +# torch_available = True +# except ImportError: +# torch_available = False + + +def get_predictor(predictor_path): + + if not os.path.exists(predictor_path): + print(f"Predictor for alignment not found at {predictor_path}.") + + remote_aligner = "shape_predictor_68_face_landmarks.dat" + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = f"Do you want to download and use {remote_aligner} from SwitchDrive? This file and others needed for ILO will be downloaded (560MB)." + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + + current_path = os.path.dirname(__file__) + + model_dir = os.path.join(current_path, "..", "..", "models") + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + ilo_path = os.path.join(model_dir, "ilo") + if os.path.exists(ilo_path): + print("ILO already exists (no need to download).") + else: + url = "https://drive.switch.ch/index.php/s/hD0JqMJemFJ7FKo/download" + filename = "ilo.zip" + download_and_extract_archive( + url, model_dir, filename=filename, remove_finished=True + ) + + predictor_path = os.path.join(model_dir, "ilo", remote_aligner) + + predictor = dlib.shape_predictor(predictor_path) + return predictor + + +def get_landmark(filepath, predictor_path): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + + predictor = get_predictor(predictor_path) + detector = dlib.get_frontal_face_detector() + + img = dlib.load_rgb_image(filepath) + dets = detector(img, 1) + + print("Number of faces detected: {}".format(len(dets))) + for k, d in enumerate(dets): + print( + "Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format( + k, d.left(), d.top(), d.right(), d.bottom() + ) + ) + # Get the landmarks/parts for the face in box d. + shape = predictor(img, d) + print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1))) + + t = list(shape.parts()) + a = [] + for tt in t: + a.append([tt.x, tt.y]) + lm = np.array(a) + # lm is a shape=(68,2) np.array + return lm + + +def align_face(filepath, predictor_path): + """ + :param filepath: str + :return: PIL Image + """ + + lm = get_landmark(filepath, predictor_path) + + lm_eye_left = lm[36:42] # left-clockwise + lm_eye_right = lm[42:48] # left-clockwise + lm_mouth_outer = lm[48:60] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # read image + img = PIL.Image.open(filepath) + + output_size = 1024 + transform_size = 4096 + enable_padding = True + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = ( + int(np.rint(float(img.size[0]) / shrink)), + int(np.rint(float(img.size[1]) / shrink)), + ) + img = img.resize(rsize, PIL.Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = ( + int(np.floor(min(quad[:, 0]))), + int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1]))), + ) + crop = ( + max(crop[0] - border, 0), + max(crop[1] - border, 0), + min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1]), + ) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = ( + int(np.floor(min(quad[:, 0]))), + int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1]))), + ) + pad = ( + max(-pad[0] + border, 0), + max(-pad[1] + border, 0), + max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0), + ) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect") + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum( + 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]), + ) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip( + mask * 3.0 + 1.0, 0.0, 1.0 + ) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), "RGB") + quad += pad[:2] + + # Transform. + img = img.transform( + (transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR + ) + if output_size < transform_size: + img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) + + # Save aligned image. + return img diff --git a/lensless_environment.yml b/lensless_environment.yml new file mode 100644 index 00000000..19055e5b --- /dev/null +++ b/lensless_environment.yml @@ -0,0 +1,302 @@ +name: lensless_class +channels: + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotlipy=0.7.0=py39h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2022.07.19=h06a4308_0 + - certifi=2022.6.15=py39h06a4308_0 + - cffi=1.15.1=py39h74dc2b5_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cryptography=37.0.1=py39h9ce1e76_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - ffmpeg=4.3=hf484d3e_0 + - freetype=2.11.0=h70c0345_0 + - giflib=5.2.1=h7b6447c_0 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - idna=3.3=pyhd3eb1b0_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - jpeg=9e=h7f8727e_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libdeflate=1.8=h7f8727e_5 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.4.0=hecacb30_0 + - libunistring=0.9.10=h27cfd23_0 + - libwebp=1.2.2=h55f646e_0 + - libwebp-base=1.2.2=h7f8727e_0 + - lz4-c=1.9.3=h295c915_1 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py39h7f8727e_0 + - mkl_fft=1.3.1=py39hd3c417c_0 + - mkl_random=1.2.2=py39h51133e4_0 + - ncurses=6.3=h5eee18b_3 + - nettle=3.7.3=hbbd107a_1 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1q=h7f8727e_0 + - pip=22.1.2=py39h06a4308_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=22.0.0=pyhd3eb1b0_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.13=haa1d7c7_1 + - pytorch-mutex=1.0=cuda + - readline=8.1.2=h7f8727e_1 + - requests=2.28.1=py39h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.39.2=h5082296_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.12.1=py39_cu113 + - typing_extensions=4.3.0=py39h06a4308_0 + - tzdata=2022a=hda174b7_0 + - urllib3=1.26.11=py39h06a4308_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.5=h7f8727e_1 + - zlib=1.2.12=h7f8727e_2 + - zstd=1.5.2=ha4553b6_0 + - pip: + - absl-py==1.4.0 + - adafruit-blinka==8.20.0 + - adafruit-circuitpython-busdevice==5.2.6 + - adafruit-circuitpython-framebuf==1.6.3 + - adafruit-circuitpython-pcd8544==1.2.14 + - adafruit-circuitpython-requests==2.0.0 + - adafruit-circuitpython-rgb-display==3.12.0 + - adafruit-circuitpython-sharpmemorydisplay==1.4.9 + - adafruit-circuitpython-typing==1.9.4 + - adafruit-platformdetect==3.47.0 + - adafruit-pureio==1.1.11 + - aiohttp==3.8.5 + - aiosignal==1.3.1 + - alabaster==0.7.12 + - antlr4-python3-runtime==4.9.3 + - argon2-cffi==21.3.0 + - argon2-cffi-bindings==21.2.0 + - asgiref==3.5.2 + - asttokens==2.0.8 + - async-timeout==4.0.2 + - attrs==22.1.0 + - babel==2.10.3 + - backcall==0.2.0 + - bcrypt==4.0.1 + - beautifulsoup4==4.11.1 + - black==22.8.0 + - bleach==5.0.1 + - bokeh==2.4.3 + - boto3==1.26.57 + - botocore==1.29.57 + - cachetools==5.3.0 + - click==8.0.4 + - cloudpickle==2.2.0 + - cmake==3.27.0 + - colorzero==2.0 + - cycler==0.11.0 + - dask==2023.7.1 + - datasets==2.14.0 + - debugpy==1.6.3 + - decorator==5.1.1 + - defusedxml==0.7.1 + - dill==0.3.7 + - distributed==2023.7.1 + - django==4.1.1 + - docutils==0.19 + - entrypoints==0.4 + - executing==1.0.0 + - fastjsonschema==2.16.1 + - filelock==3.9.0 + - fire==0.5.0 + - frozenlist==1.4.0 + - fsspec==2022.8.2 + - gitdb==4.0.10 + - gitpython==3.1.30 + - google-auth==2.16.0 + - google-auth-oauthlib==0.4.6 + - grpcio==1.51.1 + - heapdict==1.0.1 + - huggingface-hub==0.16.4 + - hydra-core==1.3.2 + - image==1.5.33 + - imageio==2.21.2 + - imageio-ffmpeg==0.4.3 + - imagesize==1.4.1 + - importlib-metadata==6.8.0 + - iniconfig==1.1.1 + - ipykernel==6.15.2 + - ipython==8.5.0 + - ipython-genutils==0.2.0 + - ipywidgets==8.0.2 + - jedi==0.18.0 + - jinja2==3.1.2 + - jmespath==1.0.1 + - joblib==1.1.0 + - jsonschema==4.15.0 + - jupyter==1.0.0 + - jupyter-client==7.3.5 + - jupyter-console==6.4.4 + - jupyter-core==4.11.1 + - jupyterlab-pygments==0.2.2 + - jupyterlab-widgets==3.0.3 + - kiwisolver==1.4.4 + - lensless==1.0.4 + - lit==16.0.6 + - llvmlite==0.39.1 + - locket==1.0.0 + - lpips==0.1.4 + - lxml==4.9.1 + - markdown==3.4.1 + - markupsafe==2.1.1 + - matplotlib==3.4.2 + - matplotlib-inline==0.1.6 + - mistune==2.0.4 + - mpmath==1.3.0 + - msgpack==1.0.4 + - multidict==6.0.4 + - multiprocess==0.70.15 + - mypy-extensions==0.4.3 + - nbclient==0.6.7 + - nbconvert==7.0.0 + - nbformat==5.4.0 + - nest-asyncio==1.5.5 + - networkx==2.8.6 + - ninja==1.10.2.3 + - notebook==6.4.12 + - numba==0.56.2 + - numpy==1.23.5 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-cupti-cu11==11.7.101 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - nvidia-cufft-cu11==10.9.0.58 + - nvidia-curand-cu11==10.2.10.91 + - nvidia-cusolver-cu11==11.4.0.1 + - nvidia-cusparse-cu11==11.7.4.91 + - nvidia-nccl-cu11==2.14.3 + - nvidia-nvtx-cu11==11.7.91 + - oauthlib==3.2.2 + - omegaconf==2.3.0 + - opencv-python==4.5.1.48 + - packaging==21.3 + - pandas==1.4.4 + - pandocfilters==1.5.0 + - paramiko==3.2.0 + - parso==0.8.3 + - partd==1.3.0 + - pathspec==0.10.1 + - perlin-numpy==0.0.0 + - pexpect==4.8.0 + - picamerax==20.9.1 + - pickleshare==0.7.5 + - pillow==9.2.0 + - platformdirs==2.5.2 + - pluggy==1.0.0 + - progressbar==2.5 + - prometheus-client==0.14.1 + - prompt-toolkit==3.0.31 + - protobuf==3.20.3 + - psutil==5.9.2 + - ptyprocess==0.7.0 + - pudb==2022.1.2 + - pure-eval==0.2.2 + - py==1.11.0 + - pyarrow==12.0.1 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pybboxes==0.1.6 + - pycsou==1.0.7.dev1687 + - pyffs==2.2.1 + - pyftdi==0.54.0 + - pygments==2.13.0 + - pylops==1.18.0 + - pynacl==1.5.0 + - pyparsing==3.0.9 + - pyrsistent==0.18.1 + - pyserial==3.5 + - pyspng==0.1.0 + - pytest==7.1.3 + - python-dateutil==2.8.2 + - pytz==2022.2.1 + - pyusb==1.2.1 + - pywavelets==1.3.0 + - pyyaml==6.0 + - pyzmq==23.2.1 + - qtconsole==5.3.2 + - qtpy==2.2.0 + - rawpy==0.17.2 + - regex==2022.10.31 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - s3transfer==0.6.0 + - sahi==0.11.11 + - scikit-image==0.19.3 + - scikit-learn==1.1.2 + - scipy==1.11.1 + - seaborn==0.12.0 + - send2trash==1.8.0 + - setuptools==59.8.0 + - shapely==2.0.0 + - slm-controller==0.0.2 + - smmap==5.0.0 + - snowballstemmer==2.2.0 + - sortedcontainers==2.4.0 + - soupsieve==2.3.2.post1 + - sparse==0.14.0 + - sphinx==2.1.2 + - sphinx-rtd-theme==0.4.3 + - sphinxcontrib-applehelp==1.0.2 + - sphinxcontrib-devhelp==1.0.2 + - sphinxcontrib-htmlhelp==2.0.0 + - sphinxcontrib-jsmath==1.0.1 + - sphinxcontrib-qthelp==1.0.3 + - sphinxcontrib-serializinghtml==1.1.5 + - sqlparse==0.4.2 + - stack-data==0.5.0 + - sympy==1.12 + - tblib==1.7.0 + - tensorboard==2.11.2 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - termcolor==2.2.0 + - terminado==0.15.0 + - terminaltables==3.1.10 + - thop==0.1.1-2209072238 + - threadpoolctl==3.1.0 + - tifffile==2022.8.12 + - tinycss2==1.1.1 + - tokenizers==0.13.2 + - tomli==2.0.1 + - toolz==0.12.0 + - torch==2.0.1 + - torchmetrics==0.9.3 + - torchvision==0.15.2 + - tornado==6.1 + - tqdm==4.64.1 + - traitlets==5.3.0 + - transformers==4.26.0 + - triton==2.0.0 + - urwid==2.1.2 + - urwid-readline==0.13 + - waveprop==0.0.4 + - wcwidth==0.2.5 + - webencodings==0.5.1 + - werkzeug==2.2.2 + - widgetsnbextension==4.0.3 + - xxhash==3.2.0 + - yarl==1.9.2 + - yolov5==7.0.7 + - zict==2.2.0 + - zipp==3.8.1 +prefix: /home/bezzam/.conda/envs/lensless_class diff --git a/recon_requirements.txt b/recon_requirements.txt index 5d142936..d5999bd3 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -4,9 +4,12 @@ pylops==1.18.0 scikit-image>=0.19.0rc0 hydra-core click>=8.0.1 -waveprop>=0.0.3 # for simulation +waveprop>=0.0.5 # for simulation # Library for learning algorithm torch >= 1.8.0 torchvision -lpips \ No newline at end of file +lpips + +# Library for ILO +dlib==19.24.2 \ No newline at end of file diff --git a/scripts/sim/ilo_dataset.py b/scripts/sim/ilo_dataset.py new file mode 100644 index 00000000..340e478c --- /dev/null +++ b/scripts/sim/ilo_dataset.py @@ -0,0 +1,255 @@ +""" + +Simulate a mask, use face alignment on a face image and simulate a measurement with the mask on the image. + +Procedure is as follows: + +1) Simulate the mask. +2) Align the face. +3) Simulate a measurement with the mask and specified physical parameters. + +Example usage: + +Simulate FlatCam with separable simulation (https://arxiv.org/abs/1509.00116, Eq 7): +``` +python scripts/sim/ilo_single_file.py mask.type=MLS simulation.flatcam=True recon.algo=tikhonov +``` + +Simulate FlatCam with PSF simulation: +``` +python scripts/sim/ilo_single_file.py mask.type=MLS simulation.flatcam=False +``` + +Simulate Fresnel Zone Aperture camera with PSF simulation (https://www.nature.com/articles/s41377-020-0289-9): +``` +python scripts/sim/ilo_single_file.py mask.type=FZA +``` + +Simulate PhaseContour camera with PSF simulation (https://ieeexplore.ieee.org/document/9076617): +``` +python scripts/sim/ilo_single_file.py mask.type=PhaseContour +``` + +""" + +import hydra +import warnings +from hydra.utils import to_absolute_path +from lensless.utils.io import load_psf, load_image # , save_image +from lensless.utils.image import rgb2gray, rgb2bayer +from lensless.utils.align_face import align_face +from lensless.recon.ilo_stylegan2.ilo import LatentOptimizer +import glob +from tqdm import tqdm +import math as ma +import torchvision + +import numpy as np +import matplotlib.pyplot as plt +from lensless.utils.plot import plot_image + +# from lensless.eval.metric import mse, psnr, ssim, lpips +from waveprop.simulation import FarFieldSimulator +import os +from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture + + +@hydra.main(version_base=None, config_path="../../configs", config_name="ilo_single_file") +def simulate(config): + + fp = to_absolute_path(config.files.original) + assert os.path.exists(fp), f"File {fp} does not exist." + + # simulation parameters + object_height = config.lensless_imaging.object_height + scene2mask = config.lensless_imaging.scene2mask + mask2sensor = config.lensless_imaging.mask2sensor + sensor = config.lensless_imaging.sensor + snr_db = config.lensless_imaging.simulated.snr + downsample = config.lensless_imaging.downsample + max_val = config.lensless_imaging.simulated.max_val + + image_format = config.lensless_imaging.image_format.lower() + grayscale = False + if image_format == "grayscale": + grayscale = True + + # 1) simulate mask + mask_type = config.mask.type + if mask_type.upper() in ["MURA", "MLS"]: + mask = CodedAperture.from_sensor( + sensor_name=sensor, + downsample=downsample, + method=mask_type, + distance_sensor=mask2sensor, + **config.mask, + ) + psf = mask.psf / np.linalg.norm(mask.psf.ravel()) + elif mask_type.upper() == "FZA": + mask = FresnelZoneAperture.from_sensor( + sensor_name=sensor, + downsample=downsample, + distance_sensor=mask2sensor, + **config.mask, + ) + psf = mask.psf / np.linalg.norm(mask.psf.ravel()) + elif mask_type == "PhaseContour": + mask = PhaseContour.from_sensor( + sensor_name=sensor, + downsample=downsample, + distance_sensor=mask2sensor, + **config.mask, + ) + psf = mask.psf / np.linalg.norm(mask.psf.ravel()) + else: + mask = None + psf_fp = to_absolute_path(config.files.psf) + assert os.path.exists(psf_fp), f"PSF {psf_fp} does not exist." + psf = load_psf(psf_fp, verbose=True, downsample=downsample) + psf = psf.squeeze() + + if grayscale and psf.shape[-1] == 3: + psf = rgb2gray(psf) + if downsample > 1: + print(f"Downsampled to {psf.shape}.") + + # 2) simulate measurement + face_aligner = to_absolute_path(config.files.face_aligner) + image = np.array(align_face(fp, face_aligner)) / 255 + + if grayscale and len(image.shape) == 3: + image = rgb2gray(image) + + flatcam_sim = config.simulation.flatcam + if flatcam_sim and mask_type.upper() not in ["MURA", "MLS"]: + warnings.warn( + "Flatcam simulation only supported for MURA and MLS masks. Using far field simulation with PSF." + ) + flatcam_sim = False + + # use far field simulator to get correct object plane sizing + simulator = FarFieldSimulator( + psf=psf, # only support one depth plane + object_height=object_height, + scene2mask=scene2mask, + mask2sensor=mask2sensor, + sensor=sensor, + snr_db=snr_db, + max_val=max_val, + ) + image_plane, object_plane = simulator.propagate(image, return_object_plane=True) + + if image_format == "grayscale": + image_plane = rgb2gray(image_plane) + object_plane = rgb2gray(object_plane) + elif "bayer" in image_format: + image_plane = rgb2bayer(image_plane, pattern=image_format[-4:]) + object_plane = rgb2bayer(object_plane, pattern=image_format[-4:]) + else: + # make sure image is RGB + assert image_plane.shape[-1] == 3, "Image plane must be RGB" + assert object_plane.shape[-1] == 3, "Object plane must be RGB" + + if flatcam_sim: + # apply flatcam simulation to object plane + image_plane = mask.simulate(object_plane, snr_db=snr_db) + + # -- plot + fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) + plot_image(object_plane, ax=ax[0]) + ax[0].set_title("Object plane") + plot_image(psf, ax=ax[1], gamma=3.5) + ax[1].set_title("PSF") + plot_image(image_plane, ax=ax[2]) + ax[2].set_title("Raw data") + plt.savefig("result.png") + fig.tight_layout() + plt.show() + + for a in ax: + a.set_axis_off() + + # 3) reconstruction (TODO) + latent_optimizer = LatentOptimizer(config, mask=mask) + + # Load input files name + input_files = [ + x + for x in glob.iglob( + os.path.join( + to_absolute_path(config["files"]["preprocess_dir"]), + "*" + config["files"]["files_ext"], + ) + ) + ] + print(input_files) + + # Configure output files name + if not (os.path.isdir(config["files"]["output_dir"])): + os.mkdir(config["files"]["output_dir"]) + + base_names = [x.split("/")[-1].split("_")[-2] for x in input_files] + new_names_output = [ + base_name + "_processed" + config["files"]["files_ext"] for base_name in base_names + ] + output_files = [ + os.path.join(config["files"]["output_dir"], new_name) for new_name in new_names_output + ] + + if config["logs"]["save_forward"]: + if not (os.path.isdir(config["logs"]["forward_dir"])): + os.mkdir(config["logs"]["forward_dir"]) + new_names_output_forward = [ + base_name + "_forward" + config["files"]["files_ext"] for base_name in base_names + ] + output_forward_files = [ + os.path.join(config["logs"]["forward_dir"], new_name) + for new_name in new_names_output_forward + ] + + # Optimize in batches + batchsize = config["opti_params"]["batchsize_process"] + n_batchs = ma.ceil(len(input_files) / batchsize) + + for i in tqdm(range(n_batchs), desc="Batch"): + + input_files_batch = input_files[i * batchsize : (i + 1) * batchsize] + output_files_batch = output_files[i * batchsize : (i + 1) * batchsize] + output_forward_files_batch = output_forward_files[i * batchsize : (i + 1) * batchsize] + + # Initialize state of optimizer + latent_optimizer.set_data(input_files_batch) + + # Invert + _, z, best = latent_optimizer.apply() + + best_img = best[0].detach().cpu() + if config["logs"]["save_forward"]: + best_img_forward = best[1].detach().cpu() + + print() + + for j in range(len(input_files_batch)): + # Save output image + print(f"Saving file: {output_files_batch[j]}") + torchvision.utils.save_image( + best_img[j], + output_files_batch[j], + nrow=int(best_img[j].shape[0] ** 0.5), + normalize=True, + ) + + if config["logs"]["save_forward"]: + # Save debug image + torchvision.utils.save_image( + best_img_forward[j], + output_forward_files_batch[j], + nrow=int(best_img_forward[j].shape[0] ** 0.5), + normalize=False, + ) + + # plt.show() + + +if __name__ == "__main__": + simulate() diff --git a/scripts/sim/ilo_single_file.py b/scripts/sim/ilo_single_file.py new file mode 100644 index 00000000..cc208883 --- /dev/null +++ b/scripts/sim/ilo_single_file.py @@ -0,0 +1,250 @@ +""" + +Simulate a mask, use face alignment on a face image and simulate a measurement with the mask on the image. + +Procedure is as follows: + +1) Simulate the mask. +2) Align the face. +3) Simulate a measurement with the mask and specified physical parameters. + +Example usage: + +Simulate FlatCam with separable simulation (https://arxiv.org/abs/1509.00116, Eq 7): +``` +python scripts/sim/ilo_single_file.py mask.type=MLS simulation.flatcam=True recon.algo=tikhonov +``` + +Simulate FlatCam with PSF simulation: +``` +python scripts/sim/ilo_single_file.py mask.type=MLS simulation.flatcam=False +``` + +Simulate Fresnel Zone Aperture camera with PSF simulation (https://www.nature.com/articles/s41377-020-0289-9): +``` +python scripts/sim/ilo_single_file.py mask.type=FZA +``` + +Simulate PhaseContour camera with PSF simulation (https://ieeexplore.ieee.org/document/9076617): +``` +python scripts/sim/ilo_single_file.py mask.type=PhaseContour +``` + +""" + +import hydra +import warnings +from hydra.utils import to_absolute_path +from lensless.utils.io import load_psf, load_image, save_image +from lensless.utils.image import rgb2gray, rgb2bayer +from lensless.utils.align_face import align_face +from lensless.recon.ilo_stylegan2.ilo import LatentOptimizer +import glob +from tqdm import tqdm +import math as ma +import torchvision + +import numpy as np +import matplotlib.pyplot as plt +from lensless.utils.plot import plot_image + +# from lensless.eval.metric import mse, psnr, ssim, lpips +from waveprop.simulation import FarFieldSimulator +import os +from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture + + +@hydra.main(version_base=None, config_path="../../configs", config_name="ilo_single_file") +def simulate(config): + + fp = to_absolute_path(config.files.original_fp) + bn = os.path.basename(fp).split(".")[0] + assert os.path.exists(fp), f"File {fp} does not exist." + + # simulation parameters + object_height = config.lensless_imaging.object_height + scene2mask = config.lensless_imaging.scene2mask + mask2sensor = config.lensless_imaging.mask2sensor + sensor = config.lensless_imaging.sensor + snr_db = config.lensless_imaging.simulated.snr + downsample = config.lensless_imaging.downsample + max_val = config.lensless_imaging.simulated.max_val + + image_format = config.lensless_imaging.image_format.lower() + grayscale = False + if image_format == "grayscale": + grayscale = True + + # 1) simulate mask + mask_type = config.mask.type + if mask_type.upper() in ["MURA", "MLS"]: + mask = CodedAperture.from_sensor( + sensor_name=sensor, + downsample=downsample, + method=mask_type, + distance_sensor=mask2sensor, + **config.mask, + ) + psf = mask.psf / np.linalg.norm(mask.psf.ravel()) + elif mask_type.upper() == "FZA": + mask = FresnelZoneAperture.from_sensor( + sensor_name=sensor, + downsample=downsample, + distance_sensor=mask2sensor, + **config.mask, + ) + psf = mask.psf / np.linalg.norm(mask.psf.ravel()) + elif mask_type == "PhaseContour": + mask = PhaseContour.from_sensor( + sensor_name=sensor, + downsample=downsample, + distance_sensor=mask2sensor, + **config.mask, + ) + psf = mask.psf / np.linalg.norm(mask.psf.ravel()) + else: + mask = None + psf_fp = to_absolute_path(config.files.psf) + assert os.path.exists(psf_fp), f"PSF {psf_fp} does not exist." + psf = load_psf(psf_fp, verbose=True, downsample=downsample) + psf = psf.squeeze() + + if grayscale and psf.shape[-1] == 3: + psf = rgb2gray(psf) + if downsample > 1: + print(f"Downsampled to {psf.shape}.") + + # 2) simulate measurement + # -- first align face + face_aligner = to_absolute_path(config.files.face_aligner) + image_aligned = np.array(align_face(fp, face_aligner)) / 255 + aligned_fn = bn + "_aligned.png" + save_image(image_aligned, aligned_fn) + print(f"Saved aligned image to {aligned_fn}.") + + if grayscale and len(image_aligned.shape) == 3: + image_aligned = rgb2gray(image_aligned) + + flatcam_sim = config.simulation.flatcam + if flatcam_sim and mask_type.upper() not in ["MURA", "MLS"]: + warnings.warn( + "Flatcam simulation only supported for MURA and MLS masks. Using far field simulation with PSF." + ) + flatcam_sim = False + + # use far field simulator to get correct object plane sizing + simulator = FarFieldSimulator( + psf=psf, # only support one depth plane + object_height=object_height, + scene2mask=scene2mask, + mask2sensor=mask2sensor, + sensor=sensor, + snr_db=snr_db, + max_val=max_val, + ) + image_plane, object_plane = simulator.propagate(image_aligned, return_object_plane=True) + simulated_fn = bn + "_sim_meas.png" + save_image(image_plane, simulated_fn) + print(f"Saved simulated image to {simulated_fn}.") + + if image_format == "grayscale": + image_plane = rgb2gray(image_plane) + object_plane = rgb2gray(object_plane) + elif "bayer" in image_format: + image_plane = rgb2bayer(image_plane, pattern=image_format[-4:]) + object_plane = rgb2bayer(object_plane, pattern=image_format[-4:]) + else: + # make sure image is RGB + assert image_plane.shape[-1] == 3, "Image plane must be RGB" + assert object_plane.shape[-1] == 3, "Object plane must be RGB" + + if flatcam_sim: + # apply flatcam simulation to object plane + image_plane = mask.simulate(object_plane, snr_db=snr_db) + + # -- plot + fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) + plot_image(object_plane, ax=ax[0]) + ax[0].set_title("Object plane") + plot_image(psf, ax=ax[1], gamma=3.5) + ax[1].set_title("PSF") + plot_image(image_plane, ax=ax[2]) + ax[2].set_title("Raw data") + plt.savefig("result.png") + fig.tight_layout() + + for a in ax: + a.set_axis_off() + + # 3) reconstruction (TODO) + latent_optimizer = LatentOptimizer(config, psf=psf, mask=mask) + + # input_file = to_absolute_path(config.files.preprocess_dir + config.files.image_name) + + """ + # Configure output files name + if not (os.path.isdir(config["files"]["output_dir"])): + os.mkdir(config["files"]["output_dir"]) + + base_name = input_file.split("/")[-1].split("_")[-2] + new_names_output = [ + base_name + "_processed" + config["files"]["files_ext"] for base_name in base_names + ] + output_files = [ + os.path.join(config["files"]["output_dir"], new_name) for new_name in new_names_output + ] + + if config["logs"]["save_forward"]: + if not (os.path.isdir(config["logs"]["forward_dir"])): + os.mkdir(config["logs"]["forward_dir"]) + new_names_output_forward = [ + base_name + "_forward" + config["files"]["files_ext"] for base_name in base_names + ] + output_forward_files = [ + os.path.join(config["logs"]["forward_dir"], new_name) + for new_name in new_names_output_forward + ] + """ + + # Initialize state of optimizer + latent_optimizer.init_state([simulated_fn]) + + # Invert + _, _, best = latent_optimizer.invert() + + best_img = best[0].detach().cpu().squeeze().permute(1, 2, 0) + plt.figure() + plt.imshow(best_img) + plt.savefig("ILO_result.png") + + # import pudb ; pudb.set_trace() + + # best_img = best[0].detach().cpu() + # if config["logs"]["save_forward"]: + # best_img_forward = best[1].detach().cpu() + + # print() + + # Save output image + # print(f"Saving file: {output_files_batch[j]}") + # torchvision.utils.save_image( + # best_img[j], + # output_files_batch[j], + # nrow=int(best_img[j].shape[0] ** 0.5), + # normalize=True, + # ) + + # if config["logs"]["save_forward"]: + # # Save debug image + # torchvision.utils.save_image( + # best_img_forward[j], + # output_forward_files_batch[j], + # nrow=int(best_img_forward[j].shape[0] ** 0.5), + # normalize=False, + # ) + + # plt.show() + + +if __name__ == "__main__": + simulate() diff --git a/scripts/sim/import_try.py b/scripts/sim/import_try.py new file mode 100644 index 00000000..7be016fe --- /dev/null +++ b/scripts/sim/import_try.py @@ -0,0 +1,47 @@ +from lensless.utils.align_face import align_face +#from lensless.recon.ilo_stylegan2.ilo import LatentOptimizer +#print('Done') + +import os +from torchvision.datasets.utils import download_and_extract_archive +import matplotlib.pyplot as plt +import shutil +import numpy as np + +fp = 'data/celeba_mini/000019.jpg' + +predictor_path = os.path.join("models", "shape_predictor_68_face_landmarks.dat") + +if not os.path.exists("models"): + os.makedirs("models") +if not os.path.exists(predictor_path): + msg = "Do you want to download the face landmark model (61.1 Mo)?" + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + if os.path.exists(predictor_path + '.bz2'): + os.remove(predictor_path + '.bz2') + url = "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2" + filename = "shape_predictor_68_face_landmarks.dat.bz2" + download_and_extract_archive(url, "models", filename=filename, remove_finished=True) + + +if not os.path.exists(fp): + msg = "Do you want to download the sample CelebA dataset (764KB)?" + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + if os.path.exists("data/celeba_mini"): + shutil.rmtree("data/celeba_mini") + url = "https://drive.switch.ch/index.php/s/Q5OdDQMwhucIlt8/download" + filename = "celeb_mini.zip" + download_and_extract_archive( + url, "data", filename=filename, remove_finished=True + ) + + +aligned = np.array(align_face(fp, predictor_path)) / 255 + +print(aligned.shape, aligned.min(), aligned.max()) + +plt.figure(figsize=(10,10)) +plt.imshow(aligned) +plt.show() diff --git a/scripts/sim/mask_single_file.py b/scripts/sim/mask_single_file.py index e8a741b5..60633e7e 100644 --- a/scripts/sim/mask_single_file.py +++ b/scripts/sim/mask_single_file.py @@ -104,6 +104,8 @@ def simulate(config): n_iter=config.mask.phase_mask_iter, **config.mask, ) + assert mask is not None, "Unsuited mask type" + psf = mask.psf / np.linalg.norm(mask.psf.ravel()) # 2) simulate measurement image = load_image(fp, verbose=True) / 255 @@ -117,7 +119,7 @@ def simulate(config): # use far field simulator to get correct object plane sizing simulator = FarFieldSimulator( - psf=mask.psf, + psf=psf, object_height=object_height, scene2mask=scene2mask, mask2sensor=mask2sensor, @@ -218,7 +220,7 @@ def simulate(config): ax[4].set_title("Reconstruction") for a in ax: - a.set_xticks([]), a.set_yticks([]) + a.set_axis_off() plt.tight_layout() plt.savefig("result.png")