From ee98362a9107d7afcd44b95af5ac510f94d1ae58 Mon Sep 17 00:00:00 2001 From: Yohann PERRON Date: Fri, 5 May 2023 14:23:08 +0200 Subject: [PATCH 01/11] Fix ADMM ordering --- lensless/admm.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/lensless/admm.py b/lensless/admm.py index 581700f4..e4b55482 100644 --- a/lensless/admm.py +++ b/lensless/admm.py @@ -181,6 +181,15 @@ def _X_update(self): self._xi + self._mu1 * self._forward_out + self._convolver._pad(self._data) ) + def _W_update(self): + """Non-negativity update""" + if self.is_torch: + self._W = torch.maximum( + self._rho / self._mu3 + self._image_est, torch.zeros_like(self._image_est) + ) + else: + self._W = np.maximum(self._rho / self._mu3 + self._image_est, 0) + def _image_update(self): rk = ( (self._mu3 * self._W - self._rho) @@ -199,15 +208,6 @@ def _image_update(self): # self._image_est = self._convolver._crop(res) - def _W_update(self): - """Non-negativity update""" - if self.is_torch: - self._W = torch.maximum( - self._rho / self._mu3 + self._image_est, torch.zeros_like(self._image_est) - ) - else: - self._W = np.maximum(self._rho / self._mu3 + self._image_est, 0) - def _xi_update(self): # to avoid computing forward model twice self._xi += self._mu1 * (self._forward_out - self._X) @@ -223,13 +223,14 @@ def _update(self): self._U_update() self._X_update() + self._W_update() + self._image_update() # update forward and sparse operators self._forward_out = self._convolver.convolve(self._image_est) self._Psi_out = self._Psi(self._image_est) - self._W_update() self._xi_update() self._eta_update() self._rho_update() From bba42c0ce588e2d6484328f26cad1391d6e34b54 Mon Sep 17 00:00:00 2001 From: YohannPerron <73244423+YohannPerron@users.noreply.github.com> Date: Tue, 29 Aug 2023 22:17:49 +0200 Subject: [PATCH 02/11] Improved dataset (#68) * New simulated dataset (moved old dataset) * Move dataset to utils * Added parent class DualDataset * Use new dataset structure for training * Fix doc and bugs * New dataset for lensless only * Fixes for downscaling * Update change * Disclaimer for LenslessDataset * Added header * Updated documentation * Fix typos and wording. * Move dataset docs to data section. * Fixed docstring * Fix for flip in simulated dataset * Add wrapper arounf FarFieldSimulator * Fix import error * Fix docstrings * FIx typos. * Fix doc rendering of FarFieldSimulator. * Refactor. * Refactor. * Fix import. * Refactor and rephrase for clearer dataset diff * Fixed no attribute psf * add new simulation to training script * Remove print. * Update changelog. --------- Co-authored-by: Eric Bezzam --- CHANGELOG.rst | 3 + docs/requirements.txt | 3 +- docs/source/conf.py | 8 +- docs/source/dataset.rst | 27 ++ docs/source/evaluation.rst | 4 - docs/source/index.rst | 1 + docs/source/simulation.rst | 12 + lensless/eval/benchmark.py | 209 +-------------- lensless/utils/dataset.py | 448 ++++++++++++++++++++++++++++++++ lensless/utils/simulation.py | 100 +++++++ scripts/eval/benchmark_recon.py | 3 +- scripts/recon/train_unrolled.py | 36 +-- 12 files changed, 620 insertions(+), 234 deletions(-) create mode 100644 docs/source/dataset.rst create mode 100644 lensless/utils/dataset.py create mode 100644 lensless/utils/simulation.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 847fa0f7..99be0bb1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,9 @@ Added - Script for measuring arbitrary dataset (from Raspberry Pi). - Support for preprocessing and postprocessing, such as denoising, in ``TrainableReconstructionAlgorithm``. Both trainable and fix postprocessing can be used. - Utilities to load a trained DruNet model for use as postprocessing in ``TrainableReconstructionAlgorithm``. +- Unified interface for dataset. See ``utils.dataset.DualDataset``. +- New simulated dataset compatible with new data format ([(batch_size), depth, width, height, color]). See ``utils.dataset.SimulatedFarFieldDataset``. +- New dataset for pair of original image and their measurement from a screen. See ``utils.dataset.MeasuredDataset`` and ``utils.dataset.MeasuredDatasetSimulatedOriginal``. - Support for unrolled loading and inference in the script ``admm.py``. - Tikhonov reconstruction for coded aperture measurements (MLS / MURA). diff --git a/docs/requirements.txt b/docs/requirements.txt index f8146fac..e105c9f7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,4 +4,5 @@ docutils==0.16 # >0.17 doesn't render bullets numpy>=1.22 # so that default dtype are correctly rendered torch>=1.10 torchvision>=0.15.2 -torchmetrics>=0.11.4 \ No newline at end of file +torchmetrics>=0.11.4 +waveprop>=0.0.5 \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index fc01f75b..02d3e0b0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,14 +28,14 @@ "pycsou.util", "pycsou.util.ptype", "PIL", + "PIL.Image", "tqdm", "paramiko", "paramiko.ssh_exception", "perlin_numpy", - "waveprop", - "waveprop.fresnel", - "waveprop.rs", - "waveprop.noise", + "scipy.special", + "matplotlib.cm", + "pyffs", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst new file mode 100644 index 00000000..1312e1cc --- /dev/null +++ b/docs/source/dataset.rst @@ -0,0 +1,27 @@ +Dataset objects (for training and testing) +========================================== + +The software below provides functionality (with PyTorch) to load +datasets for training and testing. + +.. automodule:: lensless.utils.dataset + +.. autoclass:: lensless.utils.dataset.DualDataset + :members: _get_images_pair + :special-members: __init__, __len__ + +.. autoclass:: lensless.utils.dataset.SimulatedFarFieldDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset + :members: + :special-members: __init__ diff --git a/docs/source/evaluation.rst b/docs/source/evaluation.rst index f3f381d2..0f2c9d93 100644 --- a/docs/source/evaluation.rst +++ b/docs/source/evaluation.rst @@ -23,8 +23,4 @@ .. automodule:: lensless.eval.benchmark - .. autoclass:: lensless.eval.benchmark.ParallelDataset - :members: - :special-members: __init__ - .. autofunction:: lensless.eval.benchmark.benchmark \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 94c236e6..3fba13d2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,6 +35,7 @@ Contents simulation data + dataset .. toctree:: :hidden: diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst index d5ecaa34..12739ad2 100644 --- a/docs/source/simulation.rst +++ b/docs/source/simulation.rst @@ -16,6 +16,18 @@ library is used with the following simulation steps: PyTorch support is available to speed up simulation on GPU, and to create Dataset and DataLoader objects for training and testing! +FarFieldSimulator +------------------ + +A wrapper around `waveprop.simulation.FarFieldSimulator `__ +is implemented as :py:class:`lensless.utils.simulation.FarFieldSimulator`. +It handles the conversion between the HWC and CHW dimension orderings so that the convention of LenslessPiCam can be maintained (namely HWC). + +.. autoclass:: lensless.utils.simulation.FarFieldSimulator + :members: + :special-members: __init__ + + Simulating 3D data ------------------ diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index b4aa6b79..2f78f402 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -7,18 +7,14 @@ # ############################################################################# -import glob -import os -from lensless.utils.io import load_psf -from lensless.utils.image import resize -import numpy as np +from lensless.utils.dataset import DiffuserCamTestDataset from tqdm import tqdm from lensless.utils.io import load_image try: import torch - from torch.utils.data import Dataset, DataLoader + from torch.utils.data import DataLoader from torch.nn import MSELoss, L1Loss from torchmetrics import StructuralSimilarityIndexMeasure from torchmetrics.image import lpip, psnr @@ -28,207 +24,6 @@ ) -class ParallelDataset(Dataset): - """ - Dataset consisting of lensless and corresponding lensed image. - - It can be used with a PyTorch DataLoader to load a batch of lensless and corresponding lensed images. - - """ - - def __init__( - self, - root_dir, - n_files=False, - background=None, - downsample=4, - flip=False, - transform_lensless=None, - transform_lensed=None, - lensless_fn="diffuser", - lensed_fn="lensed", - image_ext="npy", - **kwargs, - ): - """ - Dataset consisting of lensless and corresponding lensed image. Default parameters are for the DiffuserCam - Lensless Mirflickr Dataset (DLMD). - - Parameters - ---------- - - root_dir : str - Path to the test dataset. It is expected to contain two folders: ones of lensless images and one of lensed images. - n_files : int or None, optional - Metrics will be computed only on the first ``n_files`` images. If None, all images are used, by default False - background : :py:class:`~torch.Tensor` or None, optional - If not ``None``, background is removed from lensless images, by default ``None``. - downsample : int, optional - Downsample factor of the lensless images, by default 4. - flip : bool, optional - If ``True``, lensless images are flipped, by default ``False``. - transform_lensless : PyTorch Transform or None, optional - Transform to apply to the lensless images, by default None - transform_lensed : PyTorch Transform or None, optional - Transform to apply to the lensed images, by default None - lensless_fn : str, optional - Name of the folder containing the lensless images, by default "diffuser". - lensed_fn : str, optional - Name of the folder containing the lensed images, by default "lensed". - image_ext : str, optional - Extension of the images, by default "npy". - """ - - self.root_dir = root_dir - self.lensless_dir = os.path.join(root_dir, lensless_fn) - self.lensed_dir = os.path.join(root_dir, lensed_fn) - self.image_ext = image_ext.lower() - - files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) - if n_files: - files = files[:n_files] - self.files = [os.path.basename(fn) for fn in files] - - if len(self.files) == 0: - raise FileNotFoundError( - f"No files found in {self.lensless_dir} with extension {image_ext}" - ) - - self.background = background - self.downsample = downsample / 4 - self.flip = flip - self.transform_lensless = transform_lensless - self.transform_lensed = transform_lensed - - def __len__(self): - return len(self.files) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - - if self.image_ext == "npy": - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) - lensless = np.load(lensless_fp) - lensed = np.load(lensed_fp) - else: - # more standard image formats: png, jpg, tiff, etc. - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) - lensless = load_image(lensless_fp) - lensed = load_image(lensed_fp) - - # convert to float - if lensless.dtype == np.uint8: - lensless = lensless.astype(np.float32) / 255 - lensed = lensed.astype(np.float32) / 255 - else: - # 16 bit - lensless = lensless.astype(np.float32) / 65535 - lensed = lensed.astype(np.float32) / 65535 - - if self.downsample != 1.0: - lensless = resize(lensless, factor=1 / self.downsample) - lensed = resize(lensed, factor=1 / self.downsample) - - lensless = torch.from_numpy(lensless) - lensed = torch.from_numpy(lensed) - - # If [H, W, C] -> [D, H, W, C] - if len(lensless.shape) == 3: - lensless = lensless.unsqueeze(0) - if len(lensed.shape) == 3: - lensed = lensed.unsqueeze(0) - - if self.background is not None: - lensless = lensless - self.background - - # flip image x and y if needed - if self.flip: - lensless = torch.rot90(lensless, dims=(-3, -2)) - lensed = torch.rot90(lensed, dims=(-3, -2)) - if self.transform_lensless: - lensless = self.transform_lensless(lensless) - - if self.transform_lensed: - lensed = self.transform_lensed(lensed) - - return lensless, lensed - - -class DiffuserCamTestDataset(ParallelDataset): - """ - Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking. - """ - - def __init__( - self, - data_dir="data", - n_files=200, - downsample=8, - ): - """ - Dataset consisting of lensless and corresponding lensed image. Default parameters are for the test set of DiffuserCam - Lensless Mirflickr Dataset (DLMD). - - Parameters - ---------- - data_dir : str, optional - The path to the folder containing the DiffuserCam_Test dataset, by default "data" - n_files : int, optional - Number of image pair to load in the dataset , by default 200 - downsample : int, optional - Downsample factor of the lensless images, by default 8 - """ - # download dataset if necessary - main_dir = data_dir - data_dir = os.path.join(data_dir, "DiffuserCam_Test") - if not os.path.isdir(data_dir): - print("No dataset found for benchmarking.") - try: - from torchvision.datasets.utils import download_and_extract_archive - except ImportError: - exit() - msg = "Do you want to download the sample dataset (3.5GB)?" - - # default to yes if no input is given - valid = input("%s (Y/n) " % msg).lower() != "n" - if valid: - url = "https://drive.switch.ch/index.php/s/D3eRJ6PRljfHoH8/download" - filename = "DiffuserCam_Test.zip" - download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True) - - psf_fp = os.path.join(data_dir, "psf.tiff") - psf, background = load_psf( - psf_fp, - downsample=downsample, - return_float=True, - return_bg=True, - bg_pix=(0, 15), - ) - - # transform from BGR to RGB - from torchvision import transforms - - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - - self.psf = transform_BRG2RGB(torch.from_numpy(psf)) - - super().__init__( - data_dir, - n_files, - background, - downsample, - flip=False, - transform_lensless=transform_BRG2RGB, - transform_lensed=transform_BRG2RGB, - lensless_fn="diffuser", - lensed_fn="lensed", - image_ext="npy", - ) - - def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): """ Compute multiple metrics for a reconstruction algorithm. diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py new file mode 100644 index 00000000..2634cb7c --- /dev/null +++ b/lensless/utils/dataset.py @@ -0,0 +1,448 @@ +# ############################################################################# +# dataset.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# ############################################################################# + +import numpy as np +import glob +import os +import torch +from abc import abstractmethod +from torch.utils.data import Dataset +from torchvision import transforms +from lensless.utils.simulation import FarFieldSimulator +from lensless.utils.io import load_image, load_psf +from lensless.utils.image import resize + + +class DualDataset(Dataset): + """ + Abstract class for defining a dataset of paired lensed and lensless images. + """ + + def __init__( + self, + indices=None, + background=None, + downsample=1, + flip=False, + transform_lensless=None, + transform_lensed=None, + **kwargs, + ): + """ + Dataset consisting of lensless and corresponding lensed image. + + Parameters + ---------- + indices : range or int or None + Indices of the images to use in the dataset (if integer, it should be interpreted as range(indices)), by default None. + background : :py:class:`~torch.Tensor` or None, optional + If not ``None``, background is removed from lensless images, by default ``None``. + downsample : int, optional + Downsample factor of the lensless images, by default 1. + flip : bool, optional + If ``True``, lensless images are flipped, by default ``False``. + transform_lensless : PyTorch Transform or None, optional + Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + transform_lensed : PyTorch Transform or None, optional + Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + """ + if isinstance(indices, int): + indices = range(indices) + self.indices = indices + self.background = background + self.downsample = downsample + self.flip = flip + self.transform_lensless = transform_lensless + self.transform_lensed = transform_lensed + + @abstractmethod + def __len__(self): + """ + Abstract method to get the length of the dataset. It should take into account the indices parameter. + """ + raise NotImplementedError + + @abstractmethod + def _get_images_pair(self, idx): + """ + Abstract method to get the lensed and lensless images. Should return a pair (lensless, lensed) of numpy arrays with values in [0,1]. + + Parameters + ---------- + idx : int + images index + """ + raise NotImplementedError + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.item() + + if self.indices is not None: + idx = self.indices[idx] + lensless, lensed = self._get_images_pair(idx) + + if isinstance(lensless, np.ndarray): + # expected case + if self.downsample != 1.0: + lensless = resize(lensless, factor=1 / self.downsample) + lensed = resize(lensed, factor=1 / self.downsample) + + lensless = torch.from_numpy(lensless) + lensed = torch.from_numpy(lensed) + else: + # torch tensor + # This mean get_images_pair returned a torch tensor. This isn't recommended, if possible get_images_pair should return a numpy array + # In this case it should also have applied the downsampling + pass + + # If [H, W, C] -> [D, H, W, C] + if len(lensless.shape) == 3: + lensless = lensless.unsqueeze(0) + if len(lensed.shape) == 3: + lensed = lensed.unsqueeze(0) + + if self.background is not None: + lensless = lensless - self.background + + # flip image x and y if needed + if self.flip: + lensless = torch.rot90(lensless, dims=(-3, -2)) + lensed = torch.rot90(lensed, dims=(-3, -2)) + if self.transform_lensless: + lensless = self.transform_lensless(lensless) + if self.transform_lensed: + lensed = self.transform_lensed(lensed) + + return lensless, lensed + + +class SimulatedFarFieldDataset(DualDataset): + """ + Dataset of propagated images (through simulation) from a Torch Dataset. :py:class:`lensless.utils.simulation.FarFieldSimulator` is used for simulation, + assuming a far-field propagation and a shift-invariant system with a single point spread function (PSF). + + """ + + def __init__( + self, + dataset, + simulator, + pre_transform=None, + dataset_is_CHW=False, + flip=False, + **kwargs, + ): + """ + Parameters + ---------- + + dataset : :py:class:`torch.utils.data.Dataset` + Dataset to propagate. Should output images with shape [H, W, C] unless ``dataset_is_CHW`` is ``True`` (and therefore images have the dimension ordering of [C, H, W]). + simulator : :py:class:`lensless.utils.simulation.FarFieldSimulator` + Simulator object used on images from ``dataset``.Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. + pre_transform : PyTorch Transform or None, optional + Transform to apply to the images before simulation, by default ``None``. Note that this transform is applied on HCW images (different from torchvision). + dataset_is_CHW : bool, optional + If True, the input dataset is expected to output images with shape [C, H, W], by default ``False``. + flip : bool, optional + If True, images are flipped beffore the simulation, by default ``False``.. + """ + + # we do the flipping before the simualtion + super(SimulatedFarFieldDataset, self).__init__(flip=False, **kwargs) + + assert isinstance(dataset, Dataset) + self.dataset = dataset + self.n_files = len(dataset) + + self.dataset_is_CHW = dataset_is_CHW + self._pre_transform = pre_transform + self.flip_pre_sim = flip + + # check simulator + assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" + assert simulator.is_torch, "Simulator should be a pytorch simulator" + assert simulator.fft_shape is not None, "Simulator should have a psf" + self.sim = simulator + + def get_image(self, index): + return self.dataset[index] + + def _get_images_pair(self, index): + # load image + img, _ = self.get_image(index) + # convert to CHW for simulator and transform + if self.dataset_is_CHW: + img = img.moveaxis(-3, -1) + if self.flip_pre_sim: + img = torch.rot90(img, dims=(-3, -2)) + if self._pre_transform is not None: + img = self._pre_transform(img) + + lensless, lensed = self.sim.propagate(img, return_object_plane=True) + + return lensless, lensed + + def __len__(self): + if self.indices is None: + return self.n_files + else: + return len([x for x in self.indices if x < self.n_files]) + + +class MeasuredDatasetSimulatedOriginal(DualDataset): + """ + Dataset consisting of lensless image captured from a screen and the corresponding image shown on the screen. + Unlike :py:class:`lensless.utils.dataset.MeasuredDataset`, the ground-truth lensed image is simulated using a :py:class:`lensless.utils.simulation.FarFieldSimulator` + object rather than measured with a lensed camera. + """ + + def __init__( + self, + root_dir, + simulator, + lensless_fn="diffuser", + original_fn="lensed", + image_ext="npy", + original_ext=None, + downsample=1, + **kwargs, + ): + """ + Dataset consisting of lensless image captured from a screen and the corresponding image shown on screen. + + Parameters + ---------- + root_dir : str + Path to the test dataset. It is expected to contain two folders: one of lensless images and one of original images. + simulator : :py:class:`lensless.utils.simulatorFarFieldSimulator` + Simulator to use for the projection of the original image to object space. The PSF **should not** be specified, and it is expect to have ``is_torch = True``. + lensless_fn : str, optional + Name of the folder containing the lensless images, by default "diffuser". + lensed_fn : str, optional + Name of the folder containing the lensed images, by default "lensed". + image_ext : str, optional + Extension of the images, by default "npy". + original_ext : str, optional + Extension of the original image if different from lenless, by default None. + downsample : int, optional + Downsample factor of the lensless images, by default 1. + """ + super(MeasuredDatasetSimulatedOriginal, self).__init__(downsample=1, **kwargs) + self.pre_downsample = downsample + + self.root_dir = root_dir + self.lensless_dir = os.path.join(root_dir, lensless_fn) + self.original_dir = os.path.join(root_dir, original_fn) + self.image_ext = image_ext.lower() + self.original_ext = original_ext.lower() if original_ext is not None else image_ext.lower() + + files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) + files.sort() + self.files = [os.path.basename(fn) for fn in files] + + if len(self.files) == 0: + raise FileNotFoundError( + f"No files found in {self.lensless_dir} with extension {image_ext}" + ) + + # check simulator + assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" + assert simulator.is_torch, "Simulator should be a pytorch simulator" + assert simulator.fft_shape is None, "Simulator should not have a psf" + self.sim = simulator + + def __len__(self): + if self.indices is None: + return len(self.files) + else: + return len([i for i in self.indices if i < len(self.files)]) + + def _get_images_pair(self, idx): + if self.image_ext == "npy" or self.image_ext == "npz": + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + original_fp = os.path.join(self.original_dir, self.files[idx]) + lensless = np.load(lensless_fp) + lensless = resize(lensless, factor=1 / self.downsample) + original = np.load(original_fp[:-3] + self.original_ext) + else: + # more standard image formats: png, jpg, tiff, etc. + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + original_fp = os.path.join(self.original_dir, self.files[idx]) + lensless = load_image(lensless_fp, downsample=self.pre_downsample) + original = load_image( + original_fp[:-3] + self.original_ext, downsample=self.pre_downsample + ) + + # convert to float + if lensless.dtype == np.uint8: + lensless = lensless.astype(np.float32) / 255 + original = original.astype(np.float32) / 255 + else: + # 16 bit + lensless = lensless.astype(np.float32) / 65535 + original = original.astype(np.float32) / 65535 + + # convert to torch + lensless = torch.from_numpy(lensless) + original = torch.from_numpy(original) + + # project original image to lensed space + with torch.no_grad(): + lensed = self.sim.propagate() + + return lensless, lensed + + +class MeasuredDataset(DualDataset): + """ + Dataset consisting of lensless and corresponding lensed image. + It can be used with a PyTorch DataLoader to load a batch of lensless and corresponding lensed images. + Unless the setup is perfectly calibrated, one should expect to have to use ``transform_lensed`` to adjust the alignment and rotation. + """ + + def __init__( + self, + root_dir, + lensless_fn="diffuser", + lensed_fn="lensed", + image_ext="npy", + **kwargs, + ): + """ + Dataset consisting of lensless and corresponding lensed image. Default parameters are for the + `DiffuserCam Lensless Mirflickr Dataset (DLMD) `_. + + Parameters + ---------- + root_dir : str + Path to the test dataset. It is expected to contain two folders: ones of lensless images and one of lensed images. + lensless_fn : str, optional + Name of the folder containing the lensless images, by default "diffuser". + lensed_fn : str, optional + Name of the folder containing the lensed images, by default "lensed". + image_ext : str, optional + Extension of the images, by default "npy". + """ + + super(MeasuredDataset, self).__init__(**kwargs) + + self.root_dir = root_dir + self.lensless_dir = os.path.join(root_dir, lensless_fn) + self.lensed_dir = os.path.join(root_dir, lensed_fn) + self.image_ext = image_ext.lower() + + files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) + files.sort() + self.files = [os.path.basename(fn) for fn in files] + + if len(self.files) == 0: + raise FileNotFoundError( + f"No files found in {self.lensless_dir} with extension {image_ext}" + ) + + def __len__(self): + if self.indices is None: + return len(self.files) + else: + return len([i for i in self.indices if i < len(self.files)]) + + def _get_images_pair(self, idx): + if self.image_ext == "npy" or self.image_ext == "npz": + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) + lensless = np.load(lensless_fp) + lensed = np.load(lensed_fp) + else: + # more standard image formats: png, jpg, tiff, etc. + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) + lensless = load_image(lensless_fp) + lensed = load_image(lensed_fp) + + # convert to float + if lensless.dtype == np.uint8: + lensless = lensless.astype(np.float32) / 255 + lensed = lensed.astype(np.float32) / 255 + else: + # 16 bit + lensless = lensless.astype(np.float32) / 65535 + lensed = lensed.astype(np.float32) / 65535 + + return lensless, lensed + + +class DiffuserCamTestDataset(MeasuredDataset): + """ + Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking. + """ + + def __init__( + self, + data_dir="data", + n_files=200, + downsample=2, + ): + """ + Dataset consisting of lensless and corresponding lensed image. Default parameters are for the test set of + `DiffuserCam Lensless Mirflickr Dataset (DLMD) `_. + + Parameters + ---------- + data_dir : str, optional + The path to the folder containing the DiffuserCam_Test dataset, by default "data". + n_files : int, optional + Number of image pairs to load in the dataset , by default 200. + downsample : int, optional + Downsample factor of the lensless images, by default 8. + """ + + # download dataset if necessary + main_dir = data_dir + data_dir = os.path.join(data_dir, "DiffuserCam_Test") + if not os.path.isdir(data_dir): + print("No dataset found for benchmarking.") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the sample dataset (3.5GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/D3eRJ6PRljfHoH8/download" + filename = "DiffuserCam_Test.zip" + download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True) + + psf_fp = os.path.join(data_dir, "psf.tiff") + psf, background = load_psf( + psf_fp, + downsample=downsample, + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + + # transform from BGR to RGB + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + + self.psf = transform_BRG2RGB(torch.from_numpy(psf)) + + super().__init__( + root_dir=data_dir, + indices=range(n_files), + background=background, + downsample=downsample / 4, + flip=False, + transform_lensless=transform_BRG2RGB, + transform_lensed=transform_BRG2RGB, + lensless_fn="diffuser", + lensed_fn="lensed", + image_ext="npy", + ) diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py new file mode 100644 index 00000000..36aac243 --- /dev/null +++ b/lensless/utils/simulation.py @@ -0,0 +1,100 @@ +# ############################################################################# +# simulation.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# ############################################################################# + +import numpy as np +from waveprop.simulation import FarFieldSimulator as FarFieldSimulator_wp + + +class FarFieldSimulator(FarFieldSimulator_wp): + """ + LenslessPiCam-compatible wrapper for :py:class:`~waveprop.simulation.FarFieldSimulator` (source code on `GitHub `__). + """ + + def __init__( + self, + object_height, + scene2mask, + mask2sensor, + sensor, + psf=None, + output_dim=None, + snr_db=None, + max_val=255, + device_conv="cpu", + random_shift=False, + is_torch=False, + **kwargs + ): + """ + Parameters + ---------- + psf : np.ndarray, optional. + Point spread function. If not provided, return image at object plane. + object_height : float or (float, float) + Height of object in meters. Or range of values to randomly sample from. + scene2mask : float + Distance from scene to mask in meters. + mask2sensor : float + Distance from mask to sensor in meters. + sensor : str + Sensor name. + snr_db : float, optional + Signal-to-noise ratio in dB, by default None. + max_val : int, optional + Maximum value of image, by default 255. + device_conv : str, optional + Device to use for convolution (when using pytorch), by default "cpu". + random_shift : bool, optional + Whether to randomly shift the image, by default False. + is_torch : bool, optional + Whether to use pytorch, by default False. + """ + + if psf is not None: + # convert HWC to CHW + psf = psf.squeeze().movedim(-1, 0) + + super().__init__( + object_height, + scene2mask, + mask2sensor, + sensor, + psf, + output_dim, + snr_db, + max_val, + device_conv, + random_shift, + is_torch, + **kwargs + ) + + def propagate(self, obj, return_object_plane=False): + """ + Parameters + ---------- + obj : np.ndarray or torch.Tensor + Single image to propagate at format HWC. + return_object_plane : bool, optional + Whether to return object plane, by default False. + """ + if self.is_torch: + obj = obj.moveaxis(-1, 0) + res = super().propagate(obj, return_object_plane) + if isinstance(res, tuple): + res = res[0].moveaxis(-3, -1), res[1].moveaxis(-3, -1) + else: + res = res.moveaxis(-3, -1) + return res + else: + obj = np.moveaxis(obj, -1, 0) + res = super().propagate(obj, return_object_plane) + if isinstance(res, tuple): + res = np.moveaxis(res[0], -3, -1), np.moveaxis(res[1], -3, -1) + else: + res = np.moveaxis(res, -3, -1) + return res diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index de6a1c68..6611ceec 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -20,9 +20,10 @@ import json import os import pathlib as plib -from lensless.eval.benchmark import benchmark, DiffuserCamTestDataset +from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent +from lensless.utils.dataset import DiffuserCamTestDataset try: import torch diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 883f1819..a608ce97 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -22,9 +22,10 @@ import time import matplotlib.pyplot as plt from lensless import UnrolledFISTA, UnrolledADMM -from waveprop.dataset_util import SimulatedPytorchDataset +from lensless.utils.dataset import DiffuserCamTestDataset, SimulatedFarFieldDataset from lensless.utils.image import rgb2gray -from lensless.eval.benchmark import benchmark, DiffuserCamTestDataset +from lensless.utils.simulation import FarFieldSimulator +from lensless.eval.benchmark import benchmark import torch from torchvision import transforms, datasets from tqdm import tqdm @@ -58,15 +59,11 @@ def simulate_dataset(config, psf): psf = rgb2gray(psf) if not isinstance(psf, torch.Tensor): psf = transforms.ToTensor()(psf) - elif psf.shape[-1] == 3: - # Waveprop syntetic dataset expect C H W - psf = psf.permute(2, 0, 1) # batch_size = config.files.batch_size batch_size = config.training.batch_size n_files = config.files.n_files device_conv = config.torch_device - target = config.target # check if gpu is available if device_conv == "cuda" and torch.cuda.is_available(): @@ -74,11 +71,17 @@ def simulate_dataset(config, psf): else: device_conv = "cpu" + # create simulator + simulator = FarFieldSimulator( + psf=psf, + is_torch=True, + **config.simulation, + ) # create Pytorch dataset and dataloader if n_files is not None: ds = torch.utils.data.Subset(ds, np.arange(n_files)) - ds_prop = SimulatedPytorchDataset( - dataset=ds, psf=psf, device_conv=device_conv, target=target, **config.simulation + ds_prop = SimulatedFarFieldDataset( + dataset=ds, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv ) ds_loader = torch.utils.data.DataLoader( dataset=ds_prop, batch_size=batch_size, shuffle=True, pin_memory=(psf.device != "cpu") @@ -138,9 +141,6 @@ def train_unrolled( # torch.autograd.set_detect_anomaly(True) - # if using a portrait dataset rotate the PSF - flip = config.files.dataset in ["CelebA"] - # benchmarking dataset: path = os.path.join(get_original_cwd(), "data") benchmark_dataset = DiffuserCamTestDataset( @@ -155,8 +155,6 @@ def train_unrolled( psf = psf[..., [2, 1, 0]] # if using a portrait dataset rotate the PSF - if flip: - psf = torch.rot90(psf, dims=[0, 1]) disp = config.display.disp if disp < 0: @@ -222,17 +220,21 @@ def train_unrolled( # load dataset and create dataloader if config.files.dataset == "DiffuserCam": # Use a ParallelDataset - from lensless.eval.benchmark import ParallelDataset + from lensless.utils.dataset import MeasuredDataset + + max_indices = 30000 + if config.files.n_files is not None: + max_indices = config.files.n_files + 1000 data_path = os.path.join(get_original_cwd(), "data", "DiffuserCam") - dataset = ParallelDataset( + dataset = MeasuredDataset( root_dir=data_path, - n_files=config.files.n_files, + indices=range(1000, max_indices), background=background, psf=psf, lensless_fn="diffuser_images", lensed_fn="ground_truth_lensed", - downsample=config.simulation.downsample, + downsample=config.simulation.downsample / 4, transform_lensless=transform_BRG2RGB, transform_lensed=transform_BRG2RGB, ) From 58f747adeba5d0007f6c4484acbc60342387cede Mon Sep 17 00:00:00 2001 From: YohannPerron <73244423+YohannPerron@users.noreply.github.com> Date: Wed, 30 Aug 2023 01:04:18 +0200 Subject: [PATCH 03/11] Streamlined training with new Trainer class (#77) * move utility function outside of script * New trainer class for training reconstruction * Update docstring * Update changelog * Update to trainer save * Fix partial mask support bug * Fix docstrings. * Fix APGD rendering. --------- Co-authored-by: Eric Bezzam --- CHANGELOG.rst | 1 + docs/requirements.txt | 1 + docs/source/conf.py | 4 + docs/source/reconstruction.rst | 22 +- lensless/eval/benchmark.py | 2 - lensless/recon/utils.py | 410 +++++++++++++++++++++++++++++++- recon_requirements.txt | 1 - scripts/recon/admm.py | 6 +- scripts/recon/train_unrolled.py | 238 ++---------------- setup.py | 1 + 10 files changed, 453 insertions(+), 233 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 99be0bb1..6ca03bd3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,7 @@ Added - New dataset for pair of original image and their measurement from a screen. See ``utils.dataset.MeasuredDataset`` and ``utils.dataset.MeasuredDatasetSimulatedOriginal``. - Support for unrolled loading and inference in the script ``admm.py``. - Tikhonov reconstruction for coded aperture measurements (MLS / MURA). +- New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. Changed diff --git a/docs/requirements.txt b/docs/requirements.txt index e105c9f7..484c5d20 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,4 +5,5 @@ numpy>=1.22 # so that default dtype are correctly rendered torch>=1.10 torchvision>=0.15.2 torchmetrics>=0.11.4 +pyFFS>=2.2.3 # for waveprop waveprop>=0.0.5 \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 02d3e0b0..60ee9e96 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,9 @@ "torchmetrics.image", "scipy.ndimage", "pycsou.abc", + "pycsou.operator", "pycsou.operator.func", + "pycsou.operator.linop", "pycsou.opt.solver", "pycsou.opt.stop", "pycsou.runtime", @@ -33,6 +35,8 @@ "paramiko", "paramiko.ssh_exception", "perlin_numpy", + "hydra", + "hydra.utils", "scipy.special", "matplotlib.cm", "pyffs", diff --git a/docs/source/reconstruction.rst b/docs/source/reconstruction.rst index 27434c40..e5b927f4 100644 --- a/docs/source/reconstruction.rst +++ b/docs/source/reconstruction.rst @@ -55,7 +55,7 @@ Accelerated Proximal Gradient Descent (APGD) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - .. autoclass:: lensless.APGD + .. autoclass:: lensless.recon.apgd.APGD :special-members: __init__ @@ -88,4 +88,22 @@ .. autoclass:: lensless.UnrolledADMM :members: batch_call :special-members: __init__ - :show-inheritance: \ No newline at end of file + :show-inheritance: + + + Reconstruction Utilities + ------------------------ + + .. autoclass:: lensless.recon.utils.Trainer + :members: + :special-members: __init__ + + .. autofunction:: lensless.recon.utils.load_drunet + + .. autofunction:: lensless.recon.utils.apply_denoiser + + .. autofunction:: lensless.recon.utils.get_drunet_function + + .. autofunction:: lensless.recon.utils.measure_gradient + + .. autofunction:: lensless.recon.utils.create_process_network diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 2f78f402..f93b754d 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -10,8 +10,6 @@ from lensless.utils.dataset import DiffuserCamTestDataset from tqdm import tqdm -from lensless.utils.io import load_image - try: import torch from torch.utils.data import DataLoader diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 7fad0400..54d23a1d 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -1,4 +1,21 @@ +# ############################################################################# +# dataset.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +import json +import math +import time +from hydra.utils import get_original_cwd +import os +import matplotlib.pyplot as plt import torch +from lensless.eval.benchmark import benchmark +from tqdm import tqdm from lensless.recon.drunet.network_unet import UNetRes @@ -17,7 +34,7 @@ def load_drunet(model_path, n_channels=3, requires_grad=False): Returns ------- - model : :py:class:`~torch.nn.Module` + model : :py:class:`torch.nn.Module` Loaded model. """ @@ -45,11 +62,11 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference") Parameters ---------- - model : :py:class:`~torch.nn.Module` + model : :py:class:`torch.nn.Module` Drunet compatible model. Its input must consist of 4 channels (RGB + noise level) and output an RGB image both in CHW format. - image : :py:class:`~torch.Tensor` + image : :py:class:`torch.Tensor` Input image. - noise_level : float or :py:class:`~torch.Tensor` + noise_level : float or :py:class:`torch.Tensor` Noise level in the image. device : str Device to use for computation. Can be "cpu" or "cuda". @@ -58,7 +75,7 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference") Returns ------- - image : :py:class:`~torch.Tensor` + image : :py:class:`torch.Tensor` Reconstructed image. """ # convert from NDHWC to NCHW @@ -108,7 +125,7 @@ def get_drunet_function(model, device="cpu", mode="inference"): Parameters ---------- - model : torch.nn.Module + model : :py:class:`torch.nn.Module` DruNet like denoiser model device : str Device to use for computation. Can be "cpu" or "cuda". @@ -129,3 +146,384 @@ def process(image, noise_level): return image return process + + +def measure_gradient(model): + """ + Helper function to measure L2 norm of the gradient of a model. + + Parameters + ---------- + model : :py:class:`torch.nn.Module` + Model to measure gradient of. + + Returns + ------- + Float + L2 norm of the gradient of the model. + """ + total_norm = 0.0 + for p in model.parameters(): + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm**0.5 + return total_norm + + +def create_process_network(network, depth, device="cpu"): + """ + Helper function to create a process network. + + Parameters + ---------- + network : str + Name of network to use. Can be "DruNet" or "UnetRes". + depth : int + Depth of network. + device : str + Device to use for computation. Can be "cpu" or "cuda". Defaults to "cpu". + + Returns + ------- + :py:class:`torch.nn.Module` + New process network. Already trained for Drunet. + """ + if network == "DruNet": + from lensless.recon.utils import load_drunet + + process = load_drunet( + os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True + ).to(device) + process_name = "DruNet" + elif network == "UnetRes": + from lensless.recon.drunet.network_unet import UNetRes + + n_channels = 3 + process = UNetRes( + in_nc=n_channels + 1, + out_nc=n_channels, + nc=[64, 128, 256, 512], + nb=depth, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ).to(device) + process_name = "UnetRes_d" + str(depth) + else: + process = None + process_name = None + + return (process, process_name) + + +class Trainer: + def __init__( + self, + recon, + train_dataset, + test_dataset, + batch_size=4, + loss="l2", + lpips=None, + optimizer="Adam", + optimizer_lr=1e-6, + slow_start=None, + skip_NAN=False, + algorithm_name="Unknown", + ): + """ + Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. + + Parameters + ---------- + recon : :py:class:`lensless.TrainableReconstructionAlgorithm` + Reconstruction algorithm to train. + train_dataset : :py:class:`torch.utils.data.Dataset` + Dataset to use for training. + test_dataset : :py:class:`torch.utils.data.Dataset` + Dataset to use for testing. + batch_size : int, optional + Batch size to use for training, by default 4 + loss : str, optional + Loss function to use for training "l1" or "l2", by default "l2" + lpips : float, optional + the weight of the lpips(VGG) in the total loss. If None ignore. By default None + optimizer : str, optional + Optimizer to use durring training. Available : "Adam". By default "Adam" + optimizer_lr : float, optional + Learning rate for the optimizer, by default 1e-6 + slow_start : float, optional + Multiplicative factor to reduce the learning rate during the first two epochs. If None, ignored. Default is None. + skip_NAN : bool, optional + Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False + algorithm_name : str, optional + Algorithm name for logging, by default "Unknown". + + """ + self.device = recon._psf.device + + self.recon = recon + self.train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=(self.device != "cpu"), + ) + self.test_dataset = test_dataset + self.lpips = lpips + self.skip_NAN = skip_NAN + + # loss + if loss == "l2": + self.Loss = torch.nn.MSELoss() + elif loss == "l1": + self.Loss = torch.nn.L1Loss() + else: + raise ValueError(f"Unsuported loss : {loss}") + + # Lpips loss + if lpips: + try: + import lpips + + self.Loss_lpips = lpips.LPIPS(net="vgg").to(self.device) + except ImportError: + return ImportError( + "lpips package is need for LPIPS loss. Install using : pip install lpips" + ) + + # optimizer + if optimizer == "Adam": + # the parameters of the base model and non torch.Module process must be added separatly + parameters = [{"params": recon.parameters()}] + self.optimizer = torch.optim.Adam(parameters, lr=optimizer_lr) + else: + raise ValueError(f"Unsuported optimizer : {optimizer}") + # Scheduler + if slow_start: + + def learning_rate_function(epoch): + if epoch == 0: + return slow_start + elif epoch == 1: + return math.sqrt(slow_start) + else: + return 1 + + else: + + def learning_rate_function(epoch): + return 1 + + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=learning_rate_function + ) + + self.metrics = { + "LOSS": [], + "MSE": [], + "MAE": [], + "LPIPS_Vgg": [], + "LPIPS_Alex": [], + "PSNR": [], + "SSIM": [], + "ReconstructionError": [], + "n_iter": self.recon._n_iter, + "algorithm": algorithm_name, + } + + # Backward hook that detect NAN in the gradient and print the layer weights + if not self.skip_NAN: + + def detect_nan(grad): + if torch.isnan(grad).any(): + print(grad, flush=True) + for name, param in recon.named_parameters(): + if param.requires_grad: + print(name, param) + raise ValueError("Gradient is NaN") + return grad + + for param in recon.parameters(): + if param.requires_grad: + param.register_hook(detect_nan) + if param.requires_grad: + param.register_hook(detect_nan) + + def train_epoch(self, data_loader, disp=-1): + """ + Train for one epoch. + + Parameters + ---------- + data_loader : :py:class:`torch.utils.data.DataLoader` + Data loader to use for training. + disp : int, optional + Display interval, if -1, no display, by default -1 + + Returns + ------- + float + Mean loss of the epoch. + """ + mean_loss = 0.0 + i = 1.0 + pbar = tqdm(data_loader) + for X, y in pbar: + # send to device + X = X.to(self.device) + y = y.to(self.device) + + y_pred = self.recon.batch_call(X.to(self.device)) + # normalizing each output + eps = 1e-12 + y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps + y_pred = y_pred / y_pred_max + + # normalizing y + y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps + y = y / y_max + + if i % disp == 1: + img_pred = y_pred[0, 0].cpu().detach().numpy() + img_truth = y[0, 0].cpu().detach().numpy() + + plt.imshow(img_pred) + plt.savefig(f"y_pred_{i-1}.png") + plt.imshow(img_truth) + plt.savefig(f"y_{i-1}.png") + + self.optimizer.zero_grad(set_to_none=True) + # convert to CHW for loss and remove depth + y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) + y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) + + loss_v = self.Loss(y_pred, y) + if self.lpips: + # value for LPIPS needs to be in range [-1, 1] + loss_v = loss_v + self.lpips * torch.mean( + self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) + ) + loss_v.backward() + + torch.nn.utils.clip_grad_norm_(self.recon.parameters(), 1.0) + + # if any gradient is NaN, skip training step + if self.skip_NAN: + is_NAN = False + for param in self.recon.parameters(): + if torch.isnan(param.grad).any(): + is_NAN = True + break + if is_NAN: + print("NAN detected in gradiant, skipping training step") + i += 1 + continue + self.optimizer.step() + + mean_loss += (loss_v.item() - mean_loss) * (1 / i) + pbar.set_description(f"loss : {mean_loss}") + i += 1 + + return mean_loss + + def evaluate(self, mean_loss, save_pt): + """ + Evaluate the reconstruction algorithm on the test dataset. + + Parameters + ---------- + mean_loss : float + Mean loss of the last epoch. + save_pt : str + Path to save metrics dictionary to. If None, no logging of metrics. + """ + if self.test_dataset is None: + return + # benchmarking + current_metrics = benchmark(self.recon, self.test_dataset, batchsize=10) + + # update metrics with current metrics + self.metrics["LOSS"].append(mean_loss) + for key in current_metrics: + self.metrics[key].append(current_metrics[key]) + + if save_pt: + # save dictionary metrics to file with json + with open(os.path.join(save_pt, "metrics.json"), "w") as f: + json.dump(self.metrics, f) + + def on_epoch_end(self, mean_loss, save_pt): + """ + Called at the end of each epoch. + + Parameters + ---------- + mean_loss : float + Mean loss of the last epoch. + save_pt : str + Path to save metrics dictionary to. If None, no logging of metrics. + """ + if save_pt is None: + # Use current directory + save_pt = os.getcwd() + + # save model + self.save(path=save_pt, include_optimizer=False) + self.evaluate(mean_loss, save_pt) + + def train(self, n_epoch=1, save_pt=None, disp=-1): + """ + Train the reconstruction algorithm. + + Parameters + ---------- + n_epoch : int, optional + Number of epochs to train for, by default 1 + save_pt : str, optional + Path to save metrics dictionary to. If None, use current directory, by default None + disp : int, optional + Display interval, if -1, no display. Default is -1. + """ + + start_time = time.time() + + for epoch in range(n_epoch): + print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") + mean_loss = self.train_epoch(self.train_dataloader, disp=disp) + self.on_epoch_end(mean_loss, save_pt) + self.scheduler.step() + + print(f"Train time : {time.time() - start_time} s") + + def save(self, path="recon", include_optimizer=False): + """ + Save state of reconstruction algorithm. + + Parameters + ---------- + path : str, optional + Path to save model to, by default "recon" + include_optimizer : bool, optional + Whether to include optimizer state, by default False + + """ + # create directory if it does not exist + if not os.path.exists(path): + os.makedirs(path) + + # TODO : ADD mask support + # # save mask + # if self.use_mask: + # torch.save(self.mask._mask, os.path.join(path, "mask.pt")) + # torch.save(self.mask._optimizer.state_dict(), os.path.join(path, "mask_optim.pt")) + # import matplotlib.pyplot as plt + + # plt.imsave( + # os.path.join(path, "psf.png"), self.mask.get_psf().detach().cpu().numpy()[0, ...] + # ) + # save optimizer + if include_optimizer: + torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) + # save recon + torch.save(self.recon.state_dict(), os.path.join(path, "recon.pt")) diff --git a/recon_requirements.txt b/recon_requirements.txt index 4ebe4412..b9e9f324 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -2,7 +2,6 @@ jedi==0.18.0 lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 -hydra-core click>=8.0.1 waveprop>=0.0.3 # for simulation diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 3ba3de1f..2a053722 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -58,14 +58,14 @@ def admm(config): else: assert config.torch, "Unrolled ADMM only works with torch" from lensless.recon.unrolled_admm import UnrolledADMM - import train_unrolled + import lensless.recon.utils - pre_process = train_unrolled.create_process_network( + pre_process = lensless.recon.utils.create_process_network( network=config.admm.pre_process_model.network, depth=config.admm.pre_process_depth.depth, device=config.torch_device, ) - post_process = train_unrolled.create_process_network( + post_process = lensless.recon.utils.create_process_network( network=config.admm.post_process_model.network, depth=config.admm.post_process_depth.depth, device=config.torch_device, diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index a608ce97..7d0a31e1 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -14,26 +14,19 @@ """ -import math import hydra from hydra.utils import get_original_cwd import os import numpy as np import time -import matplotlib.pyplot as plt from lensless import UnrolledFISTA, UnrolledADMM from lensless.utils.dataset import DiffuserCamTestDataset, SimulatedFarFieldDataset +from lensless.recon.utils import create_process_network from lensless.utils.image import rgb2gray from lensless.utils.simulation import FarFieldSimulator -from lensless.eval.benchmark import benchmark +from lensless.recon.utils import Trainer import torch from torchvision import transforms, datasets -from tqdm import tqdm - -try: - import json -except ImportError: - print("json package not found, metrics will not be saved") def simulate_dataset(config, psf): @@ -60,8 +53,6 @@ def simulate_dataset(config, psf): if not isinstance(psf, torch.Tensor): psf = transforms.ToTensor()(psf) - # batch_size = config.files.batch_size - batch_size = config.training.batch_size n_files = config.files.n_files device_conv = config.torch_device @@ -83,49 +74,7 @@ def simulate_dataset(config, psf): ds_prop = SimulatedFarFieldDataset( dataset=ds, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv ) - ds_loader = torch.utils.data.DataLoader( - dataset=ds_prop, batch_size=batch_size, shuffle=True, pin_memory=(psf.device != "cpu") - ) - return ds_loader - - -def create_process_network(network, depth, device="cpu"): - if network == "DruNet": - from lensless.recon.utils import load_drunet - - process = load_drunet( - os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True - ).to(device) - process_name = "DruNet" - elif network == "UnetRes": - from lensless.recon.drunet.network_unet import UNetRes - - n_channels = 3 - process = UNetRes( - in_nc=n_channels + 1, - out_nc=n_channels, - nc=[64, 128, 256, 512], - nb=depth, - act_mode="R", - downsample_mode="strideconv", - upsample_mode="convtranspose", - ).to(device) - process_name = "UnetRes_d" + str(depth) - else: - process = None - process_name = None - - return (process, process_name) - - -def measure_gradient(model): - # return the L2 norm of the gradient - total_norm = 0.0 - for p in model.parameters(): - param_norm = p.grad.detach().data.norm(2) - total_norm += param_norm.item() ** 2 - total_norm = total_norm**0.5 - return total_norm + return ds_prop @hydra.main(version_base=None, config_path="../../configs", config_name="unrolled_recon") @@ -189,7 +138,6 @@ def train_unrolled( pre_process=pre_process, post_process=post_process, ).to(device) - n_iter = config.reconstruction.unrolled_fista.n_iter elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( psf, @@ -201,7 +149,6 @@ def train_unrolled( pre_process=pre_process, post_process=post_process, ).to(device) - n_iter = config.reconstruction.unrolled_admm.n_iter else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") @@ -238,175 +185,28 @@ def train_unrolled( transform_lensless=transform_BRG2RGB, transform_lensed=transform_BRG2RGB, ) - data_loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config.training.batch_size, - shuffle=True, - pin_memory=(device != "cpu"), - ) else: # Use a simulated dataset - data_loader = simulate_dataset(config, psf) + dataset = simulate_dataset(config, psf) print(f"Setup time : {time.time() - start_time} s") - start_time = time.time() - - # loss - if config.loss == "l2": - Loss = torch.nn.MSELoss() - elif config.loss == "l1": - Loss = torch.nn.L1Loss() - else: - raise ValueError(f"Unsuported loss : {config.loss}") - - # Lpips loss - if config.lpips: - try: - import lpips - - loss_lpips = lpips.LPIPS(net="vgg").to(device) - except ImportError: - return ImportError( - "lpips package is need for LPIPS loss. Install using : pip install lpips" - ) - - # optimizer - if config.optimizer.type == "Adam": - # the parameters of the base model and non torch.Module process must be added separatly - parameters = [{"params": recon.parameters()}] - optimizer = torch.optim.Adam(parameters, lr=config.optimizer.lr) - else: - raise ValueError(f"Unsuported optimizer : {config.optimizer.type}") - # Scheduler - if config.training.slow_start: - - def learning_rate_function(epoch): - if epoch == 0: - return config.training.slow_start - elif epoch == 1: - return math.sqrt(config.training.slow_start) - else: - return 1 - - else: - - def learning_rate_function(epoch): - return 1 - - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=learning_rate_function) - - metrics = { - "LOSS": [], - "MSE": [], - "MAE": [], - "LPIPS_Vgg": [], - "LPIPS_Alex": [], - "PSNR": [], - "SSIM": [], - "ReconstructionError": [], - "n_iter": n_iter, - "algorithm": algorithm_name, - } - - # Backward hook that detect NAN in the gradient and print the layer weights - if not config.training.skip_NAN: - - def detect_nan(grad): - if torch.isnan(grad).any(): - print(grad, flush=True) - for name, param in recon.named_parameters(): - if param.requires_grad: - print(name, param) - raise ValueError("Gradient is NaN") - return grad - - for param in recon.parameters(): - if param.requires_grad: - param.register_hook(detect_nan) - if param.requires_grad: - param.register_hook(detect_nan) - - # Training loop - for epoch in range(config.training.epoch): - print(f"Epoch {epoch} with learning rate {scheduler.get_last_lr()}") - mean_loss = 0.0 - i = 1.0 - pbar = tqdm(data_loader) - for X, y in pbar: - # send to device - X = X.to(device) - y = y.to(device) - if X.shape[3] == 3: - X = X - y = y - - y_pred = recon.batch_call(X.to(device)) - # normalizing each output - eps = 1e-12 - y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps - y_pred = y_pred / y_pred_max - - # normalizing y - y = y.to(device) - y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps - y = y / y_max - - if i % disp == 1 and config.display.plot: - img_pred = y_pred[0, 0].cpu().detach().numpy() - img_truth = y[0, 0].cpu().detach().numpy() - - plt.imshow(img_pred) - plt.savefig(f"y_pred_{i-1}.png") - plt.imshow(img_truth) - plt.savefig(f"y_{i-1}.png") - - optimizer.zero_grad(set_to_none=True) - # convert to CHW for loss and remove depth - y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) - y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) - - loss_v = Loss(y_pred, y) - if config.lpips: - # value for LPIPS needs to be in range [-1, 1] - loss_v = loss_v + config.lpips * torch.mean(loss_lpips(2 * y_pred - 1, 2 * y - 1)) - loss_v.backward() - torch.nn.utils.clip_grad_norm_(recon.parameters(), 1.0) - - # if any gradient is NaN, skip training step - is_NAN = False - for param in recon.parameters(): - if torch.isnan(param.grad).any(): - is_NAN = True - break - if is_NAN: - print("NAN detected in gradiant, skipping training step") - i += 1 - continue - optimizer.step() - - mean_loss += (loss_v.item() - mean_loss) * (1 / i) - pbar.set_description(f"loss : {mean_loss}") - i += 1 - - # benchmarking - current_metrics = benchmark(recon, benchmark_dataset, batchsize=10) - # update metrics with current metrics - metrics["LOSS"].append(mean_loss) - for key in current_metrics: - metrics[key].append(current_metrics[key]) - - # Update learning rate - scheduler.step() - - print(f"Train time : {time.time() - start_time} s") - - # save dictionary metrics to file with json - with open(os.path.join(save, "metrics.json"), "w") as f: - json.dump(metrics, f) + trainer = Trainer( + recon, + dataset, + benchmark_dataset, + batch_size=config.training.batch_size, + loss=config.loss, + lpips=config.lpips, + optimizer=config.optimizer.type, + optimizer_lr=config.optimizer.lr, + slow_start=config.training.slow_start, + skip_NAN=config.training.skip_NAN, + algorithm_name=algorithm_name, + ) - # save pytorch model recon - torch.save(recon.state_dict(), "recon.pt") + trainer.train(n_epoch=config.training.epoch, save_pt=save) + trainer.save(path=os.path.join(save, "recon.pt")) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 20c07f7a..79468810 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "matplotlib>=3.4.2", "rawpy>=0.16.0", "paramiko>=3.2.0", + "hydra-core", ], extra_requires={"dev": ["pudb", "black"]}, ) From ff86fb27fb86078ee0d279d12e0e95f355e90ec8 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 30 Aug 2023 10:31:37 -0700 Subject: [PATCH 04/11] Torch support for Coded Aperture reconstruction (#79) * added pytorch version of tikhonov reconstruction * replaced type(...) == ... by isintance (..., ...) * Fix torch support for tikhonov. * Change docstring. * Change docstring. * Update changelog. * Added "try" before SSIM computation * removed 'try' for SSIM --------- Co-authored-by: Aaron Fargeon --- CHANGELOG.rst | 2 +- configs/mask_sim_single.yaml | 1 + lensless/hardware/mask.py | 20 ++++- lensless/recon/tikhonov.py | 132 +++++++++++++++++++++++--------- scripts/sim/mask_single_file.py | 31 +++++++- 5 files changed, 145 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6ca03bd3..b1657928 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -26,7 +26,7 @@ Added - New simulated dataset compatible with new data format ([(batch_size), depth, width, height, color]). See ``utils.dataset.SimulatedFarFieldDataset``. - New dataset for pair of original image and their measurement from a screen. See ``utils.dataset.MeasuredDataset`` and ``utils.dataset.MeasuredDatasetSimulatedOriginal``. - Support for unrolled loading and inference in the script ``admm.py``. -- Tikhonov reconstruction for coded aperture measurements (MLS / MURA). +- Tikhonov reconstruction for coded aperture measurements (MLS / MURA): numpy and Pytorch support. - New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. diff --git a/configs/mask_sim_single.yaml b/configs/mask_sim_single.yaml index f793d302..0d20efa5 100644 --- a/configs/mask_sim_single.yaml +++ b/configs/mask_sim_single.yaml @@ -8,6 +8,7 @@ files: #original: data/original/mnist_3.png save: True +use_torch: False simulation: object_height: 0.3 diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index 126d21f1..f9597bf5 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -33,6 +33,13 @@ from lensless.hardware.sensor import VirtualSensor from lensless.utils.image import resize +try: + import torch + + torch_available = True +except ImportError: + torch_available = False + class Mask(abc.ABC): """ @@ -295,12 +302,23 @@ def simulate(self, obj, snr_db=20): # Convolve image n_channels = obj.shape[-1] - meas = np.dstack([multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)]) + + if torch_available and isinstance(obj, torch.Tensor): + P = torch.from_numpy(P).float() + Q = torch.from_numpy(Q).float() + meas = torch.dstack( + [torch.linalg.multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)] + ).float() + else: + meas = np.dstack([multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)]) # Add noise if snr_db is not None: meas = add_shot_noise(meas, snr_db=snr_db) + if torch_available and isinstance(obj, torch.Tensor): + meas = meas.to(obj) + return meas diff --git a/lensless/recon/tikhonov.py b/lensless/recon/tikhonov.py index 84a88011..fb9a182d 100644 --- a/lensless/recon/tikhonov.py +++ b/lensless/recon/tikhonov.py @@ -2,8 +2,8 @@ # tikhonov.py # ================= # Authors : -# Aaron FARGEON [aa.fargeon@gmail.com] # Eric BEZZAM [ebezzam@gmail.com] +# Aaron FARGEON [aa.fargeon@gmail.com] # ############################################################################# """ @@ -20,6 +20,13 @@ import numpy as np from numpy.linalg import multi_dot +try: + import torch + + torch_available = True +except ImportError: + torch_available = False + class CodedApertureReconstruction: """ @@ -32,7 +39,7 @@ def __init__(self, mask, image_shape, P=None, Q=None, lmbd=3e-4): """ Parameters ---------- - mask : py:class:`~lensless.hardware.mask.CodedAperture` + mask : py:class:`lensless.hardware.mask.CodedAperture` Coded aperture mask object. image_shape : (`array-like` or `tuple`) The shape of the image to reconstruct. @@ -67,46 +74,97 @@ def apply(self, img): Parameters ---------- - img : :py:class:`~numpy.ndarray` + img : :py:class:`~numpy.ndarray` or :py:class:`torch.Tensor` Lensless capture measurement. Must be 3D even if grayscale. Returns ------- - :py:class:`~numpy.ndarray` + :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` Reconstructed image, in the same format as the measurement. """ - assert len(img.shape) == 3, "Object should be a 3D array (HxWxC) even if grayscale." - - # Empty matrix for reconstruction - n_channels = img.shape[-1] - x_est = np.empty([self.P.shape[1], self.Q.shape[1], n_channels]) - - # Applying reconstruction for each channel - for c in range(n_channels): - - # SVD of left matrix - UL, SL, VLh = np.linalg.svd(self.P, full_matrices=True) - VL = VLh.T - DL = np.concatenate((np.diag(SL), np.zeros([self.P.shape[0] - SL.size, SL.size]))) - singLsq = np.square(SL) - - # SVD of right matrix - UR, SR, VRh = np.linalg.svd(self.Q, full_matrices=True) - VR = VRh.T - DR = np.concatenate((np.diag(SR), np.zeros([self.Q.shape[0] - SR.size, SR.size]))) - singRsq = np.square(SR) - - # Applying analytical reconstruction - Yc = img[:, :, c] - inner = multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( - np.outer(singLsq, singRsq) + np.full(x_est.shape[0:2], self.lmbd) - ) - x_est[:, :, c] = multi_dot([VL, inner, VR.T]) - - # Non-negativity constraint: setting all negative values to 0 - x_est = x_est.clip(min=0) - - # Normalizing the image - x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) + assert ( + len(img.shape) == 3 + ), "Object should be a 3D array or tensor (HxWxC) even if grayscale." + + if torch_available and isinstance(img, torch.Tensor): + + # Empty matrix for reconstruction + n_channels = img.shape[-1] + x_est = torch.empty([self.P.shape[1], self.Q.shape[1], n_channels]) + + self.P = torch.from_numpy(self.P).float() + self.Q = torch.from_numpy(self.Q).float() + + # Applying reconstruction for each channel + for c in range(n_channels): + Yc = img[:, :, c] + + # SVD of left matrix + UL, SL, VLh = torch.linalg.svd(self.P) + VL = VLh.T + DL = torch.cat( + ( + torch.diag(SL), + torch.zeros([self.P.shape[0] - SL.size(0), SL.size(0)], device=SL.device), + ) + ) + singLsq = SL**2 + + # SVD of right matrix + UR, SR, VRh = torch.linalg.svd(self.Q) + VR = VRh.T + DR = torch.cat( + ( + torch.diag(SR), + torch.zeros([self.Q.shape[0] - SR.size(0), SR.size(0)], device=SR.device), + ) + ) + singRsq = SR**2 + + # Applying analytical reconstruction + inner = torch.linalg.multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( + torch.outer(singLsq, singRsq) + torch.full(x_est.shape[0:2], self.lmbd) + ) + x_est[:, :, c] = torch.linalg.multi_dot([VL, inner, VR.T]) + + # Non-negativity constraint: setting all negative values to 0 + x_est = torch.clamp(x_est, min=0) + + # Normalizing the image + x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) + + else: + + # Empty matrix for reconstruction + n_channels = img.shape[-1] + x_est = np.empty([self.P.shape[1], self.Q.shape[1], n_channels]) + + # Applying reconstruction for each channel + for c in range(n_channels): + + # SVD of left matrix + UL, SL, VLh = np.linalg.svd(self.P, full_matrices=True) + VL = VLh.T + DL = np.concatenate((np.diag(SL), np.zeros([self.P.shape[0] - SL.size, SL.size]))) + singLsq = np.square(SL) + + # SVD of right matrix + UR, SR, VRh = np.linalg.svd(self.Q, full_matrices=True) + VR = VRh.T + DR = np.concatenate((np.diag(SR), np.zeros([self.Q.shape[0] - SR.size, SR.size]))) + singRsq = np.square(SR) + + # Applying analytical reconstruction + Yc = img[:, :, c] + inner = multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( + np.outer(singLsq, singRsq) + np.full(x_est.shape[0:2], self.lmbd) + ) + x_est[:, :, c] = multi_dot([VL, inner, VR.T]) + + # Non-negativity constraint: setting all negative values to 0 + x_est = x_est.clip(min=0) + + # Normalizing the image + x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) return x_est diff --git a/scripts/sim/mask_single_file.py b/scripts/sim/mask_single_file.py index e8a741b5..8513e75c 100644 --- a/scripts/sim/mask_single_file.py +++ b/scripts/sim/mask_single_file.py @@ -19,6 +19,11 @@ python scripts/sim/mask_single_file.py mask.type=MURA mask.n_bits=99 simulation.flatcam=True recon.algo=tikhonov ``` +Using Torch +``` +python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=True recon.algo=tikhonov use_torch=True +``` + Simulate FlatCam with PSF simulation and Tikhonov reconstuction: ``` python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=False recon.algo=tikhonov @@ -56,6 +61,7 @@ import os from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture from lensless.recon.tikhonov import CodedApertureReconstruction +import torch @hydra.main(version_base=None, config_path="../../configs", config_name="mask_sim_single") @@ -107,6 +113,9 @@ def simulate(config): # 2) simulate measurement image = load_image(fp, verbose=True) / 255 + if config.use_torch: + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).float() flatcam_sim = config.simulation.flatcam if flatcam_sim and mask_type.upper() not in ["MURA", "MLS"]: @@ -116,17 +125,29 @@ def simulate(config): flatcam_sim = False # use far field simulator to get correct object plane sizing + psf = mask.psf + if config.use_torch: + psf = psf.transpose(2, 0, 1) + psf = torch.from_numpy(psf).float() + simulator = FarFieldSimulator( - psf=mask.psf, + psf=psf, object_height=object_height, scene2mask=scene2mask, mask2sensor=mask2sensor, sensor=sensor, snr_db=snr_db, max_val=max_val, + is_torch=config.use_torch, ) image_plane, object_plane = simulator.propagate(image, return_object_plane=True) + # channels as last dimension + if config.use_torch: + image_plane = image_plane.permute(1, 2, 0) + object_plane = object_plane.permute(1, 2, 0) + image = image.permute(1, 2, 0) + if image_format == "grayscale": image_plane = rgb2gray(image_plane) object_plane = rgb2gray(object_plane) @@ -178,6 +199,12 @@ def simulate(config): else: raise ValueError(f"Reconstruction algorithm {config.recon.algo} not recognized.") + # back to numpy for evaluation and plotting + if config.use_torch: + recovered = recovered.numpy() + object_plane = object_plane.numpy() + image_plane = image_plane.numpy() + # 4) evaluate if image_format == "grayscale": object_plane = object_plane[:, :, 0] @@ -218,7 +245,7 @@ def simulate(config): ax[4].set_title("Reconstruction") for a in ax: - a.set_xticks([]), a.set_yticks([]) + a.set_axis_off() plt.tight_layout() plt.savefig("result.png") From f67985ef02b8a9cc6fc061f41dd2f428b6a49822 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Thu, 31 Aug 2023 19:15:55 -0700 Subject: [PATCH 05/11] Add torch support for rgb2gray. (#85) * Add torch support for rgb2gray. * Fix pycsou install. * Update CHANGELOG. --- .github/workflows/python_pycsou.yml | 2 +- CHANGELOG.rst | 1 + README.rst | 8 ++-- lensless/utils/image.py | 62 ++++++++++++++++++++++------- test/test_io.py | 32 +++++++++++++-- 5 files changed, 82 insertions(+), 23 deletions(-) diff --git a/.github/workflows/python_pycsou.yml b/.github/workflows/python_pycsou.yml index d5cf1e91..61f89fa5 100644 --- a/.github/workflows/python_pycsou.yml +++ b/.github/workflows/python_pycsou.yml @@ -59,5 +59,5 @@ jobs: pip install -U pytest pip install -r recon_requirements.txt pip install -r mask_requirements.txt - pip install git+https://github.com/matthieumeo/pycsou.git@v2-dev + pip install git+https://github.com/matthieumeo/pycsou.git@38e9929c29509d350a7ff12c514e2880fdc99d6e pytest \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b1657928..90db0c99 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,7 @@ Added - Support for unrolled loading and inference in the script ``admm.py``. - Tikhonov reconstruction for coded aperture measurements (MLS / MURA): numpy and Pytorch support. - New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. +- PyTorch support for ``lensless.utils.io.rgb2gray``. Changed diff --git a/README.rst b/README.rst index 23066e12..eb5e7e72 100644 --- a/README.rst +++ b/README.rst @@ -84,15 +84,15 @@ install the library locally. python scripts/recon/admm.py -Note (25-04-2023): for using reconstruction method based on Pycsou ``lensless.apgd.APGD``, -V2 has to be installed: +Note (25-04-2023): for using the reconstruction method based on Pycsou (now [Pyxu](https://github.com/matthieumeo/pyxu)) +``lensless.apgd.APGD``, a specific commit has to be installed (as there was no release at the time of implementation): .. code:: bash - pip install git+https://github.com/matthieumeo/pycsou.git@v2-dev + pip install git+https://github.com/matthieumeo/pycsou.git@38e9929c29509d350a7ff12c514e2880fdc99d6e If PyTorch is installed, you will need to be sure to have PyTorch 2.0 or higher, -as Pycsou V2 is not compatible with earlier versions of PyTorch. Moreover, +as Pycsou is not compatible with earlier versions of PyTorch. Moreover, Pycsou requires Python within `[3.9, 3.11) `__. diff --git a/lensless/utils/image.py b/lensless/utils/image.py index 19c977e2..f3bbe28f 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -14,6 +14,7 @@ try: import torch import torchvision.transforms as tf + from torchvision.transforms.functional import rgb_to_grayscale torch_available = True except ImportError: @@ -82,10 +83,10 @@ def rgb2gray(rgb, weights=None, keepchanneldim=True): Parameters ---------- - rgb : :py:class:`~numpy.ndarray` + rgb : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` ([Depth,] Height, Width, Channel) image. weights : :py:class:`~numpy.ndarray` - [Optional] (3,) weights to convert from RGB to grayscale. + [Optional] (3,) weights to convert from RGB to grayscale. Only used for NumPy arrays. keepchanneldim : bool Whether to keep the channel dimension. Default is True. @@ -95,22 +96,53 @@ def rgb2gray(rgb, weights=None, keepchanneldim=True): Grayscale image of dimension ([depth,] height, width [, 1]). """ - if weights is None: - weights = np.array([0.299, 0.587, 0.114]) - assert len(weights) == 3 - - if len(rgb.shape) == 4: - image = np.tensordot(rgb, weights, axes=((3,), 0)) - elif len(rgb.shape) == 3: - image = np.tensordot(rgb, weights, axes=((2,), 0)) - else: - raise ValueError("Input must be at least 3D.") - if keepchanneldim: - return image[..., np.newaxis] - else: + use_torch = False + if torch_available: + if torch.is_tensor(rgb): + use_torch = True + + if use_torch: + + # move channel dimension to third to last + if len(rgb.shape) == 4: + rgb = rgb.permute(0, 3, 1, 2) + elif len(rgb.shape) == 3: + rgb = rgb.permute(2, 0, 1) + else: + raise ValueError("Input must be at least 3D.") + + image = rgb_to_grayscale(rgb) + + # move channel dimension to last + if len(rgb.shape) == 4: + image = image.permute(0, 2, 3, 1) + elif len(rgb.shape) == 3: + image = image.permute(1, 2, 0) + + if not keepchanneldim: + image = image.squeeze(-1) + return image + else: + + if weights is None: + weights = np.array([0.299, 0.587, 0.114]) + assert len(weights) == 3 + + if len(rgb.shape) == 4: + image = np.tensordot(rgb, weights, axes=((3,), 0)) + elif len(rgb.shape) == 3: + image = np.tensordot(rgb, weights, axes=((2,), 0)) + else: + raise ValueError("Input must be at least 3D.") + + if keepchanneldim: + return image[..., np.newaxis] + else: + return image + def gamma_correction(vals, gamma=2.2): """ diff --git a/test/test_io.py b/test/test_io.py index 16823e1f..5c2f8884 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,5 +1,4 @@ -from lensless.utils.io import load_data -import numpy as np +from lensless.utils.io import load_data, rgb2gray psf_fp = "data/psf/tape_rgb.png" data_fp = "data/raw_data/thumbs_up_rgb.png" @@ -26,4 +25,31 @@ def test_load_data(): assert data.dtype == dtype, dtype -test_load_data() +def test_rgb2gray(): + for is_torch in [True, False]: + psf, data = load_data( + psf_fp=psf_fp, + data_fp=data_fp, + downsample=downsample, + plot=False, + dtype="float32", + torch=is_torch, + ) + data = data[0] # drop first depth dimension + + # try with 4D + psf_gray = rgb2gray(psf, keepchanneldim=False) + assert len(psf_gray.shape) == 3 + psf_gray = rgb2gray(psf, keepchanneldim=True) + assert len(psf_gray.shape) == 4 + + # try with 3D + data_gray = rgb2gray(data, keepchanneldim=False) + assert len(data_gray.shape) == 2 + data_gray = rgb2gray(data, keepchanneldim=True) + assert len(data_gray.shape) == 3 + + +if __name__ == "__main__": + test_load_data() + test_rgb2gray() From 8dfdc554bdfbe00514095099204a9c65e8d9c25d Mon Sep 17 00:00:00 2001 From: YohannPerron <73244423+YohannPerron@users.noreply.github.com> Date: Tue, 5 Sep 2023 23:01:51 +0200 Subject: [PATCH 06/11] Trainable mask (#81) * Add support for changing the psf * First implementation of trainable mask * Fix to projection * add support for trainable mask * new datased with trainable mask * Fix comment and dataset name * Fix for SimulatedDatasetTrainableMask * Update to trainer save * If no test dataset, sample from test * Add support for l1 regularisation on mask * Support for gray mask to rgb psf * remove update frequency param * add auto gray to rgb conversion * fix update bug * Update simulation for TrainableMask * Fix SimulatedDatasetTrainableMask * Clean and changelog * Fix simulation flip * Fix not using a mask * Default config doesn't use TrainableMask * Fix PR comment * Add config for PSF fine-tuning * Added to doc * Fix / update dataset docs. * Add method to set PSF of simulator. * Add check for dataset. * Move trainable mask. * Fix trainable mask documentation. * Fix docs. --------- Co-authored-by: Eric Bezzam --- CHANGELOG.rst | 2 + configs/fine-tune_PSF.yaml | 119 ++++++++++++++++++++++++++++ configs/unrolled_recon.yaml | 34 +++++--- docs/requirements.txt | 2 +- docs/source/dataset.rst | 34 ++++++++ docs/source/mask.rst | 10 +++ lensless/hardware/trainable_mask.py | 86 ++++++++++++++++++++ lensless/recon/recon.py | 22 +++++ lensless/recon/rfft_convolve.py | 3 + lensless/recon/utils.py | 74 +++++++++++------ lensless/utils/dataset.py | 54 ++++++++++++- lensless/utils/simulation.py | 33 ++++++++ mask_requirements.txt | 2 +- recon_requirements.txt | 2 +- scripts/recon/train_unrolled.py | 91 +++++++++++++++++---- 15 files changed, 510 insertions(+), 58 deletions(-) create mode 100644 configs/fine-tune_PSF.yaml create mode 100644 lensless/hardware/trainable_mask.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 90db0c99..ddf3c78a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,8 @@ Added - Support for unrolled loading and inference in the script ``admm.py``. - Tikhonov reconstruction for coded aperture measurements (MLS / MURA): numpy and Pytorch support. - New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. +- New ``TrainableMask`` and ``TrainablePSF`` class to train/fine-tune a mask from a dataset. +- New ``SimulatedDatasetTrainableMask`` class to train/fine-tune a mask for measurement. - PyTorch support for ``lensless.utils.io.rgb2gray``. diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml new file mode 100644 index 00000000..040dc81b --- /dev/null +++ b/configs/fine-tune_PSF.yaml @@ -0,0 +1,119 @@ +hydra: + job: + chdir: True # change to output folder + +#Reconstruction algorithm +input: + # File path for recorded PSF + psf: data/DiffuserCam_Mirflickr_200_3011302021_11h43_seed11/psf.tiff + dtype: float32 + +torch: True +torch_device: 'cuda' + +preprocess: + # Image shape (height, width) for reconstruction. + shape: null + # Whether image is raw bayer data. + bayer: False + blue_gain: null + red_gain: null + # Same PSF for all channels (sum) or unique PSF for RGB. + single_psf: False + # Whether to perform construction in grayscale. + gray: False + + +display: + # How many iterations to wait for intermediate plot. + # Set to negative value for no intermediate plots. + disp: 500 + # Whether to plot results. + plot: True + # Gamma factor for plotting. + gamma: null + +# Whether to save intermediate and final reconstructions. +save: True + +reconstruction: + # Method: unrolled_admm, unrolled_fista + method: unrolled_admm + + # Hyperparameters for each method + unrolled_fista: # for unrolled_fista + # Number of iterations + n_iter: 20 + tk: 1 + learn_tk: True + unrolled_admm: + # Number of iterations + n_iter: 20 + # Hyperparameters + mu1: 1e-4 + mu2: 1e-4 + mu3: 1e-4 + tau: 2e-4 + pre_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + +#Trainable Mask +trainable_mask: + mask_type: TrainablePSF #Null or "TrainablePSF" + initial_value: "DiffuserCam" # "random" or "DiffuserCam" or "DiffuserCam_gray" + mask_lr: 1e-3 + L1_strength: 1.0 #False or float + use_mask_in_dataset : False # Work only with simulated dataset + +# Train Dataset +files: + dataset: "DiffuserCam" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + n_files: null # null to use all + +target: "object_plane" # "original" or "object_plane" or "label" + +#for simulated dataset +simulation: + grayscale: False + # random variations + object_height: 0.04 # range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) + random_shift: False + random_vflip: 0.5 + random_hflip: 0.5 + random_rotate: False + # these distance parameters are typically fixed for a given PSF + # for DiffuserCam psf # for tape_rgb psf + scene2mask: 10e-2 # scene2mask: 40e-2 + mask2sensor: 9e-3 # mask2sensor: 4e-3 + # see waveprop.devices + sensor: "rpi_hq" + snr_db: 10 + # simulate different sensor resolution + # output_dim: [24, 32] # [H, W] or null + # Downsampling for PSF + downsample: 8 + # max val in simulated measured (quantized 8 bits) + max_val: 255 + +#Training + +training: + batch_size: 8 + epoch: 10 + #In case of instable training + skip_NAN: True + slow_start: False #float how much to reduce lr for first epoch + + +optimizer: + type: Adam + lr: 1e-4 + +loss: 'l2' +# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) +lpips: 1.0 \ No newline at end of file diff --git a/configs/unrolled_recon.yaml b/configs/unrolled_recon.yaml index 621e3cfa..2673f20c 100644 --- a/configs/unrolled_recon.yaml +++ b/configs/unrolled_recon.yaml @@ -27,7 +27,7 @@ preprocess: display: # How many iterations to wait for intermediate plot. # Set to negative value for no intermediate plots. - disp: 400 + disp: 500 # Whether to plot results. plot: True # Gamma factor for plotting. @@ -48,23 +48,30 @@ reconstruction: learn_tk: True unrolled_admm: # Number of iterations - n_iter: 5 + n_iter: 20 # Hyperparameters mu1: 1e-4 mu2: 1e-4 mu3: 1e-4 tau: 2e-4 pre_process: - network : UnetRes # UnetRes or DruNet or null + network : null # UnetRes or DruNet or null depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet post_process: - network : UnetRes # UnetRes or DruNet or null + network : null # UnetRes or DruNet or null depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet -# Train Dataset +#Trainable Mask +trainable_mask: + mask_type: Null #Null or "TrainablePSF" + initial_value: "DiffuserCam_gray" # "random" or "DiffuserCam" or "DiffuserCam_gray" + mask_lr: 1e-3 + L1_strength: 1.0 #False or float + use_mask_in_dataset : True # Work only with simulated dataset +# Train Dataset files: - dataset: "DiffuserCam" # "mnist", "fashion_mnist", "cifar10", "CelebA", "DiffuserCam" + dataset: "DiffuserCam" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" n_files: null # null to use all target: "object_plane" # "original" or "object_plane" or "label" @@ -73,18 +80,19 @@ target: "object_plane" # "original" or "object_plane" or "label" simulation: grayscale: False # random variations - object_height: 0.6 # range for random height or scalar + object_height: 0.04 # range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) random_shift: False random_vflip: 0.5 random_hflip: 0.5 random_rotate: False # these distance parameters are typically fixed for a given PSF - # for tape_rgb psf # for DiffuserCam psf - scene2mask: 40e-2 # scene2mask: 10e-2 - mask2sensor: 4e-3 # mask2sensor: 9e-3 + # for DiffuserCam psf # for tape_rgb psf + scene2mask: 10e-2 # scene2mask: 40e-2 + mask2sensor: 9e-3 # mask2sensor: 4e-3 # see waveprop.devices sensor: "rpi_hq" - snr_db: 40 + snr_db: 10 # simulate different sensor resolution # output_dim: [24, 32] # [H, W] or null # Downsampling for PSF @@ -96,7 +104,7 @@ simulation: training: batch_size: 8 - epoch: 50 + epoch: 10 #In case of instable training skip_NAN: True slow_start: False #float how much to reduce lr for first epoch @@ -104,7 +112,7 @@ training: optimizer: type: Adam - lr: 1e-6 + lr: 1e-4 loss: 'l2' # set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) diff --git a/docs/requirements.txt b/docs/requirements.txt index 484c5d20..3eb1e15f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,4 +6,4 @@ torch>=1.10 torchvision>=0.15.2 torchmetrics>=0.11.4 pyFFS>=2.2.3 # for waveprop -waveprop>=0.0.5 \ No newline at end of file +waveprop>=0.0.7 \ No newline at end of file diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst index 1312e1cc..ad21defb 100644 --- a/docs/source/dataset.rst +++ b/docs/source/dataset.rst @@ -6,14 +6,48 @@ datasets for training and testing. .. automodule:: lensless.utils.dataset +Abstract base class +------------------- + +All dataset objects derive from this abstract base class, which +lays out the notion of a dataset with pairs of images: one image +is lensed (simulated or measured), and the other is lensless (simulated +or measured). + .. autoclass:: lensless.utils.dataset.DualDataset :members: _get_images_pair :special-members: __init__, __len__ + +Simulated dataset objects +------------------------- + +These dataset objects can be used for training and testing with +simulated data. The main assumption is that the imaging system +is linear shift-invariant (LSI), and that the lensless image is +the result of a convolution of the lensed image with a point-spread +function (PSF). Check out `this Medium post `__ +for more details on the simulation procedure. + +With simulated data, we can avoid the hassle of collecting a large +amount of data. However, it's important to note that the LSI assumption +can sometimes be too idealistic, in particular for large angles. + +Nevertheless, simulating data is the only option of learning the +mask / PSF. + .. autoclass:: lensless.utils.dataset.SimulatedFarFieldDataset :members: :special-members: __init__ +.. autoclass:: lensless.utils.dataset.SimulatedDatasetTrainableMask + :members: + :special-members: __init__ + + +Measured dataset objects +------------------------ + .. autoclass:: lensless.utils.dataset.MeasuredDataset :members: :special-members: __init__ diff --git a/docs/source/mask.rst b/docs/source/mask.rst index 0ad8327e..036d0f12 100644 --- a/docs/source/mask.rst +++ b/docs/source/mask.rst @@ -29,5 +29,15 @@ ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: lensless.hardware.mask.FresnelZoneAperture + :members: + :special-members: __init__ + + Trainable Mask + ~~~~~~~~~~~~~~~~~~~~~ + .. autoclass:: lensless.hardware.trainable_mask.TrainableMask + :members: + :special-members: __init__ + + .. autoclass:: lensless.hardware.trainable_mask.TrainablePSF :members: :special-members: __init__ \ No newline at end of file diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py new file mode 100644 index 00000000..593c2360 --- /dev/null +++ b/lensless/hardware/trainable_mask.py @@ -0,0 +1,86 @@ +# ############################################################################# +# trainable_mask.py +# ================== +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# ############################################################################# + +import abc +import torch + + +class TrainableMask(metaclass=abc.ABCMeta): + """ + Abstract class for defining trainable masks. + + The following abstract methods need to be defined: + + - :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`: returning the PSF of the mask. + - :py:class:`~lensless.hardware.trainable_mask.TrainableMask.project`: projecting the mask parameters to a valid space (should be a subspace of [0,1]). + + """ + + def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): + """ + Base constructor. Derived constructor may define new state variables + + Parameters + ---------- + initial_mask : :py:class:`~torch.Tensor` + Initial mask parameters. + optimizer : str, optional + Optimizer to use for updating the mask parameters, by default "Adam" + lr : float, optional + Learning rate for the mask parameters, by default 1e-3 + """ + self._mask = torch.nn.Parameter(initial_mask) + self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs) + self._counter = 0 + + @abc.abstractmethod + def get_psf(self): + """ + Abstract method for getting the PSF of the mask. Should be fully compatible with pytorch autograd. + + Returns + ------- + :py:class:`~torch.Tensor` + The PSF of the mask. + """ + raise NotImplementedError + + def update_mask(self): + """Update the mask parameters. Acoording to externaly updated gradiants.""" + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self.project() + self._counter += 1 + + @abc.abstractmethod + def project(self): + """Abstract method for projecting the mask parameters to a valid space (should be a subspace of [0,1]).""" + raise NotImplementedError + + +class TrainablePSF(TrainableMask): + """ + Class for defining an object that directly optimizes the PSF, without any constraints on what can be realized physically. + + Parameters + ---------- + is_rgb : bool, optional + Whether the mask is RGB or not, by default True. + """ + + def __init__(self, is_rgb=True, **kwargs): + super().__init__(**kwargs) + self._is_rgb = is_rgb + + def get_psf(self): + if self._is_rgb: + return self._mask.expand(-1, -1, -1, 3) + else: + return self._mask + + def project(self): + self._mask.data = torch.clamp(self._mask, 0, 1) diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 1124c289..444e3b0a 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -404,6 +404,28 @@ def get_image_estimate(self): """Get current image estimate as [Batch, Depth, Height, Width, Channels].""" return self._form_image() + def _set_psf(self, psf): + """ + Set PSF. + + Parameters + ---------- + psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + PSF to set. + """ + assert len(psf.shape) == 4, "PSF must be 4D: (depth, height, width, channels)." + assert psf.shape[3] == 3 or psf.shape[3] == 1, "PSF must either be rgb (3) or grayscale (1)" + assert self._psf.shape == psf.shape, "new PSF must have same shape as old PSF" + assert isinstance(psf, type(self._psf)), "new PSF must have same type as old PSF" + + self._psf = psf + self._convolver = RealFFTConvolve2D( + psf, + dtype=self._convolver._psf.dtype, + pad=self._convolver.pad, + norm=self._convolver.norm, + ) + def _progress(self): """ Optional method for printing progress update, e.g. relative improvement diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index 5c867cd3..34cca96a 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -57,6 +57,9 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs): self._is_rgb = psf.shape[3] == 3 assert self._is_rgb or psf.shape[3] == 1 + # save normalization + self.norm = norm + # set dtype if dtype is None: if self.is_torch: diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 54d23a1d..5f091c53 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -1,5 +1,5 @@ # ############################################################################# -# dataset.py +# utils.py # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] @@ -15,6 +15,7 @@ import matplotlib.pyplot as plt import torch from lensless.eval.benchmark import benchmark +from lensless.hardware.trainable_mask import TrainableMask from tqdm import tqdm from lensless.recon.drunet.network_unet import UNetRes @@ -222,9 +223,11 @@ def __init__( recon, train_dataset, test_dataset, + mask=None, batch_size=4, loss="l2", lpips=None, + l1_mask=None, optimizer="Adam", optimizer_lr=1e-6, slow_start=None, @@ -242,12 +245,16 @@ def __init__( Dataset to use for training. test_dataset : :py:class:`torch.utils.data.Dataset` Dataset to use for testing. + mask : TrainableMask, optional + Trainable mask to use for training. If none, training with fix psf, by default None. batch_size : int, optional Batch size to use for training, by default 4 loss : str, optional Loss function to use for training "l1" or "l2", by default "l2" lpips : float, optional the weight of the lpips(VGG) in the total loss. If None ignore. By default None + l1_mask : float, optional + the weight of the l1 norm of the mask in the total loss. If None ignore. By default None optimizer : str, optional Optimizer to use durring training. Available : "Adam". By default "Adam" optimizer_lr : float, optional @@ -263,6 +270,15 @@ def __init__( self.device = recon._psf.device self.recon = recon + + if test_dataset is None: + # split train dataset + train_size = int(0.9 * len(train_dataset)) + test_size = len(train_dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split( + train_dataset, [train_size, test_size] + ) + self.train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=batch_size, @@ -273,6 +289,15 @@ def __init__( self.lpips = lpips self.skip_NAN = skip_NAN + if mask is not None: + assert isinstance(mask, TrainableMask) + self.mask = mask + self.use_mask = True + else: + self.use_mask = False + + self.l1_mask = l1_mask + # loss if loss == "l2": self.Loss = torch.nn.MSELoss() @@ -358,8 +383,8 @@ def train_epoch(self, data_loader, disp=-1): ---------- data_loader : :py:class:`torch.utils.data.DataLoader` Data loader to use for training. - disp : int, optional - Display interval, if -1, no display, by default -1 + disp : int + Display interval, if -1, no display Returns ------- @@ -374,6 +399,11 @@ def train_epoch(self, data_loader, disp=-1): X = X.to(self.device) y = y.to(self.device) + # update psf according to mask + if self.use_mask: + self.recon._set_psf(self.mask.get_psf()) + + # forward pass y_pred = self.recon.batch_call(X.to(self.device)) # normalizing each output eps = 1e-12 @@ -404,6 +434,8 @@ def train_epoch(self, data_loader, disp=-1): loss_v = loss_v + self.lpips * torch.mean( self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) ) + if self.use_mask and self.l1_mask: + loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask)) loss_v.backward() torch.nn.utils.clip_grad_norm_(self.recon.parameters(), 1.0) @@ -421,6 +453,10 @@ def train_epoch(self, data_loader, disp=-1): continue self.optimizer.step() + # update mask + if self.use_mask: + self.mask.update_mask() + mean_loss += (loss_v.item() - mean_loss) * (1 / i) pbar.set_description(f"loss : {mean_loss}") i += 1 @@ -488,6 +524,7 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): start_time = time.time() + self.evaluate(-1, save_pt) for epoch in range(n_epoch): print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") mean_loss = self.train_epoch(self.train_dataloader, disp=disp) @@ -497,31 +534,18 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): print(f"Train time : {time.time() - start_time} s") def save(self, path="recon", include_optimizer=False): - """ - Save state of reconstruction algorithm. - - Parameters - ---------- - path : str, optional - Path to save model to, by default "recon" - include_optimizer : bool, optional - Whether to include optimizer state, by default False - - """ # create directory if it does not exist if not os.path.exists(path): os.makedirs(path) - - # TODO : ADD mask support - # # save mask - # if self.use_mask: - # torch.save(self.mask._mask, os.path.join(path, "mask.pt")) - # torch.save(self.mask._optimizer.state_dict(), os.path.join(path, "mask_optim.pt")) - # import matplotlib.pyplot as plt - - # plt.imsave( - # os.path.join(path, "psf.png"), self.mask.get_psf().detach().cpu().numpy()[0, ...] - # ) + # save mask + if self.use_mask: + torch.save(self.mask._mask, os.path.join(path, "mask.pt")) + torch.save(self.mask._optimizer.state_dict(), os.path.join(path, "mask_optim.pt")) + import matplotlib.pyplot as plt + + plt.imsave( + os.path.join(path, "psf.png"), self.mask.get_psf().detach().cpu().numpy()[0, ...] + ) # save optimizer if include_optimizer: torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 2634cb7c..67aa7a8d 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -3,6 +3,7 @@ # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# import numpy as np @@ -144,7 +145,7 @@ def __init__( dataset : :py:class:`torch.utils.data.Dataset` Dataset to propagate. Should output images with shape [H, W, C] unless ``dataset_is_CHW`` is ``True`` (and therefore images have the dimension ordering of [C, H, W]). simulator : :py:class:`lensless.utils.simulation.FarFieldSimulator` - Simulator object used on images from ``dataset``.Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. + Simulator object used on images from ``dataset``. Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. pre_transform : PyTorch Transform or None, optional Transform to apply to the images before simulation, by default ``None``. Note that this transform is applied on HCW images (different from torchvision). dataset_is_CHW : bool, optional @@ -176,7 +177,7 @@ def get_image(self, index): def _get_images_pair(self, index): # load image img, _ = self.get_image(index) - # convert to CHW for simulator and transform + # convert to HWC for simulator and transform if self.dataset_is_CHW: img = img.moveaxis(-3, -1) if self.flip_pre_sim: @@ -446,3 +447,52 @@ def __init__( lensed_fn="lensed", image_ext="npy", ) + + +class SimulatedDatasetTrainableMask(SimulatedFarFieldDataset): + """ + Dataset of propagated images (through simulation) from a Torch Dataset with learnable mask. + The `waveprop `_ package is used for the simulation, + assuming a far-field propagation and a shift-invariant system with a single point spread function (PSF). + To ensure autograd compatibility, the dataloader should have ``num_workers=0``. + """ + + def __init__( + self, + mask, + dataset, + simulator, + **kwargs, + ): + """ + Parameters + ---------- + + mask : :py:class:`lensless.hardware.trainable_mask.TrainableMask` + Mask to use for simulation. Should be a 4D tensor with shape [1, H, W, C]. Simulation of multi-depth data is not supported yet. + dataset : :py:class:`torch.utils.data.Dataset` + Dataset to propagate. Should output images with shape [H, W, C] unless ``dataset_is_CHW`` is ``True`` (and therefore images have the dimension ordering of [C, H, W]). + simulator : :py:class:`lensless.utils.simulation.FarFieldSimulator` + Simulator object used on images from ``dataset``. Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. + """ + + self._mask = mask + + temp_psf = self._mask.get_psf() + test_sim = FarFieldSimulator(psf=temp_psf, **simulator.params) + assert ( + test_sim.conv_dim == simulator.conv_dim + ).all(), "PSF shape should match simulator shape" + assert ( + not simulator.quantize + ), "Simulator should not perform quantization to maintain differentiability. Please set quantize=False" + + super(SimulatedDatasetTrainableMask, self).__init__(dataset, simulator, **kwargs) + + def _get_images_pair(self, index): + # update psf + psf = self._mask.get_psf() + self.sim.set_psf(psf) + + # return simulated images + return super()._get_images_pair(index) diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py index 36aac243..e7f7af3a 100644 --- a/lensless/utils/simulation.py +++ b/lensless/utils/simulation.py @@ -3,6 +3,7 @@ # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# import numpy as np @@ -27,6 +28,7 @@ def __init__( device_conv="cpu", random_shift=False, is_torch=False, + quantize=True, **kwargs ): """ @@ -52,6 +54,8 @@ def __init__( Whether to randomly shift the image, by default False. is_torch : bool, optional Whether to use pytorch, by default False. + quantize : bool, optional + Whether to quantize image, by default True. """ if psf is not None: @@ -70,9 +74,38 @@ def __init__( device_conv, random_shift, is_torch, + quantize, **kwargs ) + # save all the parameters in a dict + self.params = { + "object_height": object_height, + "scene2mask": scene2mask, + "mask2sensor": mask2sensor, + "sensor": sensor, + "output_dim": output_dim, + "snr_db": snr_db, + "max_val": max_val, + "device_conv": device_conv, + "random_shift": random_shift, + "is_torch": is_torch, + "quantize": quantize, + } + self.params.update(kwargs) + + def set_psf(self, psf): + """ + Set point spread function. + + Parameters + ---------- + psf : np.ndarray or torch.Tensor + Point spread function. + """ + psf = psf.squeeze().movedim(-1, 0) + return super().set_psf(psf) + def propagate(self, obj, return_object_plane=False): """ Parameters diff --git a/mask_requirements.txt b/mask_requirements.txt index ee87c51f..699ba552 100644 --- a/mask_requirements.txt +++ b/mask_requirements.txt @@ -1,3 +1,3 @@ sympy>=1.11.1 perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 -waveprop>=0.0.4 \ No newline at end of file +waveprop>=0.0.7 \ No newline at end of file diff --git a/recon_requirements.txt b/recon_requirements.txt index b9e9f324..33e12092 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -3,7 +3,7 @@ lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 click>=8.0.1 -waveprop>=0.0.3 # for simulation +waveprop>=0.0.7 # for simulation # Library for learning algorithm torch >= 2.0.0 diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 7d0a31e1..c669ea2e 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -12,6 +12,11 @@ python scripts/recon/train_unrolled.py ``` +To fine-tune the DiffuserCam PSF, use the following command: +``` +python scripts/recon/train_unrolled.py -cn fine-tune_PSF +``` + """ import hydra @@ -20,7 +25,12 @@ import numpy as np import time from lensless import UnrolledFISTA, UnrolledADMM -from lensless.utils.dataset import DiffuserCamTestDataset, SimulatedFarFieldDataset +from lensless.utils.dataset import ( + DiffuserCamTestDataset, + SimulatedFarFieldDataset, + SimulatedDatasetTrainableMask, +) +import lensless.hardware.trainable_mask from lensless.recon.utils import create_process_network from lensless.utils.image import rgb2gray from lensless.utils.simulation import FarFieldSimulator @@ -29,7 +39,7 @@ from torchvision import transforms, datasets -def simulate_dataset(config, psf): +def simulate_dataset(config, psf, mask=None): # load dataset transforms_list = [transforms.ToTensor()] data_path = os.path.join(get_original_cwd(), "data") @@ -71,9 +81,23 @@ def simulate_dataset(config, psf): # create Pytorch dataset and dataloader if n_files is not None: ds = torch.utils.data.Subset(ds, np.arange(n_files)) - ds_prop = SimulatedFarFieldDataset( - dataset=ds, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv - ) + if mask is None: + ds_prop = SimulatedFarFieldDataset( + dataset=ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + else: + ds_prop = SimulatedDatasetTrainableMask( + dataset=ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) return ds_prop @@ -96,12 +120,11 @@ def train_unrolled( data_dir=path, downsample=config.simulation.downsample ) - psf = benchmark_dataset.psf.to(device) + diffusercam_psf = benchmark_dataset.psf.to(device) background = benchmark_dataset.background # convert psf from BGR to RGB - if config.files.dataset in ["DiffuserCam"]: - psf = psf[..., [2, 1, 0]] + diffusercam_psf = diffusercam_psf[..., [2, 1, 0]] # if using a portrait dataset rotate the PSF @@ -130,7 +153,7 @@ def train_unrolled( # create reconstruction algorithm if config.reconstruction.method == "unrolled_fista": recon = UnrolledFISTA( - psf, + diffusercam_psf, n_iter=config.reconstruction.unrolled_fista.n_iter, tk=config.reconstruction.unrolled_fista.tk, pad=True, @@ -140,7 +163,7 @@ def train_unrolled( ).to(device) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( - psf, + diffusercam_psf, n_iter=config.reconstruction.unrolled_admm.n_iter, mu1=config.reconstruction.unrolled_admm.mu1, mu2=config.reconstruction.unrolled_admm.mu2, @@ -164,6 +187,25 @@ def train_unrolled( # transform from BGR to RGB transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + # create mask + if config.trainable_mask.mask_type is not None: + mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) + if config.trainable_mask.initial_value == "random": + mask = mask_class( + torch.rand_like(diffusercam_psf), optimizer="Adam", lr=config.trainable_mask.mask_lr + ) + elif config.trainable_mask.initial_value == "DiffuserCam": + mask = mask_class(diffusercam_psf, optimizer="Adam", lr=config.trainable_mask.mask_lr) + elif config.trainable_mask.initial_value == "DiffuserCam_gray": + mask = mask_class( + diffusercam_psf[:, :, :, 0, None], + optimizer="Adam", + lr=config.trainable_mask.mask_lr, + is_rgb=not config.simulation.grayscale, + ) + else: + mask = None + # load dataset and create dataloader if config.files.dataset == "DiffuserCam": # Use a ParallelDataset @@ -174,11 +216,12 @@ def train_unrolled( max_indices = config.files.n_files + 1000 data_path = os.path.join(get_original_cwd(), "data", "DiffuserCam") + assert os.path.exists(data_path), "DiffuserCam dataset not found" dataset = MeasuredDataset( root_dir=data_path, indices=range(1000, max_indices), background=background, - psf=psf, + psf=diffusercam_psf, lensless_fn="diffuser_images", lensed_fn="ground_truth_lensed", downsample=config.simulation.downsample / 4, @@ -187,17 +230,25 @@ def train_unrolled( ) else: # Use a simulated dataset - dataset = simulate_dataset(config, psf) + if config.trainable_mask.use_mask_in_dataset: + dataset = simulate_dataset(config, diffusercam_psf, mask=mask) + # the mask use will differ from the one in the benchmark dataset + print("Trainable Mask will be used in the test dataset") + benchmark_dataset = None + else: + dataset = simulate_dataset(config, diffusercam_psf, mask=None) print(f"Setup time : {time.time() - start_time} s") - + print(f"PSF shape : {diffusercam_psf.shape}") trainer = Trainer( recon, dataset, benchmark_dataset, + mask=mask, batch_size=config.training.batch_size, loss=config.loss, lpips=config.lpips, + l1_mask=config.trainable_mask.L1_strength, optimizer=config.optimizer.type, optimizer_lr=config.optimizer.lr, slow_start=config.training.slow_start, @@ -205,8 +256,18 @@ def train_unrolled( algorithm_name=algorithm_name, ) - trainer.train(n_epoch=config.training.epoch, save_pt=save) - trainer.save(path=os.path.join(save, "recon.pt")) + trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) + + if mask is not None: + print("Saving mask") + print(f"mask shape: {mask._mask.shape}") + torch.save(mask._mask, os.path.join(save, "mask.pt")) + # save as image using plt + import matplotlib.pyplot as plt + + print(f"mask max: {mask._mask.max()}") + print(f"mask min: {mask._mask.min()}") + plt.imsave(os.path.join(save, "mask.png"), mask._mask.detach().cpu().numpy()[0, ...]) if __name__ == "__main__": From 3f78c2421a170057718f01952c1cc53c3e8d42d7 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 5 Sep 2023 16:23:31 -0700 Subject: [PATCH 07/11] Fix readme rendering. (#88) --- README.rst | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index eb5e7e72..5a88de08 100644 --- a/README.rst +++ b/README.rst @@ -60,7 +60,8 @@ Python 3.9, as some Python library versions may not be available with earlier versions of Python. Moreover, its `end-of-life `__ is Oct 2025. -**Local machine** +*Local machine setup* +===================== Below are commands that worked for our configuration (Ubuntu 21.04), but there are certainly other ways to download a repository and @@ -83,9 +84,13 @@ install the library locally. # (optional) try reconstruction on local machine python scripts/recon/admm.py + # (optional) try reconstruction on local machine with GPU + python scripts/recon/admm.py -cn pytorch -Note (25-04-2023): for using the reconstruction method based on Pycsou (now [Pyxu](https://github.com/matthieumeo/pyxu)) -``lensless.apgd.APGD``, a specific commit has to be installed (as there was no release at the time of implementation): + +Note (25-04-2023): for using the :py:class:`~lensless.recon.apgd.APGD` reconstruction method based on Pycsou +(now `Pyxu `__), a specific commit has +to be installed (as there was no release at the time of implementation): .. code:: bash @@ -102,7 +107,8 @@ Moreover, ``numba`` (requirement for Pycsou V2) may require an older version of pip install numpy==1.23.5 -**Raspberry Pi** +*Raspberry Pi setup* +==================== After `flashing your Raspberry Pi with SSH enabled `__, you need to set it up for `passwordless access `__. From 4ae024747a0dbf0b56728accf5087f59b61bc696 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 5 Sep 2023 16:28:28 -0700 Subject: [PATCH 08/11] Bump version to v1.0.5. --- CHANGELOG.rst | 18 ++++++++++++++++++ lensless/version.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ddf3c78a..3ae1221a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,24 @@ Unreleased Added ~~~~~ +- Nothing + +Changed +~~~~~~~ + +- Nothing + +Bugfix +~~~~~~ + +- Nothing + +1.0.5 - (2023-09-05) +-------------------- + +Added +~~~~~ + - Sensor module. - Single-script and Telegram demo. - Link and citation for JOSS. diff --git a/lensless/version.py b/lensless/version.py index 92192eed..68cdeee4 100644 --- a/lensless/version.py +++ b/lensless/version.py @@ -1 +1 @@ -__version__ = "1.0.4" +__version__ = "1.0.5" From fa91052906ad1ba174cca4956f3297f819c08888 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 5 Sep 2023 16:56:11 -0700 Subject: [PATCH 09/11] Simplify setup config for PyPI rendering. --- setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 79468810..392fc7fe 100644 --- a/setup.py +++ b/setup.py @@ -6,15 +6,16 @@ exec(f.read()) assert __version__ is not None -with open("README.rst", "r", encoding="utf-8") as fh: - long_description = fh.read() +# with open("README.rst", "r", encoding="utf-8") as fh: +# long_description = fh.read() +long_description = "See the documentation at https://lensless.readthedocs.io/en/latest/" setuptools.setup( name="lensless", version=__version__, author="Eric Bezzam", author_email="ebezzam@gmail.com", - description="Package to control and image with a lensless camera running on a Raspberry Pi.", + description="All-in-one package for lensless imaging: design, simulation, measurement, reconstruction.", long_description=long_description, long_description_content_type="text/x-rst", url="https://github.com/LCAV/LenslessPiCam", From 753a64a1a712ab521751d26528d2c757ed7ddf45 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 20 Sep 2023 17:52:44 +0200 Subject: [PATCH 10/11] Clean up unrolled training + PSF learning + simulated datasets. (#90) * New default config * Add new config to train psf from scratch Namely, it simulates the dataset with the mask/PSF that's being optimized. * Clean up configs. * Clean up config, clearer downsample. * Clean up test set loading. * Add download for drunet. * Fix message. * Improve diffusercam test dataset api. * New object for Mirflickr dataset. * Index train set correctly. * Update requirements for reconstruction. * Update documentation. * Adapt ADMM script for intermediate output. * remove raise error * Update number epochs. * Fix normalization. * Add logic for saving best model. * Clean up PSF fine-tuning. * Clean up fine-tuning PSF. * Clean up training with simulated dataset. * Update CHANGELOG. --------- Co-authored-by: YohannPerron Co-authored-by: YohannPerron <73244423+YohannPerron@users.noreply.github.com> --- CHANGELOG.rst | 11 +- configs/defaults_recon.yaml | 2 + .../diffusercam_mirflickr_single_admm.yaml | 43 +++ configs/fine-tune_PSF.yaml | 117 +------ configs/train_pre-post-processing.yaml | 24 ++ configs/train_psf_from_scratch.yaml | 18 ++ ...led_recon.yaml => train_unrolledADMM.yaml} | 42 +-- lensless/eval/benchmark.py | 15 +- lensless/hardware/trainable_mask.py | 32 +- lensless/recon/trainable_recon.py | 44 ++- lensless/recon/utils.py | 150 +++++++-- lensless/utils/dataset.py | 152 +++++++-- lensless/utils/image.py | 17 + lensless/utils/io.py | 19 +- lensless/utils/simulation.py | 58 +++- mask_requirements.txt | 2 +- recon_requirements.txt | 3 +- scripts/eval/benchmark_recon.py | 4 +- scripts/recon/admm.py | 114 ++++++- scripts/recon/train_unrolled.py | 295 +++++++++++------- 20 files changed, 815 insertions(+), 347 deletions(-) create mode 100644 configs/diffusercam_mirflickr_single_admm.yaml create mode 100644 configs/train_pre-post-processing.yaml create mode 100644 configs/train_psf_from_scratch.yaml rename configs/{unrolled_recon.yaml => train_unrolledADMM.yaml} (74%) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3ae1221a..a0492898 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,17 +13,22 @@ Unreleased Added ~~~~~ -- Nothing +- Trainable reconstruction can return intermediate outputs (between pre- and post-processing). +- Auto-download for DRUNet model. +- ``utils.dataset.DiffuserCamMirflickr`` helper class for Mirflickr dataset. Changed ~~~~~~~ -- Nothing +- Better logic for saving best model. Based on desired metric rather than last epoch, and intermediate models can be saved. +- Optional normalization in ``utils.io.load_image``. Bugfix ~~~~~~ -- Nothing +- Support for unrolled reconstruction with grayscale, needed to copy to three channels for LPIPS. +- Fix bad train/test split for DiffuserCamMirflickr in unrolled training. + 1.0.5 - (2023-09-05) -------------------- diff --git a/configs/defaults_recon.yaml b/configs/defaults_recon.yaml index 324aa679..1771ff8a 100644 --- a/configs/defaults_recon.yaml +++ b/configs/defaults_recon.yaml @@ -8,11 +8,13 @@ input: # File path for raw data data: data/raw_data/thumbs_up_rgb.png dtype: float32 + original: null # ground truth image torch: False torch_device: 'cpu' preprocess: + normalize: True # Downsampling factor along X and Y downsample: 4 # Image shape (height, width) for reconstruction. diff --git a/configs/diffusercam_mirflickr_single_admm.yaml b/configs/diffusercam_mirflickr_single_admm.yaml new file mode 100644 index 00000000..5055bf6f --- /dev/null +++ b/configs/diffusercam_mirflickr_single_admm.yaml @@ -0,0 +1,43 @@ +# python scripts/recon/admm.py -cn diffusercam_mirflickr_single_admm +defaults: + - defaults_recon + - _self_ + + +display: + gamma: null + +input: + # File path for recorded PSF + psf: data/DiffuserCam_Test/psf.tiff + # File path for raw data + data: data/DiffuserCam_Test/diffuser/im5.npy + dtype: float32 + original: data/DiffuserCam_Test/lensed/im5.npy + +torch: True +torch_device: 'cuda:0' + +preprocess: + downsample: 8 # factor for PSF, which is 4x resolution of image + normalize: False + +admm: + # Number of iterations + n_iter: 20 + # Hyperparameters + mu1: 1e-6 + mu2: 1e-5 + mu3: 4e-5 + tau: 0.0001 + #Loading unrolled model + unrolled: True + # checkpoint_fp: pretrained_models/Pre_Unrolled_Post-DiffuserCam/model_weights.pt + checkpoint_fp: outputs/2023-09-11/22-06-49/recon.pt # pre unet and post drunet + pre_process_model: + network : UnetRes # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process_model: + network : DruNet # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + \ No newline at end of file diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml index 040dc81b..af55e03a 100644 --- a/configs/fine-tune_PSF.yaml +++ b/configs/fine-tune_PSF.yaml @@ -1,119 +1,18 @@ -hydra: - job: - chdir: True # change to output folder - -#Reconstruction algorithm -input: - # File path for recorded PSF - psf: data/DiffuserCam_Mirflickr_200_3011302021_11h43_seed11/psf.tiff - dtype: float32 - -torch: True -torch_device: 'cuda' - -preprocess: - # Image shape (height, width) for reconstruction. - shape: null - # Whether image is raw bayer data. - bayer: False - blue_gain: null - red_gain: null - # Same PSF for all channels (sum) or unique PSF for RGB. - single_psf: False - # Whether to perform construction in grayscale. - gray: False - - -display: - # How many iterations to wait for intermediate plot. - # Set to negative value for no intermediate plots. - disp: 500 - # Whether to plot results. - plot: True - # Gamma factor for plotting. - gamma: null - -# Whether to save intermediate and final reconstructions. -save: True - -reconstruction: - # Method: unrolled_admm, unrolled_fista - method: unrolled_admm - - # Hyperparameters for each method - unrolled_fista: # for unrolled_fista - # Number of iterations - n_iter: 20 - tk: 1 - learn_tk: True - unrolled_admm: - # Number of iterations - n_iter: 20 - # Hyperparameters - mu1: 1e-4 - mu2: 1e-4 - mu3: 1e-4 - tau: 2e-4 - pre_process: - network : null # UnetRes or DruNet or null - depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet - post_process: - network : null # UnetRes or DruNet or null - depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet +# python scripts/recon/train_unrolled.py -cn fine-tune_PSF +defaults: + - train_unrolledADMM + - _self_ #Trainable Mask trainable_mask: mask_type: TrainablePSF #Null or "TrainablePSF" - initial_value: "DiffuserCam" # "random" or "DiffuserCam" or "DiffuserCam_gray" + initial_value: psf mask_lr: 1e-3 L1_strength: 1.0 #False or float - use_mask_in_dataset : False # Work only with simulated dataset - -# Train Dataset -files: - dataset: "DiffuserCam" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" - n_files: null # null to use all - -target: "object_plane" # "original" or "object_plane" or "label" - -#for simulated dataset -simulation: - grayscale: False - # random variations - object_height: 0.04 # range for random height or scalar - flip: True # change the orientation of the object (from vertical to horizontal) - random_shift: False - random_vflip: 0.5 - random_hflip: 0.5 - random_rotate: False - # these distance parameters are typically fixed for a given PSF - # for DiffuserCam psf # for tape_rgb psf - scene2mask: 10e-2 # scene2mask: 40e-2 - mask2sensor: 9e-3 # mask2sensor: 4e-3 - # see waveprop.devices - sensor: "rpi_hq" - snr_db: 10 - # simulate different sensor resolution - # output_dim: [24, 32] # [H, W] or null - # Downsampling for PSF - downsample: 8 - # max val in simulated measured (quantized 8 bits) - max_val: 255 #Training - training: - batch_size: 8 - epoch: 10 - #In case of instable training - skip_NAN: True - slow_start: False #float how much to reduce lr for first epoch - + save_every: 5 -optimizer: - type: Adam - lr: 1e-4 - -loss: 'l2' -# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) -lpips: 1.0 \ No newline at end of file +display: + gamma: 2.2 diff --git a/configs/train_pre-post-processing.yaml b/configs/train_pre-post-processing.yaml new file mode 100644 index 00000000..f4d6ba98 --- /dev/null +++ b/configs/train_pre-post-processing.yaml @@ -0,0 +1,24 @@ +# python scripts/recon/train_unrolled.py -cn train_pre-post-processing +defaults: + - train_unrolledADMM + - _self_ + +display: + disp: 400 + +reconstruction: + method: unrolled_admm + + pre_process: + network: UnetRes + depth: 2 + post_process: + network: DruNet + depth: 4 + +training: + epoch: 50 + slow_start: 0.01 + +loss: l2 +lpips: 1.0 diff --git a/configs/train_psf_from_scratch.yaml b/configs/train_psf_from_scratch.yaml new file mode 100644 index 00000000..b4eef0ed --- /dev/null +++ b/configs/train_psf_from_scratch.yaml @@ -0,0 +1,18 @@ +# python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +defaults: + - train_unrolledADMM + - _self_ + +# Train Dataset +files: + dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + celeba_root: /scratch/bezzam + downsample: 8 + +#Trainable Mask +trainable_mask: + mask_type: TrainablePSF #Null or "TrainablePSF" + initial_value: "random" + +simulation: + grayscale: False diff --git a/configs/unrolled_recon.yaml b/configs/train_unrolledADMM.yaml similarity index 74% rename from configs/unrolled_recon.yaml rename to configs/train_unrolledADMM.yaml index 2673f20c..3871be0d 100644 --- a/configs/unrolled_recon.yaml +++ b/configs/train_unrolledADMM.yaml @@ -1,29 +1,20 @@ +# python scripts/recon/train_unrolled.py hydra: job: chdir: True # change to output folder -#Reconstruction algorithm -input: - # File path for recorded PSF - psf: data/DiffuserCam_Mirflickr_200_3011302021_11h43_seed11/psf.tiff - dtype: float32 +# Dataset +files: + dataset: data/DiffuserCam # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + psf: data/psf.tiff + diffusercam_psf: True + n_files: null # null to use all for both train/test + downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution torch: True torch_device: 'cuda' -preprocess: - # Image shape (height, width) for reconstruction. - shape: null - # Whether image is raw bayer data. - bayer: False - blue_gain: null - red_gain: null - # Same PSF for all channels (sum) or unique PSF for RGB. - single_psf: False - # Whether to perform construction in grayscale. - gray: False - - display: # How many iterations to wait for intermediate plot. # Set to negative value for no intermediate plots. @@ -64,15 +55,11 @@ reconstruction: #Trainable Mask trainable_mask: mask_type: Null #Null or "TrainablePSF" - initial_value: "DiffuserCam_gray" # "random" or "DiffuserCam" or "DiffuserCam_gray" + # "random" (with shape of config.files.psf) or "psf" (using config.files.psf) + initial_value: psf + grayscale: False mask_lr: 1e-3 L1_strength: 1.0 #False or float - use_mask_in_dataset : True # Work only with simulated dataset - -# Train Dataset -files: - dataset: "DiffuserCam" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" - n_files: null # null to use all target: "object_plane" # "original" or "object_plane" or "label" @@ -98,13 +85,16 @@ simulation: # Downsampling for PSF downsample: 8 # max val in simulated measured (quantized 8 bits) + quantize: False # must be False for differentiability max_val: 255 #Training training: batch_size: 8 - epoch: 10 + epoch: 50 + metric_for_best_model: null # e.g. LPIPS_Vgg, null does test loss + save_every: null #In case of instable training skip_NAN: True slow_start: False #float how much to reduce lr for first epoch diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index f93b754d..885766f3 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -93,7 +93,20 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): if metric == "ReconstructionError": metrics_values[metric] += model.reconstruction_error().cpu().item() else: - metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() + if "LPIPS" in metric: + if prediction.shape[1] == 1: + # LPIPS needs 3 channels + metrics_values[metric] += ( + metrics[metric]( + prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + ) + .cpu() + .item() + ) + else: + metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() + else: + metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() model.reset() diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 593c2360..9bc70bc8 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -3,13 +3,15 @@ # ================== # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# import abc import torch +from lensless.utils.image import is_grayscale -class TrainableMask(metaclass=abc.ABCMeta): +class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): """ Abstract class for defining trainable masks. @@ -33,6 +35,7 @@ def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): lr : float, optional Learning rate for the mask parameters, by default 1e-3 """ + super().__init__() self._mask = torch.nn.Parameter(initial_mask) self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs) self._counter = 0 @@ -68,18 +71,31 @@ class TrainablePSF(TrainableMask): Parameters ---------- - is_rgb : bool, optional - Whether the mask is RGB or not, by default True. + grayscale : bool, optional + Whether mask should be returned as grayscale when calling :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`. + Otherwise PSF will be returned as RGB. By default False. """ - def __init__(self, is_rgb=True, **kwargs): - super().__init__(**kwargs) - self._is_rgb = is_rgb + def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs): + super().__init__(initial_mask, optimizer, lr, **kwargs) + assert ( + len(initial_mask.shape) == 4 + ), "Mask must be of shape (depth, height, width, channels)" + self.grayscale = grayscale + self._is_grayscale = is_grayscale(initial_mask) + if grayscale: + assert self._is_grayscale, "Mask must be grayscale" def get_psf(self): - if self._is_rgb: - return self._mask.expand(-1, -1, -1, 3) + if self._is_grayscale: + if self.grayscale: + # simulation in grayscale + return self._mask + else: + # replicate to 3 channels + return self._mask.expand(-1, -1, -1, 3) else: + # assume RGB return self._mask def project(self): diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index e554f6b0..82fd883d 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -5,7 +5,10 @@ # Yohann PERRON [yohann.perron@gmail.com] # ############################################################################# +import pathlib as plib +from matplotlib import pyplot as plt from lensless.recon.recon import ReconstructionAlgorithm +from lensless.utils.plot import plot_image try: import torch @@ -153,7 +156,15 @@ def batch_call(self, batch): return image_est def apply( - self, disp_iter=10, plot_pause=0.2, plot=True, save=False, gamma=None, ax=None, reset=True + self, + disp_iter=10, + plot_pause=0.2, + plot=True, + save=False, + gamma=None, + ax=None, + reset=True, + output_intermediate=False, ): """ Method for performing iterative reconstruction. Contrary to non-trainable reconstruction @@ -178,6 +189,8 @@ def apply( Gamma correction factor to apply for plots. Default is None. ax : :py:class:`~matplotlib.axes.Axes`, optional `Axes` object to fill for plotting/saving, default is to create one. + output_intermediate : bool, optional + Whether to output intermediate reconstructions after preprocessing and before postprocessing. Returns ------- @@ -188,8 +201,11 @@ def apply( returning if `plot` or `save` is True. """ + pre_processed_image = None if self.pre_process is not None: self._data = self.pre_process(self._data, self.pre_process_param) + if output_intermediate: + pre_processed_image = self._data[0, ...].clone() im = super(TrainableReconstructionAlgorithm, self).apply( n_iter=self._n_iter, @@ -201,6 +217,30 @@ def apply( ax=ax, reset=reset, ) + + # remove plot if returned + if plot: + im, _ = im + + # post process data + pre_post_process_image = None if self.post_process is not None: + # apply post process + if output_intermediate: + pre_post_process_image = im.clone() im = self.post_process(im, self.post_process_param) - return im + + if plot: + ax = plot_image(self._get_numpy_data(im[0]), ax=ax, gamma=gamma) + ax.set_title( + "Final reconstruction after {} iterations and post process".format(self._n_iter) + ) + if save: + plt.savefig(plib.Path(save) / "final.png") + + if output_intermediate: + return im, pre_processed_image, pre_post_process_image + elif plot: + return im, ax + else: + return im diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 5f091c53..2409dd80 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -9,25 +9,28 @@ import json import math +import numpy as np +import matplotlib.pyplot as plt import time from hydra.utils import get_original_cwd import os -import matplotlib.pyplot as plt import torch from lensless.eval.benchmark import benchmark from lensless.hardware.trainable_mask import TrainableMask from tqdm import tqdm from lensless.recon.drunet.network_unet import UNetRes +from lensless.utils.io import save_image +from lensless.utils.plot import plot_image -def load_drunet(model_path, n_channels=3, requires_grad=False): +def load_drunet(model_path=None, n_channels=3, requires_grad=False): """ Load a pre-trained Drunet model. Parameters ---------- - model_path : str - Path to pre-trained model. + model_path : str, optional + Path to pre-trained model. Download if not provided. n_channels : int Number of channels in input image. requires_grad : bool @@ -39,6 +42,25 @@ def load_drunet(model_path, n_channels=3, requires_grad=False): Loaded model. """ + if model_path is None: + model_path = os.path.join(get_original_cwd(), "models", "drunet_color.pth") + if not os.path.exists(model_path): + try: + from torchvision.datasets.utils import download_url + except ImportError: + exit() + msg = "Do you want to download the pretrained DRUNet model (130MB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + output_path = os.path.join(get_original_cwd(), "models") + if valid: + url = "https://drive.switch.ch/index.php/s/jTdeMHom025RFRQ/download" + filename = "drunet_color.pth" + download_url(url, output_path, filename=filename) + + assert os.path.exists(model_path), f"Model path {model_path} does not exist" + model = UNetRes( in_nc=n_channels + 1, out_nc=n_channels, @@ -192,9 +214,7 @@ def create_process_network(network, depth, device="cpu"): if network == "DruNet": from lensless.recon.utils import load_drunet - process = load_drunet( - os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True - ).to(device) + process = load_drunet(requires_grad=True).to(device) process_name = "DruNet" elif network == "UnetRes": from lensless.recon.drunet.network_unet import UNetRes @@ -223,6 +243,7 @@ def __init__( recon, train_dataset, test_dataset, + test_size=0.15, mask=None, batch_size=4, loss="l2", @@ -233,10 +254,19 @@ def __init__( slow_start=None, skip_NAN=False, algorithm_name="Unknown", + metric_for_best_model=None, + save_every=None, + gamma=None, ): """ Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. + The train and test metrics at the end of each epoch can be found in ``self.metrics``, + with "LOSS" being the train loss. The test loss can be found in "MSE" (if loss is "l2") or + "MAE" (if loss is "l1"). If ``lpips`` is not None, the LPIPS loss is also added + to the train loss, such that the test loss can be computed as "MSE" + ``lpips`` * "LPIPS_Vgg" + (or "MAE" + ``lpips`` * "LPIPS_Vgg"). + Parameters ---------- recon : :py:class:`lensless.TrainableReconstructionAlgorithm` @@ -245,39 +275,51 @@ def __init__( Dataset to use for training. test_dataset : :py:class:`torch.utils.data.Dataset` Dataset to use for testing. + test_size : float, optional + If test_dataset is None, fraction of the train dataset to use for testing, by default 0.15. mask : TrainableMask, optional Trainable mask to use for training. If none, training with fix psf, by default None. batch_size : int, optional - Batch size to use for training, by default 4 + Batch size to use for training, by default 4. loss : str, optional - Loss function to use for training "l1" or "l2", by default "l2" + Loss function to use for training "l1" or "l2", by default "l2". lpips : float, optional - the weight of the lpips(VGG) in the total loss. If None ignore. By default None + the weight of the lpips(VGG) in the total loss. If None ignore. By default None. l1_mask : float, optional - the weight of the l1 norm of the mask in the total loss. If None ignore. By default None + the weight of the l1 norm of the mask in the total loss. If None ignore. By default None. optimizer : str, optional - Optimizer to use durring training. Available : "Adam". By default "Adam" + Optimizer to use durring training. Available : "Adam". By default "Adam". optimizer_lr : float, optional - Learning rate for the optimizer, by default 1e-6 + Learning rate for the optimizer, by default 1e-6. slow_start : float, optional Multiplicative factor to reduce the learning rate during the first two epochs. If None, ignored. Default is None. skip_NAN : bool, optional Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False algorithm_name : str, optional Algorithm name for logging, by default "Unknown". + metric_for_best_model : str, optional + Metric to use for saving the best model. If None, will default to evaluation loss. Default is None. + save_every : int, optional + Save model every ``save_every`` epochs. If None, just save best model. + gamma : float, optional + Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None. + """ self.device = recon._psf.device self.recon = recon + assert train_dataset is not None if test_dataset is None: + assert test_size < 1.0 and test_size > 0.0 # split train dataset - train_size = int(0.9 * len(train_dataset)) + train_size = int((1 - test_size) * len(train_dataset)) test_size = len(train_dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split( train_dataset, [train_size, test_size] ) + print(f"Train size : {train_size}, Test size : {test_size}") self.train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, @@ -297,6 +339,7 @@ def __init__( self.use_mask = False self.l1_mask = l1_mask + self.gamma = gamma # loss if loss == "l2": @@ -345,7 +388,7 @@ def learning_rate_function(epoch): ) self.metrics = { - "LOSS": [], + "LOSS": [], # train loss "MSE": [], "MAE": [], "LPIPS_Vgg": [], @@ -355,7 +398,15 @@ def learning_rate_function(epoch): "ReconstructionError": [], "n_iter": self.recon._n_iter, "algorithm": algorithm_name, + "metric_for_best_model": metric_for_best_model, + "best_epoch": 0, + "best_eval_score": 0 + if metric_for_best_model == "PSNR" or metric_for_best_model == "SSIM" + else np.inf, } + if metric_for_best_model is not None: + assert metric_for_best_model in self.metrics.keys() + self.save_every = save_every # Backward hook that detect NAN in the gradient and print the layer weights if not self.skip_NAN: @@ -430,6 +481,12 @@ def train_epoch(self, data_loader, disp=-1): loss_v = self.Loss(y_pred, y) if self.lpips: + + if y_pred.shape[1] == 1: + # if only one channel, repeat for LPIPS + y_pred = y_pred.repeat(1, 3, 1, 1) + y = y.repeat(1, 3, 1, 1) + # value for LPIPS needs to be in range [-1, 1] loss_v = loss_v + self.lpips * torch.mean( self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) @@ -489,7 +546,18 @@ def evaluate(self, mean_loss, save_pt): with open(os.path.join(save_pt, "metrics.json"), "w") as f: json.dump(self.metrics, f) - def on_epoch_end(self, mean_loss, save_pt): + # check best metric + if self.metrics["metric_for_best_model"] is None: + eval_loss = current_metrics["MSE"] + if self.lpips is not None: + eval_loss += self.lpips * current_metrics["LPIPS_Vgg"] + if self.use_mask and self.l1_mask: + eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy())) + return eval_loss + else: + return current_metrics[self.metrics["metric_for_best_model"]] + + def on_epoch_end(self, mean_loss, save_pt, epoch): """ Called at the end of each epoch. @@ -499,14 +567,35 @@ def on_epoch_end(self, mean_loss, save_pt): Mean loss of the last epoch. save_pt : str Path to save metrics dictionary to. If None, no logging of metrics. + epoch : int + Current epoch. """ if save_pt is None: # Use current directory save_pt = os.getcwd() # save model - self.save(path=save_pt, include_optimizer=False) - self.evaluate(mean_loss, save_pt) + # self.save(path=save_pt, include_optimizer=False) + epoch_eval_metric = self.evaluate(mean_loss, save_pt) + new_best = False + if ( + self.metrics["metric_for_best_model"] == "PSNR" + or self.metrics["metric_for_best_model"] == "SSIM" + ): + if epoch_eval_metric > self.metrics["best_eval_score"]: + self.metrics["best_eval_score"] = epoch_eval_metric + new_best = True + else: + if epoch_eval_metric < self.metrics["best_eval_score"]: + self.metrics["best_eval_score"] = epoch_eval_metric + new_best = True + + if new_best: + self.metrics["best_epoch"] = epoch + self.save(path=save_pt, include_optimizer=False, epoch="BEST") + + if self.save_every is not None and epoch % self.save_every == 0: + self.save(path=save_pt, include_optimizer=False, epoch=epoch) def train(self, n_epoch=1, save_pt=None, disp=-1): """ @@ -528,26 +617,31 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): for epoch in range(n_epoch): print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") mean_loss = self.train_epoch(self.train_dataloader, disp=disp) - self.on_epoch_end(mean_loss, save_pt) + # offset because of evaluate before loop + self.on_epoch_end(mean_loss, save_pt, epoch + 1) self.scheduler.step() print(f"Train time : {time.time() - start_time} s") - def save(self, path="recon", include_optimizer=False): + def save(self, epoch, path="recon", include_optimizer=False): # create directory if it does not exist if not os.path.exists(path): os.makedirs(path) # save mask if self.use_mask: - torch.save(self.mask._mask, os.path.join(path, "mask.pt")) - torch.save(self.mask._optimizer.state_dict(), os.path.join(path, "mask_optim.pt")) - import matplotlib.pyplot as plt - - plt.imsave( - os.path.join(path, "psf.png"), self.mask.get_psf().detach().cpu().numpy()[0, ...] + torch.save(self.mask._mask, os.path.join(path, f"mask_epoch{epoch}.pt")) + torch.save( + self.mask._optimizer.state_dict(), os.path.join(path, f"mask_optim_epoch{epoch}.pt") ) + + psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] + psf_np = psf_np.squeeze() # remove (potential) singleton color channel + save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png")) + plot_image(psf_np, gamma=self.gamma) + plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png")) + # save optimizer if include_optimizer: - torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) + torch.save(self.optimizer.state_dict(), os.path.join(path, f"optim_epoch{epoch}.pt")) # save recon - torch.save(self.recon.state_dict(), os.path.join(path, "recon.pt")) + torch.save(self.recon.state_dict(), os.path.join(path, f"recon_epoch{epoch}")) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 67aa7a8d..a5a2e8a9 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -26,7 +26,9 @@ class DualDataset(Dataset): def __init__( self, indices=None, + # psf_path=None, background=None, + # background_pix=(0, 15), downsample=1, flip=False, transform_lensless=None, @@ -38,18 +40,22 @@ def __init__( Parameters ---------- - indices : range or int or None - Indices of the images to use in the dataset (if integer, it should be interpreted as range(indices)), by default None. - background : :py:class:`~torch.Tensor` or None, optional - If not ``None``, background is removed from lensless images, by default ``None``. - downsample : int, optional - Downsample factor of the lensless images, by default 1. - flip : bool, optional - If ``True``, lensless images are flipped, by default ``False``. - transform_lensless : PyTorch Transform or None, optional - Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). - transform_lensed : PyTorch Transform or None, optional - Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + indices : range or int or None + Indices of the images to use in the dataset (if integer, it should be interpreted as range(indices)), by default None. + psf_path : str + Path to the PSF of the imaging system, by default None. + background : :py:class:`~torch.Tensor` or None, optional + If not ``None``, background is removed from lensless images, by default ``None``. If PSF is provided, background is estimated from the PSF. + background_pix : tuple, optional + Pixels to use for background estimation, by default (0, 15). + downsample : int, optional + Downsample factor of the lensless images, by default 1. + flip : bool, optional + If ``True``, lensless images are flipped, by default ``False``. + transform_lensless : PyTorch Transform or None, optional + Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + transform_lensed : PyTorch Transform or None, optional + Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). """ if isinstance(indices, int): indices = range(indices) @@ -60,6 +66,21 @@ def __init__( self.transform_lensless = transform_lensless self.transform_lensed = transform_lensed + # self.psf = None + # if psf_path is not None: + # psf, background = load_psf( + # psf_path, + # downsample=downsample, + # return_float=True, + # return_bg=True, + # bg_pix=background_pix, + # ) + # if self.background is None: + # self.background = background + # self.psf = torch.from_numpy(psf) + # if self.transform_lensless is not None: + # self.psf = self.transform_lensless(self.psf) + @abstractmethod def __len__(self): """ @@ -151,7 +172,7 @@ def __init__( dataset_is_CHW : bool, optional If True, the input dataset is expected to output images with shape [C, H, W], by default ``False``. flip : bool, optional - If True, images are flipped beffore the simulation, by default ``False``.. + If True, images are flipped beffore the simulation, by default ``False``. """ # we do the flipping before the simualtion @@ -171,6 +192,10 @@ def __init__( assert simulator.fft_shape is not None, "Simulator should have a psf" self.sim = simulator + @property + def psf(self): + return self.sim.get_psf() + def get_image(self, index): return self.dataset[index] @@ -185,7 +210,14 @@ def _get_images_pair(self, index): if self._pre_transform is not None: img = self._pre_transform(img) - lensless, lensed = self.sim.propagate(img, return_object_plane=True) + lensless, lensed = self.sim.propagate_image(img, return_object_plane=True) + + if lensed.shape[-1] == 1 and lensless.shape[-1] == 3: + # copy to 3 channels + lensed = lensed.repeat(1, 1, 3) + assert ( + lensed.shape[-1] == lensless.shape[-1] + ), "Lensed and lensless should have same number of channels" return lensless, lensed @@ -240,6 +272,9 @@ def __init__( self.root_dir = root_dir self.lensless_dir = os.path.join(root_dir, lensless_fn) self.original_dir = os.path.join(root_dir, original_fn) + assert os.path.isdir(self.lensless_dir) + assert os.path.isdir(self.original_dir) + self.image_ext = image_ext.lower() self.original_ext = original_ext.lower() if original_ext is not None else image_ext.lower() @@ -295,7 +330,7 @@ def _get_images_pair(self, idx): # project original image to lensed space with torch.no_grad(): - lensed = self.sim.propagate() + lensed = self.sim.propagate_image() return lensless, lensed @@ -336,6 +371,9 @@ def __init__( self.root_dir = root_dir self.lensless_dir = os.path.join(root_dir, lensless_fn) self.lensed_dir = os.path.join(root_dir, lensed_fn) + assert os.path.isdir(self.lensless_dir) + assert os.path.isdir(self.lensed_dir) + self.image_ext = image_ext.lower() files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) @@ -359,6 +397,7 @@ def _get_images_pair(self, idx): lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) lensless = np.load(lensless_fp) lensed = np.load(lensed_fp) + else: # more standard image formats: png, jpg, tiff, etc. lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) @@ -378,6 +417,59 @@ def _get_images_pair(self, idx): return lensless, lensed +class DiffuserCamMirflickr(MeasuredDataset): + """ + Helper class for DiffuserCam Mirflickr dataset. + + Note that image colors are in BGR format: https://github.com/Waller-Lab/LenslessLearning/blob/master/utils.py#L432 + """ + + def __init__( + self, + dataset_dir, + psf_path, + downsample=2, + **kwargs, + ): + + psf, background = load_psf( + psf_path, + downsample=downsample * 4, # PSF is 4x the resolution of the images + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + self.psf = transform_BRG2RGB(torch.from_numpy(psf)) + self.allowed_idx = np.arange(2, 25001) + + super().__init__( + root_dir=dataset_dir, + background=background, + downsample=downsample, + flip=False, + transform_lensless=transform_BRG2RGB, + transform_lensed=transform_BRG2RGB, + lensless_fn="diffuser_images", + lensed_fn="ground_truth_lensed", + image_ext="npy", + **kwargs, + ) + + def _get_images_pair(self, idx): + + assert idx >= self.allowed_idx.min(), f"idx should be >= {self.allowed_idx.min()}" + assert idx <= self.allowed_idx.max(), f"idx should be <= {self.allowed_idx.max()}" + + fn = f"im{idx}.npy" + lensless_fp = os.path.join(self.lensless_dir, fn) + lensed_fp = os.path.join(self.lensed_dir, fn) + lensless = np.load(lensless_fp) + lensed = np.load(lensed_fp) + + return lensless, lensed + + class DiffuserCamTestDataset(MeasuredDataset): """ Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking. @@ -385,8 +477,8 @@ class DiffuserCamTestDataset(MeasuredDataset): def __init__( self, - data_dir="data", - n_files=200, + data_dir=None, + n_files=None, downsample=2, ): """ @@ -396,17 +488,20 @@ def __init__( Parameters ---------- data_dir : str, optional - The path to the folder containing the DiffuserCam_Test dataset, by default "data". + The path to ``DiffuserCam_Test`` dataset, by default looks inside the ``data`` folder. n_files : int, optional - Number of image pairs to load in the dataset , by default 200. + Number of image pairs to load in the dataset , by default use all. downsample : int, optional - Downsample factor of the lensless images, by default 8. + Downsample factor of the lensless images, by default 2. Note that the PSF has a resolution of 4x of the images. """ # download dataset if necessary - main_dir = data_dir - data_dir = os.path.join(data_dir, "DiffuserCam_Test") + if data_dir is None: + data_dir = os.path.join( + os.path.dirname(__file__), "..", "..", "data", "DiffuserCam_Test" + ) if not os.path.isdir(data_dir): + main_dir = os.path.join(os.path.dirname(__file__), "..", "..", "data") print("No dataset found for benchmarking.") try: from torchvision.datasets.utils import download_and_extract_archive @@ -424,7 +519,7 @@ def __init__( psf_fp = os.path.join(data_dir, "psf.tiff") psf, background = load_psf( psf_fp, - downsample=downsample, + downsample=downsample * 4, # PSF is 4x the resolution of the images return_float=True, return_bg=True, bg_pix=(0, 15), @@ -435,11 +530,16 @@ def __init__( self.psf = transform_BRG2RGB(torch.from_numpy(psf)) + if n_files is None: + indices = None + else: + indices = range(n_files) + super().__init__( root_dir=data_dir, - indices=range(n_files), + indices=indices, background=background, - downsample=downsample / 4, + downsample=downsample, flip=False, transform_lensless=transform_BRG2RGB, transform_lensed=transform_BRG2RGB, @@ -492,7 +592,7 @@ def __init__( def _get_images_pair(self, index): # update psf psf = self._mask.get_psf() - self.sim.set_psf(psf) + self.sim.set_point_spread_function(psf) # return simulated images return super()._get_images_pair(index) diff --git a/lensless/utils/image.py b/lensless/utils/image.py index f3bbe28f..748aaf50 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -77,6 +77,23 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC): return np.clip(resized, min_val, max_val) +def is_grayscale(img): + """ + Check if image is RGB. Assuming image is of shape ([depth,] height, width, color). + + Parameters + ---------- + img : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Image array. + + Returns + ------- + bool + Whether image is RGB. + """ + return img.shape[-1] == 1 + + def rgb2gray(rgb, weights=None, keepchanneldim=True): """ Convert RGB array to grayscale. diff --git a/lensless/utils/io.py b/lensless/utils/io.py index f502719a..1b2b234f 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -34,6 +34,7 @@ def load_image( return_float=False, shape=None, dtype=None, + normalize=True, ): """ Load image as numpy array. @@ -73,6 +74,8 @@ def load_image( Shape (H, W, C) to resize to. dtype : str, optional Data type of returned data. Default is to use that of input. + normalize : bool, default True + If ``return_float``, whether to normalize data to maximum value of 1. Returns ------- @@ -136,7 +139,7 @@ def load_image( if bg is not None: # if bg is float vector, turn into int-valued vector - if bg.max() <= 1: + if bg.max() <= 1 and img.dtype not in [np.float32, np.float64]: bg = bg * get_max_val(img) img = img - bg @@ -160,7 +163,8 @@ def load_image( dtype = np.float32 assert dtype == np.float32 or dtype == np.float64 img = img.astype(dtype) - img /= img.max() + if normalize: + img /= img.max() else: if dtype is None: @@ -336,6 +340,7 @@ def load_psf( def load_data( psf_fp, data_fp, + return_float=True, downsample=None, bg_pix=(5, 25), plot=True, @@ -350,6 +355,7 @@ def load_data( shape=None, torch=False, torch_device="cpu", + normalize=False, ): """ Load data for image reconstruction. @@ -360,6 +366,8 @@ def load_data( Full path to PSF file. data_fp : str Full path to measurement file. + return_float : bool, optional + Whether to return PSF as float array, or unsigned int. downsample : int or float Downsampling factor. bg_pix : tuple, optional @@ -386,6 +394,8 @@ def load_data( Whether to sum RGB channels into single PSF, same across channels. Done in "Learned reconstructions for practical mask-based lensless imaging" of Kristina Monakhova et. al. + normalize : bool default True + Whether to normalize data to maximum value of 1. Returns ------- @@ -415,7 +425,7 @@ def load_data( psf, bg = load_psf( psf_fp, downsample=downsample, - return_float=True, + return_float=return_float, bg_pix=bg_pix, return_bg=True, flip=flip, @@ -437,8 +447,9 @@ def load_data( red_gain=red_gain, bg=bg, as_4d=True, - return_float=True, + return_float=return_float, shape=shape, + normalize=normalize, ) if data.shape != psf.shape: diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py index e7f7af3a..b77fabcb 100644 --- a/lensless/utils/simulation.py +++ b/lensless/utils/simulation.py @@ -6,8 +6,8 @@ # Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# -import numpy as np from waveprop.simulation import FarFieldSimulator as FarFieldSimulator_wp +import torch class FarFieldSimulator(FarFieldSimulator_wp): @@ -34,7 +34,7 @@ def __init__( """ Parameters ---------- - psf : np.ndarray, optional. + psf : np.ndarray or torch.Tensor, optional. Point spread function. If not provided, return image at object plane. object_height : float or (float, float) Height of object in meters. Or range of values to randomly sample from. @@ -58,9 +58,15 @@ def __init__( Whether to quantize image, by default True. """ - if psf is not None: - # convert HWC to CHW - psf = psf.squeeze().movedim(-1, 0) + assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" + + if torch.is_tensor(psf): + # drop depth dimension, and convert HWC to CHW + psf = psf[0].movedim(-1, 0) + assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + psf = psf[0] + assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" super().__init__( object_height, @@ -78,6 +84,13 @@ def __init__( **kwargs ) + if self.is_torch: + assert self.psf.shape[0] == 1 or self.psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + assert ( + self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 + ), "PSF must have 1 or 3 channels" + # save all the parameters in a dict self.params = { "object_height": object_height, @@ -94,7 +107,15 @@ def __init__( } self.params.update(kwargs) - def set_psf(self, psf): + def get_psf(self): + if self.is_torch: + # convert CHW to HWC + return self.psf.movedim(0, -1).unsqueeze(0) + else: + return self.psf[None, ...] + + # needs different name from parent class + def set_point_spread_function(self, psf): """ Set point spread function. @@ -103,19 +124,32 @@ def set_psf(self, psf): psf : np.ndarray or torch.Tensor Point spread function. """ - psf = psf.squeeze().movedim(-1, 0) + assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" + + if torch.is_tensor(psf): + # convert HWC to CHW + psf = psf[0].movedim(-1, 0) + assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + psf = psf[0] + assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" + return super().set_psf(psf) - def propagate(self, obj, return_object_plane=False): + def propagate_image(self, obj, return_object_plane=False): """ Parameters ---------- obj : np.ndarray or torch.Tensor - Single image to propagate at format HWC. + Single image to propagate of format HWC. return_object_plane : bool, optional Whether to return object plane, by default False. """ + + assert obj.shape[-1] == 1 or obj.shape[-1] == 3, "Image must have 1 or 3 channels" + if self.is_torch: + # channel in first dimension as expected by waveprop for pytorch obj = obj.moveaxis(-1, 0) res = super().propagate(obj, return_object_plane) if isinstance(res, tuple): @@ -124,10 +158,6 @@ def propagate(self, obj, return_object_plane=False): res = res.moveaxis(-3, -1) return res else: - obj = np.moveaxis(obj, -1, 0) + # TODO: not tested, but normally don't need to move dimensions for numpy res = super().propagate(obj, return_object_plane) - if isinstance(res, tuple): - res = np.moveaxis(res[0], -3, -1), np.moveaxis(res[1], -3, -1) - else: - res = np.moveaxis(res, -3, -1) return res diff --git a/mask_requirements.txt b/mask_requirements.txt index 699ba552..9e9c28a4 100644 --- a/mask_requirements.txt +++ b/mask_requirements.txt @@ -1,3 +1,3 @@ sympy>=1.11.1 perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 -waveprop>=0.0.7 \ No newline at end of file +waveprop>=0.0.8 \ No newline at end of file diff --git a/recon_requirements.txt b/recon_requirements.txt index 33e12092..0b90adf2 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -3,9 +3,10 @@ lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 click>=8.0.1 -waveprop>=0.0.7 # for simulation +waveprop>=0.0.8 # for simulation # Library for learning algorithm torch >= 2.0.0 torchvision +torchmetrics lpips \ No newline at end of file diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 6611ceec..89a31309 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -44,9 +44,7 @@ def benchmark_recon(config): device = "cpu" # Benchmark dataset - benchmark_dataset = DiffuserCamTestDataset( - data_dir=os.path.join(get_original_cwd(), "data"), n_files=n_files, downsample=downsample - ) + benchmark_dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample) psf = benchmark_dataset.psf.to(device) model_list = [] # list of algoritms to benchmark diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 2a053722..c84d5b92 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -13,8 +13,9 @@ import pathlib as plib import matplotlib.pyplot as plt import numpy as np -from lensless.utils.io import load_data +from lensless.utils.io import load_data, load_image from lensless import ADMM +from lensless.utils.plot import plot_image @hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") @@ -42,6 +43,7 @@ def admm(config): torch=config.torch, torch_device=config.torch_device, bg_pix=config.preprocess.bg_pix, + normalize=config.preprocess.normalize, ) disp = config["display"]["disp"] @@ -52,6 +54,15 @@ def admm(config): if save: save = os.getcwd() + if save: + if config.torch: + org_data = data.cpu().numpy() + else: + org_data = data + ax = plot_image(org_data, gamma=config["display"]["gamma"]) + ax.set_title("Original measurement") + plt.savefig(plib.Path(save) / "lensless.png") + start_time = time.time() if not config.admm.unrolled: recon = ADMM(psf, **config.admm) @@ -60,14 +71,14 @@ def admm(config): from lensless.recon.unrolled_admm import UnrolledADMM import lensless.recon.utils - pre_process = lensless.recon.utils.create_process_network( + pre_process, _ = lensless.recon.utils.create_process_network( network=config.admm.pre_process_model.network, - depth=config.admm.pre_process_depth.depth, + depth=config.admm.pre_process_model.depth, device=config.torch_device, ) - post_process = lensless.recon.utils.create_process_network( + post_process, _ = lensless.recon.utils.create_process_network( network=config.admm.post_process_model.network, - depth=config.admm.post_process_depth.depth, + depth=config.admm.post_process_model.depth, device=config.torch_device, ) @@ -76,18 +87,28 @@ def admm(config): print("Loading checkpoint from : ", path) assert os.path.exists(path), "Checkpoint does not exist" recon.load_state_dict(torch.load(path, map_location=config.torch_device)) + recon.set_data(data) print(f"Setup time : {time.time() - start_time} s") start_time = time.time() if config.torch: with torch.no_grad(): - res = recon.apply( - disp_iter=disp, - save=save, - gamma=config["display"]["gamma"], - plot=config["display"]["plot"], - ) + if config.admm.unrolled: + res = recon.apply( + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + output_intermediate=True, + ) + else: + res = recon.apply( + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + ) else: res = recon.apply( disp_iter=disp, @@ -105,7 +126,78 @@ def admm(config): if config["display"]["plot"]: plt.show() if save: + + if config.admm.unrolled: + # Save intermediate results + if res[1] is not None: + pre_processed_image = res[1].cpu().numpy() + ax = plot_image(pre_processed_image, gamma=config["display"]["gamma"]) + ax.set_title("Image after preprocessing") + plt.savefig(plib.Path(save) / "pre_processed.png") + + if res[2] is not None: + pre_post_process_image = res[2].cpu().numpy() + ax = plot_image(pre_post_process_image, gamma=config["display"]["gamma"]) + ax.set_title("Image prior to post-processing") + plt.savefig(plib.Path(save) / "pre_post_process.png") + np.save(plib.Path(save) / "final_reconstruction.npy", img) + + if config.input.original is not None: + original = load_image( + to_absolute_path(config.input.original), + flip=config["preprocess"]["flip"], + red_gain=config["preprocess"]["red_gain"], + blue_gain=config["preprocess"]["blue_gain"], + shape=img.shape[-3:], + ) + ax = plot_image(original, gamma=config["display"]["gamma"]) + ax.set_title("Ground truth image") + plt.savefig(plib.Path(save) / "original.png") + + # compute metrics + from torchmetrics.image import lpip, psnr + + lpips_func = lpip.LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True) + psnr_funct = psnr.PeakSignalNoiseRatio() + + img_torch = torch.from_numpy(img).squeeze(0) + original_torch = torch.from_numpy(original).unsqueeze(0) + + # channel as first dimension + img_torch = img_torch.movedim(-1, -3) + original_torch = original_torch.movedim(-1, -3) + + # normalize, TODO img max value is 14 which seems strange + img_torch = img_torch / torch.amax(img_torch) + + # compute metrics + lpips = lpips_func(img_torch, original_torch) + psnr = psnr_funct(img_torch, original_torch) + print(f"LPIPS : {lpips}") + print(f"PSNR : {psnr}") + + # If the recon algorithm is unrolled and has a preprocessing step, plot result without preprocessing + if config.admm.unrolled and recon.pre_process is not None: + recon.set_data(data) + recon.pre_process = None + with torch.no_grad(): + res = recon.apply( + disp_iter=disp, + save=False, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + output_intermediate=True, + ) + + img = res[0].cpu().numpy() + np.save(plib.Path(save) / "final_reconstruction_no_preprocessing.npy", img[0]) + ax = plot_image(img, gamma=config["display"]["gamma"]) + plt.savefig(plib.Path(save) / "final_reconstruction_no_preprocessing.png") + pre_post_process_image = res[2].cpu().numpy() + ax = plot_image(pre_post_process_image, gamma=config["display"]["gamma"]) + plt.savefig(plib.Path(save) / "pre_post_process_no_preprocessing.png") + print(f"Files saved to : {save}") diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index c669ea2e..5cbee7bf 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -3,6 +3,7 @@ # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# """ @@ -12,13 +13,26 @@ python scripts/recon/train_unrolled.py ``` +By default it uses the configuration from the file `configs/train_unrolledADMM.yaml`. + +To train pre- and post-processing networks, use the following command: +``` +python scripts/recon/train_unrolled.py -cn train_pre-post-processing +``` + To fine-tune the DiffuserCam PSF, use the following command: ``` python scripts/recon/train_unrolled.py -cn fine-tune_PSF ``` +To train a PSF from scratch with a simulated dataset, use the following command: +``` +python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +``` + """ +import logging import hydra from hydra.utils import get_original_cwd import os @@ -26,20 +40,50 @@ import time from lensless import UnrolledFISTA, UnrolledADMM from lensless.utils.dataset import ( - DiffuserCamTestDataset, + DiffuserCamMirflickr, SimulatedFarFieldDataset, SimulatedDatasetTrainableMask, ) +from torch.utils.data import Subset import lensless.hardware.trainable_mask from lensless.recon.utils import create_process_network -from lensless.utils.image import rgb2gray +from lensless.utils.image import rgb2gray, is_grayscale from lensless.utils.simulation import FarFieldSimulator from lensless.recon.utils import Trainer import torch from torchvision import transforms, datasets +from lensless.utils.io import load_psf +from lensless.utils.io import save_image +from lensless.utils.plot import plot_image +import matplotlib.pyplot as plt + +# A logger for this file +log = logging.getLogger(__name__) + +def simulate_dataset(config): + + if config.torch_device == "cuda" and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # prepare PSF + psf_fp = os.path.join(get_original_cwd(), config.files.psf) + psf, _ = load_psf( + psf_fp, + downsample=config.files.downsample, + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + if config.files.diffusercam_psf: + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + psf = transform_BRG2RGB(torch.from_numpy(psf)) + + # drop depth dimension + psf = psf.to(device) -def simulate_dataset(config, psf, mask=None): # load dataset transforms_list = [transforms.ToTensor()] data_path = os.path.join(get_original_cwd(), "data") @@ -47,26 +91,38 @@ def simulate_dataset(config, psf, mask=None): transforms_list.append(transforms.Grayscale()) transform = transforms.Compose(transforms_list) if config.files.dataset == "mnist": - ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) elif config.files.dataset == "fashion_mnist": - ds = datasets.FashionMNIST(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.FashionMNIST( + root=data_path, train=True, download=True, transform=transform + ) + test_ds = datasets.FashionMNIST( + root=data_path, train=False, download=True, transform=transform + ) elif config.files.dataset == "cifar10": - ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) elif config.files.dataset == "CelebA": - ds = datasets.CelebA(root=data_path, split="train", download=True, transform=transform) + root = config.files.celeba_root + data_path = os.path.join(root, "celeba") + assert os.path.isdir( + data_path + ), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" + train_ds = datasets.CelebA(root=root, split="train", download=False, transform=transform) + test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) else: raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") # convert PSF - if config.simulation.grayscale: + if config.simulation.grayscale and not is_grayscale(psf): psf = rgb2gray(psf) - if not isinstance(psf, torch.Tensor): - psf = transforms.ToTensor()(psf) - n_files = config.files.n_files - device_conv = config.torch_device + # prepare mask + mask = prep_trainable_mask(config, psf, grayscale=config.simulation.grayscale) # check if gpu is available + device_conv = config.torch_device if device_conv == "cuda" and torch.cuda.is_available(): device_conv = "cuda" else: @@ -78,55 +134,74 @@ def simulate_dataset(config, psf, mask=None): is_torch=True, **config.simulation, ) + # create Pytorch dataset and dataloader + n_files = config.files.n_files if n_files is not None: - ds = torch.utils.data.Subset(ds, np.arange(n_files)) + train_ds = torch.utils.data.Subset(train_ds, np.arange(n_files)) + test_ds = torch.utils.data.Subset(test_ds, np.arange(n_files)) if mask is None: - ds_prop = SimulatedFarFieldDataset( - dataset=ds, + train_ds_prop = SimulatedFarFieldDataset( + dataset=train_ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + test_ds_prop = SimulatedFarFieldDataset( + dataset=test_ds, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv, flip=config.simulation.flip, ) else: - ds_prop = SimulatedDatasetTrainableMask( - dataset=ds, + train_ds_prop = SimulatedDatasetTrainableMask( + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + test_ds_prop = SimulatedDatasetTrainableMask( + dataset=test_ds, mask=mask, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv, flip=config.simulation.flip, ) - return ds_prop + return train_ds_prop, test_ds_prop, mask -@hydra.main(version_base=None, config_path="../../configs", config_name="unrolled_recon") -def train_unrolled( - config, -): - if config.torch_device == "cuda" and torch.cuda.is_available(): - print("Using GPU for training.") - device = "cuda" - else: - print("Using CPU for training.") - device = "cpu" - # torch.autograd.set_detect_anomaly(True) +def prep_trainable_mask(config, psf, grayscale=False): + mask = None + if config.trainable_mask.mask_type is not None: + mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - # benchmarking dataset: - path = os.path.join(get_original_cwd(), "data") - benchmark_dataset = DiffuserCamTestDataset( - data_dir=path, downsample=config.simulation.downsample - ) + if config.trainable_mask.initial_value == "random": + initial_mask = torch.rand_like(psf) + elif config.trainable_mask.initial_value == "psf": + initial_mask = psf.clone() + else: + raise ValueError( + f"Initial PSF value {config.trainable_mask.initial_value} not supported" + ) + + if config.trainable_mask.grayscale and not is_grayscale(initial_mask): + initial_mask = rgb2gray(initial_mask) + + mask = mask_class( + initial_mask, optimizer="Adam", lr=config.trainable_mask.mask_lr, grayscale=grayscale + ) - diffusercam_psf = benchmark_dataset.psf.to(device) - background = benchmark_dataset.background + return mask - # convert psf from BGR to RGB - diffusercam_psf = diffusercam_psf[..., [2, 1, 0]] - # if using a portrait dataset rotate the PSF +@hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") +def train_unrolled(config): disp = config.display.disp if disp < 0: @@ -136,6 +211,63 @@ def train_unrolled( if save: save = os.getcwd() + if config.torch_device == "cuda" and torch.cuda.is_available(): + print("Using GPU for training.") + device = "cuda" + else: + print("Using CPU for training.") + device = "cpu" + + # load dataset and create dataloader + train_set = None + test_set = None + psf = None + if "DiffuserCam" in config.files.dataset: + + original_path = os.path.join(get_original_cwd(), config.files.dataset) + psf_path = os.path.join(get_original_cwd(), config.files.psf) + dataset = DiffuserCamMirflickr( + dataset_dir=original_path, + psf_path=psf_path, + downsample=config.files.downsample, + ) + dataset.psf = dataset.psf.to(device) + # train-test split as in https://waller-lab.github.io/LenslessLearning/dataset.html + # first 1000 files for test, the rest for training + train_indices = dataset.allowed_idx[dataset.allowed_idx > 1000] + test_indices = dataset.allowed_idx[dataset.allowed_idx <= 1000] + if config.files.n_files is not None: + train_indices = train_indices[: config.files.n_files] + test_indices = test_indices[: config.files.n_files] + + train_set = Subset(dataset, train_indices) + test_set = Subset(dataset, test_indices) + + # -- if learning mask + mask = prep_trainable_mask(config, dataset.psf) + if mask is not None: + # plot initial PSF + psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] + if config.trainable_mask.grayscale: + psf_np = psf_np[:, :, -1] + + save_image(psf_np, os.path.join(save, "psf_initial.png")) + plot_image(psf_np, gamma=config.display.gamma) + plt.savefig(os.path.join(save, "psf_initial_plot.png")) + + psf = dataset.psf + + else: + + train_set, test_set, mask = simulate_dataset(config) + psf = train_set.psf + + assert train_set is not None + assert psf is not None + + print("Train test size : ", len(train_set)) + print("Test test size : ", len(test_set)) + start_time = time.time() # Load pre process model @@ -150,10 +282,11 @@ def train_unrolled( config.reconstruction.post_process.depth, device=device, ) + # create reconstruction algorithm if config.reconstruction.method == "unrolled_fista": recon = UnrolledFISTA( - diffusercam_psf, + psf, n_iter=config.reconstruction.unrolled_fista.n_iter, tk=config.reconstruction.unrolled_fista.tk, pad=True, @@ -163,7 +296,7 @@ def train_unrolled( ).to(device) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( - diffusercam_psf, + psf, n_iter=config.reconstruction.unrolled_admm.n_iter, mu1=config.reconstruction.unrolled_admm.mu1, mu2=config.reconstruction.unrolled_admm.mu2, @@ -183,67 +316,17 @@ def train_unrolled( algorithm_name += "_" + post_process_name # print number of parameters - print(f"Training model with {sum(p.numel() for p in recon.parameters())} parameters") - # transform from BGR to RGB - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - - # create mask - if config.trainable_mask.mask_type is not None: - mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - if config.trainable_mask.initial_value == "random": - mask = mask_class( - torch.rand_like(diffusercam_psf), optimizer="Adam", lr=config.trainable_mask.mask_lr - ) - elif config.trainable_mask.initial_value == "DiffuserCam": - mask = mask_class(diffusercam_psf, optimizer="Adam", lr=config.trainable_mask.mask_lr) - elif config.trainable_mask.initial_value == "DiffuserCam_gray": - mask = mask_class( - diffusercam_psf[:, :, :, 0, None], - optimizer="Adam", - lr=config.trainable_mask.mask_lr, - is_rgb=not config.simulation.grayscale, - ) - else: - mask = None - - # load dataset and create dataloader - if config.files.dataset == "DiffuserCam": - # Use a ParallelDataset - from lensless.utils.dataset import MeasuredDataset - - max_indices = 30000 - if config.files.n_files is not None: - max_indices = config.files.n_files + 1000 - - data_path = os.path.join(get_original_cwd(), "data", "DiffuserCam") - assert os.path.exists(data_path), "DiffuserCam dataset not found" - dataset = MeasuredDataset( - root_dir=data_path, - indices=range(1000, max_indices), - background=background, - psf=diffusercam_psf, - lensless_fn="diffuser_images", - lensed_fn="ground_truth_lensed", - downsample=config.simulation.downsample / 4, - transform_lensless=transform_BRG2RGB, - transform_lensed=transform_BRG2RGB, - ) - else: - # Use a simulated dataset - if config.trainable_mask.use_mask_in_dataset: - dataset = simulate_dataset(config, diffusercam_psf, mask=mask) - # the mask use will differ from the one in the benchmark dataset - print("Trainable Mask will be used in the test dataset") - benchmark_dataset = None - else: - dataset = simulate_dataset(config, diffusercam_psf, mask=None) + n_param = sum(p.numel() for p in recon.parameters()) + if mask is not None: + n_param += sum(p.numel() for p in mask.parameters()) + log.info(f"Training model with {n_param} parameters") print(f"Setup time : {time.time() - start_time} s") - print(f"PSF shape : {diffusercam_psf.shape}") + print(f"PSF shape : {psf.shape}") trainer = Trainer( - recon, - dataset, - benchmark_dataset, + recon=recon, + train_dataset=train_set, + test_dataset=test_set, mask=mask, batch_size=config.training.batch_size, loss=config.loss, @@ -254,21 +337,13 @@ def train_unrolled( slow_start=config.training.slow_start, skip_NAN=config.training.skip_NAN, algorithm_name=algorithm_name, + metric_for_best_model=config.training.metric_for_best_model, + save_every=config.training.save_every, + gamma=config.display.gamma, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) - if mask is not None: - print("Saving mask") - print(f"mask shape: {mask._mask.shape}") - torch.save(mask._mask, os.path.join(save, "mask.pt")) - # save as image using plt - import matplotlib.pyplot as plt - - print(f"mask max: {mask._mask.max()}") - print(f"mask min: {mask._mask.min()}") - plt.imsave(os.path.join(save, "mask.png"), mask._mask.detach().cpu().numpy()[0, ...]) - if __name__ == "__main__": train_unrolled() From 5c62cecfabe0aa3f9105812851d53e3e0cd32dc3 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 22 Sep 2023 14:36:18 +0200 Subject: [PATCH 11/11] Logging (#92) * Fix ADMM ordering * Write training output to log file. --------- Co-authored-by: Yohann PERRON --- lensless/recon/utils.py | 44 +++++++++++++++++++++++++++------ scripts/recon/train_unrolled.py | 16 +++++++----- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 2409dd80..2ca758c6 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -246,6 +246,7 @@ def __init__( test_size=0.15, mask=None, batch_size=4, + eval_batch_size=10, loss="l2", lpips=None, l1_mask=None, @@ -257,6 +258,7 @@ def __init__( metric_for_best_model=None, save_every=None, gamma=None, + logger=None, ): """ Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. @@ -281,6 +283,8 @@ def __init__( Trainable mask to use for training. If none, training with fix psf, by default None. batch_size : int, optional Batch size to use for training, by default 4. + eval_batch_size : int, optional + Batch size to use for evaluation, by default 10. loss : str, optional Loss function to use for training "l1" or "l2", by default "l2". lpips : float, optional @@ -303,11 +307,13 @@ def __init__( Save model every ``save_every`` epochs. If None, just save best model. gamma : float, optional Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None. + logger : :py:class:`logging.Logger`, optional + Logger to use for logging. If None, just print to terminal. Default is None. """ self.device = recon._psf.device - + self.logger = logger self.recon = recon assert train_dataset is not None @@ -319,7 +325,10 @@ def __init__( train_dataset, test_dataset = torch.utils.data.random_split( train_dataset, [train_size, test_size] ) - print(f"Train size : {train_size}, Test size : {test_size}") + if self.logger is not None: + self.logger.info(f"Train size : {train_size}, Test size : {test_size}") + else: + print(f"Train size : {train_size}, Test size : {test_size}") self.train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, @@ -330,6 +339,7 @@ def __init__( self.test_dataset = test_dataset self.lpips = lpips self.skip_NAN = skip_NAN + self.eval_batch_size = eval_batch_size if mask is not None: assert isinstance(mask, TrainableMask) @@ -413,10 +423,16 @@ def learning_rate_function(epoch): def detect_nan(grad): if torch.isnan(grad).any(): - print(grad, flush=True) + if self.logger: + self.logger.info(grad) + else: + print(grad, flush=True) for name, param in recon.named_parameters(): if param.requires_grad: - print(name, param) + if self.logger: + self.logger.info(name, param) + else: + print(name, param) raise ValueError("Gradient is NaN") return grad @@ -505,7 +521,10 @@ def train_epoch(self, data_loader, disp=-1): is_NAN = True break if is_NAN: - print("NAN detected in gradiant, skipping training step") + if self.logger is not None: + self.logger.info("NAN detected in gradiant, skipping training step") + else: + print("NAN detected in gradiant, skipping training step") i += 1 continue self.optimizer.step() @@ -518,6 +537,9 @@ def train_epoch(self, data_loader, disp=-1): pbar.set_description(f"loss : {mean_loss}") i += 1 + if self.logger is not None: + self.logger.info(f"loss : {mean_loss}") + return mean_loss def evaluate(self, mean_loss, save_pt): @@ -534,7 +556,7 @@ def evaluate(self, mean_loss, save_pt): if self.test_dataset is None: return # benchmarking - current_metrics = benchmark(self.recon, self.test_dataset, batchsize=10) + current_metrics = benchmark(self.recon, self.test_dataset, batchsize=self.eval_batch_size) # update metrics with current metrics self.metrics["LOSS"].append(mean_loss) @@ -615,13 +637,19 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): self.evaluate(-1, save_pt) for epoch in range(n_epoch): - print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") + if self.logger is not None: + self.logger.info(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") + else: + print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") mean_loss = self.train_epoch(self.train_dataloader, disp=disp) # offset because of evaluate before loop self.on_epoch_end(mean_loss, save_pt, epoch + 1) self.scheduler.step() - print(f"Train time : {time.time() - start_time} s") + if self.logger is not None: + self.logger.info(f"Train time : {time.time() - start_time} s") + else: + print(f"Train time : {time.time() - start_time} s") def save(self, epoch, path="recon", include_optimizer=False): # create directory if it does not exist diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 5cbee7bf..c9be1ee4 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -212,10 +212,10 @@ def train_unrolled(config): save = os.getcwd() if config.torch_device == "cuda" and torch.cuda.is_available(): - print("Using GPU for training.") + log.info("Using GPU for training.") device = "cuda" else: - print("Using CPU for training.") + log.info("Using CPU for training.") device = "cpu" # load dataset and create dataloader @@ -265,8 +265,8 @@ def train_unrolled(config): assert train_set is not None assert psf is not None - print("Train test size : ", len(train_set)) - print("Test test size : ", len(test_set)) + log.info(f"Train test size : {len(train_set)}") + log.info(f"Test test size : {len(test_set)}") start_time = time.time() @@ -321,8 +321,9 @@ def train_unrolled(config): n_param += sum(p.numel() for p in mask.parameters()) log.info(f"Training model with {n_param} parameters") - print(f"Setup time : {time.time() - start_time} s") - print(f"PSF shape : {psf.shape}") + log.info(f"Setup time : {time.time() - start_time} s") + log.info(f"PSF shape : {psf.shape}") + log.info(f"Results saved in {save}") trainer = Trainer( recon=recon, train_dataset=train_set, @@ -340,10 +341,13 @@ def train_unrolled(config): metric_for_best_model=config.training.metric_for_best_model, save_every=config.training.save_every, gamma=config.display.gamma, + logger=log, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) + log.info(f"Results saved in {save}") + if __name__ == "__main__": train_unrolled()