diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 885766f3..3ec13120 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -22,7 +22,7 @@ ) -def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): +def benchmark(model, dataset, batchsize=1, metrics=None, mask_crop=None, **kwargs): """ Compute multiple metrics for a reconstruction algorithm. @@ -36,6 +36,8 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): Batch size for processing. For maximum compatibility use 1 (batchsize above 1 are not supported on all algorithm), by default 1 metrics : dict, optional Dictionary of metrics to compute. If None, MSE, MAE, SSIM, LPIPS and PSNR are computed. + mask_crop : torch.Tensor, optional + Mask to apply to the output of the reconstruction algorithm, by default None. Returns ------- @@ -80,6 +82,11 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): # Convert to [N*D, C, H, W] for torchmetrics prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) + + if mask_crop is not None: + prediction = prediction * mask_crop + lensed = lensed * mask_crop + # normalization prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True) if torch.all(prediction_max != 0): diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 2ca758c6..b1e64a8e 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -259,6 +259,7 @@ def __init__( save_every=None, gamma=None, logger=None, + crop=None, ): """ Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. @@ -309,6 +310,8 @@ def __init__( Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None. logger : :py:class:`logging.Logger`, optional Logger to use for logging. If None, just print to terminal. Default is None. + crop : dict, optional + Crop to apply to images before computing loss (by applying a mask). If None, no crop is applied. Default is None. """ @@ -370,6 +373,21 @@ def __init__( "lpips package is need for LPIPS loss. Install using : pip install lpips" ) + if crop is not None: + datashape = train_dataset[0][0].shape + # create binary mask to multiply with before computing loss + self.mask_crop = torch.zeros(datashape, dtype=torch.bool).to(self.device) + + # move channel dimension to third to last + self.mask_crop = self.mask_crop.movedim(-1, -3) + + # set values to True in mask + self.mask_crop[ + :, :, crop.vertical[0] : crop.vertical[1], crop.horizontal[0] : crop.horizontal[1] + ] = True + else: + self.mask_crop = None + # optimizer if optimizer == "Adam": # the parameters of the base model and non torch.Module process must be added separatly @@ -495,6 +513,11 @@ def train_epoch(self, data_loader, disp=-1): y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) + # crop + if self.mask_crop is not None: + y_pred = y_pred * self.mask_crop + y = y * self.mask_crop + loss_v = self.Loss(y_pred, y) if self.lpips: @@ -556,7 +579,9 @@ def evaluate(self, mean_loss, save_pt): if self.test_dataset is None: return # benchmarking - current_metrics = benchmark(self.recon, self.test_dataset, batchsize=self.eval_batch_size) + current_metrics = benchmark( + self.recon, self.test_dataset, batchsize=self.eval_batch_size, mask_crop=self.mask_crop + ) # update metrics with current metrics self.metrics["LOSS"].append(mean_loss) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index a5a2e8a9..5fd32f56 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -133,8 +133,8 @@ def __getitem__(self, idx): # flip image x and y if needed if self.flip: - lensless = torch.rot90(lensless, dims=(-3, -2)) - lensed = torch.rot90(lensed, dims=(-3, -2)) + lensless = torch.rot90(lensless, dims=(-3, -2), k=2) + lensed = torch.rot90(lensed, dims=(-3, -2), k=2) if self.transform_lensless: lensless = self.transform_lensless(lensless) if self.transform_lensed: @@ -230,20 +230,27 @@ def __len__(self): class MeasuredDatasetSimulatedOriginal(DualDataset): """ + Abstract class for defining a dataset of paired lensed and lensless images. + Dataset consisting of lensless image captured from a screen and the corresponding image shown on the screen. Unlike :py:class:`lensless.utils.dataset.MeasuredDataset`, the ground-truth lensed image is simulated using a :py:class:`lensless.utils.simulation.FarFieldSimulator` object rather than measured with a lensed camera. + + The class assumes that the ``measured_dir`` and ``original_dir`` have file names that match. + + The method ``_get_images_pair`` must be defined. """ def __init__( self, - root_dir, + measured_dir, + original_dir, simulator, - lensless_fn="diffuser", - original_fn="lensed", - image_ext="npy", - original_ext=None, + measurement_ext="png", + original_ext="jpg", downsample=1, + background=None, + flip=False, **kwargs, ): """ @@ -251,42 +258,34 @@ def __init__( Parameters ---------- - root_dir : str - Path to the test dataset. It is expected to contain two folders: one of lensless images and one of original images. - simulator : :py:class:`lensless.utils.simulatorFarFieldSimulator` - Simulator to use for the projection of the original image to object space. The PSF **should not** be specified, and it is expect to have ``is_torch = True``. - lensless_fn : str, optional - Name of the folder containing the lensless images, by default "diffuser". - lensed_fn : str, optional - Name of the folder containing the lensed images, by default "lensed". - image_ext : str, optional - Extension of the images, by default "npy". - original_ext : str, optional - Extension of the original image if different from lenless, by default None. - downsample : int, optional - Downsample factor of the lensless images, by default 1. """ - super(MeasuredDatasetSimulatedOriginal, self).__init__(downsample=1, **kwargs) + super(MeasuredDatasetSimulatedOriginal, self).__init__( + downsample=1, background=background, flip=flip, **kwargs + ) self.pre_downsample = downsample - self.root_dir = root_dir - self.lensless_dir = os.path.join(root_dir, lensless_fn) - self.original_dir = os.path.join(root_dir, original_fn) - assert os.path.isdir(self.lensless_dir) + self.measured_dir = measured_dir + self.original_dir = original_dir + assert os.path.isdir(self.measured_dir) assert os.path.isdir(self.original_dir) - self.image_ext = image_ext.lower() - self.original_ext = original_ext.lower() if original_ext is not None else image_ext.lower() + self.measurement_ext = measurement_ext.lower() + self.original_ext = original_ext.lower() - files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) + files = glob.glob(os.path.join(self.measured_dir, "*." + self.measurement_ext)) files.sort() self.files = [os.path.basename(fn) for fn in files] if len(self.files) == 0: raise FileNotFoundError( - f"No files found in {self.lensless_dir} with extension {image_ext}" + f"No files found in {self.lensless_dir} with extension {self.measurement_ext }" ) + # check that corresponding files exist + for fn in self.files: + original_fp = os.path.join(self.original_dir, fn[:-3] + self.original_ext) + assert os.path.exists(original_fp), f"File {original_fp} does not exist" + # check simulator assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" assert simulator.is_torch, "Simulator should be a pytorch simulator" @@ -299,30 +298,176 @@ def __len__(self): else: return len([i for i in self.indices if i < len(self.files)]) - def _get_images_pair(self, idx): - if self.image_ext == "npy" or self.image_ext == "npz": - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - original_fp = os.path.join(self.original_dir, self.files[idx]) - lensless = np.load(lensless_fp) - lensless = resize(lensless, factor=1 / self.downsample) - original = np.load(original_fp[:-3] + self.original_ext) - else: - # more standard image formats: png, jpg, tiff, etc. - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - original_fp = os.path.join(self.original_dir, self.files[idx]) - lensless = load_image(lensless_fp, downsample=self.pre_downsample) - original = load_image( - original_fp[:-3] + self.original_ext, downsample=self.pre_downsample + # def _get_images_pair(self, idx): + # if self.image_ext == "npy" or self.image_ext == "npz": + # lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + # original_fp = os.path.join(self.original_dir, self.files[idx]) + # lensless = np.load(lensless_fp) + # lensless = resize(lensless, factor=1 / self.downsample) + # original = np.load(original_fp[:-3] + self.original_ext) + # else: + # # more standard image formats: png, jpg, tiff, etc. + # lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + # original_fp = os.path.join(self.original_dir, self.files[idx]) + # lensless = load_image(lensless_fp, downsample=self.pre_downsample) + # original = load_image( + # original_fp[:-3] + self.original_ext, downsample=self.pre_downsample + # ) + + # # convert to float + # if lensless.dtype == np.uint8: + # lensless = lensless.astype(np.float32) / 255 + # original = original.astype(np.float32) / 255 + # else: + # # 16 bit + # lensless = lensless.astype(np.float32) / 65535 + # original = original.astype(np.float32) / 65535 + + # # convert to torch + # lensless = torch.from_numpy(lensless) + # original = torch.from_numpy(original) + + # # project original image to lensed space + # with torch.no_grad(): + # lensed = self.sim.propagate_image() + + # return lensless, lensed + + +class DigiCamCelebA(MeasuredDatasetSimulatedOriginal): + def __init__( + self, + celeba_root, + data_dir=None, + psf_path=None, + downsample=1, + flip=True, + vertical_shift=-85, + horizontal_shift=-15, + simulation_config=None, + **kwargs, + ): + """ + + Parameters + ---------- + celeba_root : str + Path to the CelebA dataset. + data_dir : str, optional + Path to the lensless images, by default looks inside the ``data`` folder. Can download if not available. + psf_path : str, optional + Path to the PSF of the imaging system, by default looks inside the ``data/psf`` folder. Can download if not available. + downsample : int, optional + Downsample factor of the lensless images, by default 1. + flip : bool, optional + If True, measurements are flipped, by default ``True``. Does not get applied to the original images. + vertical_shift : int, optional + Vertical shift (in pixels) of the lensed images to align, by default 0. + horizontal_shift : int, optional + Horizontal shift (in pixels) of the lensed images to align, by default 0. + """ + + # download dataset if necessary + if data_dir is None: + data_dir = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "data", + "celeba_adafruit_random_2mm_20230720_10K", ) + if not os.path.isdir(data_dir): + main_dir = os.path.join(os.path.dirname(__file__), "..", "..", "data") + print("DigiCam CelebA dataset not found.") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download this dataset of 10K examples (12.2GB)?" - # convert to float - if lensless.dtype == np.uint8: - lensless = lensless.astype(np.float32) / 255 - original = original.astype(np.float32) / 255 - else: - # 16 bit - lensless = lensless.astype(np.float32) / 65535 - original = original.astype(np.float32) / 65535 + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/9NNGCJs3DoBDGlY/download" + filename = "celeba_adafruit_random_2mm_20230720_10K.zip" + download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True) + + # download PSF if necessary + if psf_path is None: + psf_path = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "data", + "psf", + "adafruit_random_2mm_20231907.png", + ) + if not os.path.exists(psf_path): + try: + from torchvision.datasets.utils import download_url + except ImportError: + exit() + msg = "Do you want to download the PSF (38.8MB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + output_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "psf") + if valid: + url = "https://drive.switch.ch/index.php/s/kfN5vOqvVkNyHmc/download" + filename = "adafruit_random_2mm_20231907.png" + download_url(url, output_path, filename=filename) + + # load PSF + self.flip_measurement = flip + self.vertical_shift = vertical_shift + self.horizontal_shift = horizontal_shift + psf, background = load_psf( + psf_path, + downsample=downsample * 4, # PSF is 4x the resolution of the images + return_float=True, + return_bg=True, + flip=flip, + bg_pix=(0, 15), + ) + self.psf = torch.from_numpy(psf) + + # create simulator + simulation_config["output_dim"] = tuple(self.psf.shape[-3:-1]) + simulator = FarFieldSimulator( + is_torch=True, + **simulation_config, + ) + + super().__init__( + measured_dir=data_dir, + original_dir=os.path.join(celeba_root, "celeba", "img_align_celeba"), + simulator=simulator, + measurement_ext="png", + original_ext="jpg", + downsample=downsample, + background=background, + flip=False, # will do flipping only on measurement + **kwargs, + ) + + def _get_images_pair(self, idx): + + # more standard image formats: png, jpg, tiff, etc. + lensless_fp = os.path.join(self.measured_dir, self.files[idx]) + original_fp = os.path.join(self.original_dir, self.files[idx][:-3] + self.original_ext) + lensless = load_image( + lensless_fp, downsample=self.pre_downsample, flip=self.flip_measurement + ) + original = load_image(original_fp[:-3] + self.original_ext) + + # convert to float + if lensless.dtype == np.uint8: + lensless = lensless.astype(np.float32) / 255 + original = original.astype(np.float32) / 255 + else: + # 16 bit + lensless = lensless.astype(np.float32) / 65535 + original = original.astype(np.float32) / 65535 # convert to torch lensless = torch.from_numpy(lensless) @@ -330,7 +475,12 @@ def _get_images_pair(self, idx): # project original image to lensed space with torch.no_grad(): - lensed = self.sim.propagate_image() + lensed = self.sim.propagate_image(original, return_object_plane=True) + + if self.vertical_shift != 0: + lensed = torch.roll(lensed, self.vertical_shift, dims=-3) + if self.horizontal_shift != 0: + lensed = torch.roll(lensed, self.horizontal_shift, dims=-2) return lensless, lensed diff --git a/lensless/utils/image.py b/lensless/utils/image.py index 748aaf50..5dc1bef2 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -8,6 +8,7 @@ import cv2 +import scipy.signal import numpy as np from lensless.hardware.constants import RPI_HQ_CAMERA_CCM_MATRIX, RPI_HQ_CAMERA_BLACK_LEVEL @@ -57,9 +58,8 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC): if torch_available: # torch resize expects an input of form [color, depth, width, height] tmp = np.moveaxis(img, -1, 0) - resized = tf.Resize(size=new_shape, interpolation=interpolation)( - torch.from_numpy(tmp) - ).numpy() + tmp = torch.from_numpy(tmp.copy()) + resized = tf.Resize(size=new_shape)(tmp).numpy() resized = np.moveaxis(resized, 0, -1) else: @@ -327,6 +327,25 @@ def autocorr2d(vals, pad_mode="reflect"): return autocorr[shape[0] // 2 : -shape[0] // 2, shape[1] // 2 : -shape[1] // 2] +def corr2d(im1, im2): + """ + Source: https://stackoverflow.com/a/24769222 + + """ + + # get rid of the color channels by performing a grayscale transform + # the type cast into 'float' is to avoid overflows + im1_gray = np.sum(im1.astype("float"), axis=2) + im2_gray = np.sum(im2.astype("float"), axis=2) + + # get rid of the averages, otherwise the results are not good + im1_gray -= np.mean(im1_gray) + im2_gray -= np.mean(im2_gray) + + # calculate the correlation image; note the flipping of onw of the images + return scipy.signal.fftconvolve(im1_gray, im2_gray[::-1, ::-1], mode="same") + + def rgb2bayer(img, pattern): """ Converting RGB image to separated Bayer channels. @@ -442,133 +461,3 @@ def bayer2rgb(X_bayer, pattern): X_rgb[:, :, 2] = X_bayer[:, :, b_channel] return X_rgb - - -def load_drunet(model_path, n_channels=3, requires_grad=False): - """ - Load a pre-trained Drunet model. - - Parameters - ---------- - model_path : str - Path to pre-trained model. - n_channels : int - Number of channels in input image. - requires_grad : bool - Whether to require gradients for model parameters. - - Returns - ------- - model : :py:class:`~torch.nn.Module` - Loaded model. - """ - from lensless.recon.drunet.network_unet import UNetRes - - model = UNetRes( - in_nc=n_channels + 1, - out_nc=n_channels, - nc=[64, 128, 256, 512], - nb=4, - act_mode="R", - downsample_mode="strideconv", - upsample_mode="convtranspose", - ) - model.load_state_dict(torch.load(model_path), strict=True) - model.eval() - for k, v in model.named_parameters(): - v.requires_grad = requires_grad - - return model - - -def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference"): - """ - Apply a pre-trained denoising model with input in the format Channel, Height, Width. - An additionnal channel is added for the noise level as done in Drunet. - - Parameters - ---------- - model : :py:class:`~torch.nn.Module` - Drunet compatible model. Its input must concist of 4 channels ( RGB + noise level) and outbut an RGB image both in CHW format. - image : :py:class:`~torch.Tensor` - Input image. - noise_level : float or :py:class:`~torch.Tensor` - Noise level in the image. - device : str - Device to use for computation. Can be "cpu" or "cuda". - mode : str - Mode to use for model. Can be "inference" or "train". - - Returns - ------- - image : :py:class:`~torch.Tensor` - Reconstructed image. - """ - # convert from NDHWC to NCHW - depth = image.shape[-4] - image = image.movedim(-1, -3) - image = image.reshape(-1, *image.shape[-3:]) - # pad image H and W to next multiple of 8 - top = (8 - image.shape[-2] % 8) // 2 - bottom = (8 - image.shape[-2] % 8) - top - left = (8 - image.shape[-1] % 8) // 2 - right = (8 - image.shape[-1] % 8) - left - image = torch.nn.functional.pad(image, (left, right, top, bottom), mode="constant", value=0) - # add noise level as extra channel - image = image.to(device) - if isinstance(noise_level, torch.Tensor): - noise_level = noise_level / 255.0 - else: - noise_level = torch.tensor([noise_level / 255.0]).to(device) - image = torch.cat( - ( - image, - noise_level.repeat(image.shape[0], 1, image.shape[2], image.shape[3]), - ), - dim=1, - ) - - # apply model - if mode == "inference": - with torch.no_grad(): - image = model(image) - elif mode == "train": - image = model(image) - else: - raise ValueError("mode must be 'inference' or 'train'") - - # remove padding - image = image[:, :, top:-bottom, left:-right] - # convert back to NDHWC - image = image.movedim(-3, -1) - image = image.reshape(-1, depth, *image.shape[-3:]) - return image - - -def process_with_DruNet(model, device="cpu", mode="inference"): - """ - Return a porcessing function that applies the DruNet model to an image. - - Parameters - ---------- - model : torch.nn.Module - DruNet like denoiser model - device : str - Device to use for computation. Can be "cpu" or "cuda". - mode : str - Mode to use for model. Can be "inference" or "train". - """ - - def process(image, noise_level): - x_max = torch.amax(image, dim=(-2, -3), keepdim=True) + 1e-6 - image = apply_denoiser( - model, - image, - noise_level=noise_level, - device=device, - mode="train", - ) - image = torch.clip(image, min=0.0) * x_max - return image - - return process diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 1b2b234f..d65f5532 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -504,17 +504,19 @@ def load_data( def save_image(img, fp, max_val=255): """Save as uint8 image.""" - if img.dtype == np.uint16: - img = img.astype(np.float32) + img_tmp = img.copy() - if img.dtype == np.float64 or img.dtype == np.float32: - img -= img.min() - img /= img.max() - img *= max_val - img = img.astype(np.uint8) + if img_tmp.dtype == np.uint16: + img_tmp = img_tmp.astype(np.float32) - img = Image.fromarray(img) - img.save(fp) + if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32: + img_tmp -= img_tmp.min() + img_tmp /= img_tmp.max() + img_tmp *= max_val + img_tmp = img_tmp.astype(np.uint8) + + img_tmp = Image.fromarray(img_tmp) + img_tmp.save(fp) def get_dtype(dtype=None, is_torch=False): diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py index b77fabcb..53d6257b 100644 --- a/lensless/utils/simulation.py +++ b/lensless/utils/simulation.py @@ -58,15 +58,16 @@ def __init__( Whether to quantize image, by default True. """ - assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" + if psf is not None: + assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" - if torch.is_tensor(psf): - # drop depth dimension, and convert HWC to CHW - psf = psf[0].movedim(-1, 0) - assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" - else: - psf = psf[0] - assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" + if torch.is_tensor(psf): + # drop depth dimension, and convert HWC to CHW + psf = psf[0].movedim(-1, 0) + assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + psf = psf[0] + assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" super().__init__( object_height, @@ -84,12 +85,15 @@ def __init__( **kwargs ) - if self.is_torch: - assert self.psf.shape[0] == 1 or self.psf.shape[0] == 3, "PSF must have 1 or 3 channels" - else: - assert ( - self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 - ), "PSF must have 1 or 3 channels" + if psf is not None: + if self.is_torch: + assert ( + self.psf.shape[0] == 1 or self.psf.shape[0] == 3 + ), "PSF must have 1 or 3 channels" + else: + assert ( + self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 + ), "PSF must have 1 or 3 channels" # save all the parameters in a dict self.params = { diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index c9be1ee4..62b29038 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -43,6 +43,7 @@ DiffuserCamMirflickr, SimulatedFarFieldDataset, SimulatedDatasetTrainableMask, + DigiCamCelebA, ) from torch.utils.data import Subset import lensless.hardware.trainable_mask @@ -257,6 +258,75 @@ def train_unrolled(config): psf = dataset.psf + elif "celeba_adafruit" in config.files.dataset: + + dataset = DigiCamCelebA( + data_dir=os.path.join(get_original_cwd(), config.files.dataset), + celeba_root=config.files.celeba_root, + psf_path=os.path.join(get_original_cwd(), config.files.psf), + downsample=config.files.downsample, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + simulation_config=config.simulation, + ) + dataset.psf = dataset.psf.to(device) + psf = dataset.psf + log.info(f"Data shape : {dataset[0][0].shape}") + + # reconstruct lensless with ADMM + lensless, lensed = dataset[0] + from lensless import ADMM + + recon = ADMM(psf) + recon.set_data(lensless.to(psf.device)) + print("Reconstructing lensless image with ADMM...") + start_time = time.time() + res = recon.apply(disp_iter=None, plot=False, n_iter=10) + print(f"Processing time : {time.time() - start_time} s") + res_np = res[0].cpu().numpy() + res_np = res_np / res_np.max() + save_image(res_np, "lensless_recon.png") + lensed_np = lensed[0].cpu().numpy() + save_image(lensed_np, "lensed.png") + lensless_np = lensless[0].cpu().numpy() + save_image(lensless_np, "lensless_raw.png") + + # -- plot lensed and res on top of each other + if config.training.crop is not None: + res_np = res_np[ + config.training.crop.vertical[0] : config.training.crop.vertical[1], + config.training.crop.horizontal[0] : config.training.crop.horizontal[1], + ] + lensed_np = lensed_np[ + config.training.crop.vertical[0] : config.training.crop.vertical[1], + config.training.crop.horizontal[0] : config.training.crop.horizontal[1], + ] + log.info(f"Cropped shape : {res_np.shape}") + plt.figure() + plt.imshow(lensed_np, alpha=0.5) + plt.imshow(res_np, alpha=0.7) + plt.savefig("overlay_lensed_recon.png") + + # train-test split + train_size = int((1 - config.files.test_size) * len(dataset)) + test_size = len(dataset) - train_size + train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size]) + if config.files.n_files is not None: + train_set = Subset(train_set, np.arange(config.files.n_files)) + test_set = Subset(test_set, np.arange(config.files.n_files)) + + # -- if learning mask + mask = prep_trainable_mask(config, dataset.psf) + if mask is not None: + # plot initial PSF + psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] + if config.trainable_mask.grayscale: + psf_np = psf_np[:, :, -1] + + save_image(psf_np, os.path.join(save, "psf_initial.png")) + plot_image(psf_np, gamma=config.display.gamma) + plt.savefig(os.path.join(save, "psf_initial_plot.png")) + else: train_set, test_set, mask = simulate_dataset(config) @@ -330,6 +400,7 @@ def train_unrolled(config): test_dataset=test_set, mask=mask, batch_size=config.training.batch_size, + eval_batch_size=config.training.eval_batch_size, loss=config.loss, lpips=config.lpips, l1_mask=config.trainable_mask.L1_strength, @@ -342,6 +413,7 @@ def train_unrolled(config): save_every=config.training.save_every, gamma=config.display.gamma, logger=log, + crop=config.training.crop, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp)