diff --git a/.github/workflows/python_no_pycsou.yml b/.github/workflows/python_no_pycsou.yml index 87a6bff5..a1c1a617 100644 --- a/.github/workflows/python_no_pycsou.yml +++ b/.github/workflows/python_no_pycsou.yml @@ -58,4 +58,5 @@ jobs: run: | pip install -U pytest pip install -r recon_requirements.txt + pip install -r mask_requirements.txt pytest \ No newline at end of file diff --git a/.github/workflows/python_pycsou.yml b/.github/workflows/python_pycsou.yml index dc64e1d3..d5cf1e91 100644 --- a/.github/workflows/python_pycsou.yml +++ b/.github/workflows/python_pycsou.yml @@ -58,5 +58,6 @@ jobs: run: | pip install -U pytest pip install -r recon_requirements.txt + pip install -r mask_requirements.txt pip install git+https://github.com/matthieumeo/pycsou.git@v2-dev pytest \ No newline at end of file diff --git a/.gitignore b/.gitignore index c2055183..3430ec87 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,4 @@ -scripts/recon/outputs/ .idea/ - *_sol.py scripts/admm_* scripts/gd_* @@ -97,6 +95,17 @@ target/ # Jupyter Notebook .ipynb_checkpoints +*.ipynb + +# Images and NPY files +*.png +*.npy + +# Datasets +data/celeba_mini + +# FlatCam reconstruction +masks/FlatCam/flatcam-authors # IPython profile_default/ @@ -152,3 +161,4 @@ dmypy.json # Pyre type checker .pyre/ + diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4272f7a6..847fa0f7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,10 +18,12 @@ Added - Link and citation for JOSS. - Authors at top of source code files. - Add paramiko as dependency for remote capture and display. +- Mask module, for CodedAperture (FlatCam), PhaseContour (PhlatCam), and FresnelZoneAperture. - Script for measuring arbitrary dataset (from Raspberry Pi). - Support for preprocessing and postprocessing, such as denoising, in ``TrainableReconstructionAlgorithm``. Both trainable and fix postprocessing can be used. - Utilities to load a trained DruNet model for use as postprocessing in ``TrainableReconstructionAlgorithm``. - Support for unrolled loading and inference in the script ``admm.py``. +- Tikhonov reconstruction for coded aperture measurements (MLS / MURA). Changed @@ -36,7 +38,8 @@ Changed Bugfix ~~~~~~ -- Displaying 3D reconstructions by summing values along axis would produce un-normalized values. +- Fix overwriting of sensor parameters when downsampling. +- Displaying 3D reconstructions by summing values along axis would produce un-normalized values. 1.0.4 - (2023-06-14) -------------------- diff --git a/configs/mask_sim_dataset.yaml b/configs/mask_sim_dataset.yaml new file mode 100644 index 00000000..9f5ad958 --- /dev/null +++ b/configs/mask_sim_dataset.yaml @@ -0,0 +1,31 @@ +defaults: + - mask_sim_single + - _self_ + +seed: 0 +save: True + +files: + dataset: data/celeba_mini + image_ext: jpg + n_files: 10 # null to use all + +simulation: + object_height: [0.25, 0.3] # range for random height, or scalar + random_shift: False + grayscale: False + +# torch for reconstruction +torch: False +torch_device: 'cuda:0' + +recon: + + algo: "tikhonov" # "tikhonov" or "admm" or None to skip + + tikhonov: + reg: 3e-4 + + admm: + # Recommend to not display, ok for small number of files, otherwise many windows will pop up! + disp_iter: null \ No newline at end of file diff --git a/configs/mask_sim_single.yaml b/configs/mask_sim_single.yaml new file mode 100644 index 00000000..f793d302 --- /dev/null +++ b/configs/mask_sim_single.yaml @@ -0,0 +1,61 @@ +hydra: + job: + chdir: True # change to output folder + + +files: + original: data/celeba_mini/000019.jpg + #original: data/original/mnist_3.png + +save: True + +simulation: + object_height: 0.3 + # these distance parameters are typically fixed for a given PSF + scene2mask: 40e-2 + mask2sensor: 4e-3 + # see waveprop.devices + sensor: "rpi_hq" + snr_db: 20 + # Downsampling for PSF + downsample: 8 + + # max val in simulated measured (quantized 8 bits) + max_val: 230 + + image_format: rgb # rgb, grayscale, bayer_rggb, bayer_bggr, bayer_grbg, bayer_gbrg + + flatcam: False # only supported if mask.type is "MURA" or "MLS" + + +mask: + type: "MLS" # "MURA", "MLS", "FZA", "PhaseContour" + + # Coded Aperture (MURA or MLS) + #flatcam_method: 'MLS' + n_bits: 8 # e.g. 8 for MLS, 99 for MURA + + # Phase Contour + noise_period: [16, 16] + refractive_index: 1.2 + phase_mask_iter: 10 + + # Fresnel Zone Aperture + radius: 0.32e-3 + + +recon: + + algo: "admm" # tikhonov or admm + + tikhonov: + reg: 3e-4 + + admm: + n_iter: 20 + disp_iter: 2 + # Hyperparameters + mu1: 1e-6 + mu2: 1e-5 + mu3: 4e-5 + tau: 0.0001 \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 7e376152..fc01f75b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,6 +7,8 @@ MOCK_MODULES = [ "scipy", + "scipy.signal", + "scipy.linalg", "pycsou", "matplotlib", "matplotlib.pyplot", @@ -28,7 +30,12 @@ "PIL", "tqdm", "paramiko", - "paramiko.ssh_exception" + "paramiko.ssh_exception", + "perlin_numpy", + "waveprop", + "waveprop.fresnel", + "waveprop.rs", + "waveprop.noise", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/docs/source/index.rst b/docs/source/index.rst index c87b6258..94c236e6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -25,6 +25,7 @@ Contents measurement reconstruction evaluation + mask sensor utilities demo diff --git a/docs/source/mask.rst b/docs/source/mask.rst new file mode 100644 index 00000000..0ad8327e --- /dev/null +++ b/docs/source/mask.rst @@ -0,0 +1,33 @@ +.. automodule:: lensless.hardware.mask + + + Abstract Mask Class + ~~~~~~~~~~~~~~~~~~~ + + .. autoclass:: lensless.hardware.mask.Mask + :members: + :special-members: __init__ + + + Coded Aperture (FlatCam) + ~~~~~~~~~~~~~~~~~~~~~~~~ + + .. autoclass:: lensless.hardware.mask.CodedAperture + :members: + :special-members: __init__ + + + Phase Contour (PhlatCam) + ~~~~~~~~~~~~~~~~~~~~~~~~ + + .. autoclass:: lensless.hardware.mask.PhaseContour + :members: + :special-members: __init__ + + + Fresnel Zone Aperture + ~~~~~~~~~~~~~~~~~~~~~ + + .. autoclass:: lensless.hardware.mask.FresnelZoneAperture + :members: + :special-members: __init__ \ No newline at end of file diff --git a/docs/source/reconstruction.rst b/docs/source/reconstruction.rst index b156145e..27434c40 100644 --- a/docs/source/reconstruction.rst +++ b/docs/source/reconstruction.rst @@ -44,6 +44,14 @@ .. autofunction:: lensless.recon.admm.finite_diff_gram + + Tikhonov (Ridge Regression) + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + .. autoclass:: lensless.CodedApertureReconstruction + :special-members: __init__, apply + + Accelerated Proximal Gradient Descent (APGD) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/lensless/__init__.py b/lensless/__init__.py index 6341b851..53d94884 100644 --- a/lensless/__init__.py +++ b/lensless/__init__.py @@ -21,6 +21,7 @@ FISTA, GradientDescentUpdate, ) +from .recon.tikhonov import CodedApertureReconstruction from .hardware.sensor import VirtualSensor, SensorOptions try: diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py new file mode 100644 index 00000000..9cde01b2 --- /dev/null +++ b/lensless/hardware/mask.py @@ -0,0 +1,455 @@ +# ############################################################################# +# mask.py +# ================= +# Authors : +# Aaron FARGEON [aa.fargeon@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + +""" +Mask +==== + +This module provides utilities to create different types of masks (:py:class:`~lensless.hardware.mask.CodedAperture`, +:py:class:`~lensless.hardware.mask.PhaseContour`, +:py:class:`~lensless.hardware.mask.FresnelZoneAperture`) and simulate the corresponding PSF. + +""" + + +import abc +import warnings +import numpy as np +import cv2 as cv +from math import sqrt +from perlin_numpy import generate_perlin_noise_2d +from sympy.ntheory import quadratic_residues +from scipy.signal import max_len_seq +from scipy.linalg import circulant +from numpy.linalg import multi_dot +from waveprop.fresnel import fresnel_conv +from waveprop.rs import angular_spectrum +from waveprop.noise import add_shot_noise +from lensless.hardware.sensor import VirtualSensor +from lensless.utils.image import resize +from lensless.utils.image import rgb2bayer, bayer2rgb + + +class Mask(abc.ABC): + """ + Parent ``Mask`` class. Attributes common to each type of mask. + """ + + def __init__( + self, + resolution, + distance_sensor, + size=None, + feature_size=None, + psf_wavelength=[460e-9, 550e-9, 640e-9], + **kwargs + ): + """ + Constructor from parameters of the user's choice. + + Parameters + ---------- + resolution: array_like + Resolution of the mask (px). + distance_sensor: float + Distance between the mask and the sensor (m). + size: array_like + Size of the sensor (m). Only one of ``size`` or ``feature_size`` needs to be specified. + feature_size: float or array_like + Size of the feature (m). Only one of ``size`` or ``feature_size`` needs to be specified. + psf_wavelength: list, optional + List of wavelengths to simulate PSF (m). Default is [460e-9, 550e-9, 640e-9] nm (blue, green, red). + """ + + resolution = np.array(resolution) + assert len(resolution) == 2, "Sensor resolution should be of length 2" + + assert ( + size is not None or feature_size is not None + ), "Either sensor_size or feature_size should be specified" + if size is None: + size = np.array(resolution * feature_size) + else: + size = np.array(size) + assert len(size) == 2, "Sensor size should be of length 2" + if feature_size is None: + feature_size = np.array(size) / np.array(resolution) + else: + if isinstance(feature_size, float): + feature_size = np.array([feature_size, feature_size]) + else: + assert len(feature_size) == 2, "Feature size should be of length 2" + feature_size = np.array(feature_size) + assert np.all(feature_size > 0), "Feature size should be positive" + assert np.all(resolution * feature_size <= size) + + self.phase_mask = None + self.resolution = resolution + self.size = size + if feature_size is None: + self.feature_size = self.size / self.resolution + else: + self.feature_size = feature_size + self.distance_sensor = distance_sensor + + # create mask + self.mask = None + self.create_mask() + self.shape = self.mask.shape + + # PSF + self.psf_wavelength = psf_wavelength + self.psf = None + self.compute_psf() + + @classmethod + def from_sensor(cls, sensor_name, downsample=None, **kwargs): + """ + Constructor from an existing virtual sensor that copies over the sensor parameters + (sensor resolution, sensor size, feature size). + + Parameters + ---------- + sensor_name: str + Name of the sensor. See :py:class:`~lensless.hardware.sensor.SensorOptions`. + downsample: float, optional + Downsampling factor. + **kwargs: + Additional arguments for the mask constructor. See the abstract class :py:class:`~lensless.hardware.mask.Mask` + and the corresponding subclass for more details. + + Example + ------- + + .. code-block:: python + + mask = CodedAperture.from_sensor(sensor_name=SensorOptions.RPI_HQ, downsample=8, ...) + """ + sensor = VirtualSensor.from_name(sensor_name, downsample) + return cls( + resolution=tuple(sensor.resolution.copy()), + size=tuple(sensor.size.copy()), + feature_size=sensor.pixel_size.copy(), + **kwargs + ) + + @abc.abstractmethod + def create_mask(self): + """ + Abstract mask creation method that creates mask with subclass-specific function. + """ + pass + + def compute_psf(self): + """ + Compute the intensity PSF with bandlimited angular spectrum (BLAS) for each wavelength. + Common to all types of masks. + """ + psf = np.zeros(tuple(self.resolution) + (len(self.psf_wavelength),), dtype=np.complex64) + for i, wv in enumerate(self.psf_wavelength): + psf[:, :, i] = angular_spectrum( + u_in=self.mask, + wv=wv, + d1=self.feature_size, + dz=self.distance_sensor, + dtype=np.float32, + bandlimit=True, + )[0] + + # intensity PSF + self.psf = np.abs(psf) ** 2 + + +class CodedAperture(Mask): + """ + Coded aperture mask as in `FlatCam `_. + """ + + def __init__(self, method="MLS", n_bits=8, **kwargs): + """ + Coded aperture mask contructor (FlatCam). + + Parameters + ---------- + method: str + Pattern generation method (MURA or MLS). Default is ``MLS``. + n_bits: int, optional + Number of bits for pattern generation. + Size is ``4*n_bits + 1`` for MURA and ``2^n - 1`` for MLS. + Default is 8 (for a 255x255 MLS mask). + **kwargs: + The keyword arguments are passed to the parent class :py:class:`~lensless.hardware.mask.Mask`. + """ + + self.row = None + self.col = None + self.method = method + self.n_bits = n_bits + + super().__init__(**kwargs) + + def create_mask(self): + """ + Creating coded aperture mask using either the MURA of MLS method. + """ + assert self.method.upper() in ["MURA", "MLS"], "Method should be either 'MLS' or 'MURA'" + + # Generating pattern + if self.method.upper() == "MURA": + self.mask = self.squarepattern(4 * self.n_bits + 1)[1:, 1:] + self.row = 2 * self.mask[0, :] - 1 + self.col = 2 * self.mask[:, 0] - 1 + else: + seq = max_len_seq(self.n_bits)[0] * 2 - 1 + h_r = np.r_[seq, seq] + self.row = h_r + self.col = h_r + self.mask = (np.outer(h_r, h_r) + 1) / 2 + + # Upscaling + if np.any(self.resolution != self.mask.shape): + upscaled_mask = resize( + self.mask[:, :, np.newaxis], shape=tuple(self.resolution) + (1,) + ).squeeze() + upscaled_mask = np.clip(upscaled_mask, 0, 1) + self.mask = np.round(upscaled_mask).astype(int) + + def is_prime(self, n): + """ + Assess whether a number is prime or not. + + Parameters + ---------- + n: int + The number we want to check. + """ + if n % 2 == 0 and n > 2: + return False + return all(n % i for i in range(3, int(sqrt(n)) + 1, 2)) + + def squarepattern(self, p): + """ + Generate MURA square pattern. + + Parameters + ---------- + p: int + Number of bits. + """ + if not self.is_prime(p): + raise ValueError("p is not a valid length. It must be prime.") + A = np.zeros((p, p), dtype=int) + q = quadratic_residues(p) + A[1:, 0] = 1 + for j in range(1, p): + for i in range(1, p): + if not ((i - 1 in q) != (j - 1 in q)): + A[i, j] = 1 + return A + + def get_conv_matrices(self, img_shape): + """ + Get theoretical left and right convolution matrices for the separable mask. + + Such that measurement model is given ``P @ img @ Q.T``. + + Parameters + ---------- + img_shape: tuple + Shape of the image to being convolved. + + Returns + ------- + P: :py:class:`~numpy.ndarray` + Left convolution matrix. + Q: :py:class:`~numpy.ndarray` + Right convolution matrix. + + """ + + P = circulant(np.resize(self.col, self.resolution[0]))[:, : img_shape[0]] + Q = circulant(np.resize(self.row, self.resolution[1]))[:, : img_shape[1]] + + return P, Q + + def simulate(self, obj, snr_db=20): + """ + Simulate the mask measurement of an image. Apply left and right convolution matrices, + add noise and return the measurement. + + Parameters + ---------- + obj: :py:class:`~numpy.ndarray` + Image to simulate. + snr_db: float, optional + Signal-to-noise ratio (dB) of the simulated measurement. Default is 20 dB. + """ + assert len(obj.shape) == 3, "Object should be a 3D array (HxWxC) even if grayscale." + + # Get convolution matrices + P, Q = self.get_conv_matrices(obj.shape) + + # Convolve image + n_channels = obj.shape[-1] + meas = np.dstack([multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)]) + + # Add noise + if snr_db is not None: + meas = add_shot_noise(meas, snr_db=snr_db) + + return meas + + +class PhaseContour(Mask): + """ + Phase contour mask as in `PhlatCam `_. + """ + + def __init__( + self, noise_period=(16, 16), refractive_index=1.2, n_iter=10, design_wv=532e-9, **kwargs + ): + """ + Phase contour mask contructor (PhlatCam). + + Parameters + ---------- + noise_period: array_like, optional + Noise period of the Perlin noise (px). Default is (8, 8). + refractive_index: float, optional + Refractive index of the mask substrate. Default is 1.2. + n_iter: int, optional + Number of iterations for the phase retrieval algorithm. Default is 10. + design_wv: float, optional + Wavelength used to design the mask (m). Default is 532e-9, as in the PhlatCam paper. + **kwargs: + The keyword arguments are passed to the parent class :py:class:`~lensless.hardware.mask.Mask`. + """ + + self.target_psf = None + self.phase_pattern = None + self.height_map = None + self.noise_period = noise_period + self.refractive_index = refractive_index + self.n_iter = n_iter + self.design_wv = design_wv + + super().__init__(**kwargs) + + def create_mask(self): + """ + Creating phase contour from edges of Perlin noise. + """ + + # Creating Perlin noise + proper_dim_1 = (self.resolution[0] // self.noise_period[0]) * self.noise_period[0] + proper_dim_2 = (self.resolution[1] // self.noise_period[1]) * self.noise_period[1] + noise = generate_perlin_noise_2d((proper_dim_1, proper_dim_2), self.noise_period) + + # Upscaling to correspond to sensor size + if np.any(self.resolution != noise.shape): + noise = resize(noise[:, :, np.newaxis], shape=tuple(self.resolution) + (1,)).squeeze() + + # Edge detection + binary = np.clip(np.round(np.interp(noise, (-1, 1), (0, 1))), a_min=0, a_max=1) + self.target_psf = cv.Canny(np.interp(binary, (-1, 1), (0, 255)).astype(np.uint8), 0, 255) + + # Computing mask and height map + phase_mask, height_map = phase_retrieval( + target_psf=self.target_psf, + wv=self.design_wv, + d1=self.feature_size, + dz=self.distance_sensor, + n=self.refractive_index, + n_iter=self.n_iter, + height_map=True, + ) + self.height_map = height_map + self.phase_pattern = phase_mask + self.mask = np.exp(1j * phase_mask) + + +def phase_retrieval(target_psf, wv, d1, dz, n=1.2, n_iter=10, height_map=False): + """ + Iterative phase retrieval algorithm similar to `PhlatCam `_, + using Fresnel propagation. + + Parameters + ---------- + target_psf: array_like + Target PSF to optimize the phase mask for. + wv: float + Wavelength (m). + d1: float + Sample period on the sensor i.e. pixel size (m). + dz: float + Propagation distance between the mask and the sensor. + n: float + Refractive index of the mask substrate. Default is 1.2. + n_iter: int + Number of iterations. Default value is 10. + """ + M_p = np.sqrt(target_psf) + + if hasattr(d1, "__len__"): + if d1[0] != d1[1]: + warnings.warn("Non-square pixel, first dimension taken as feature size.") + d1 = d1[0] + + for _ in range(n_iter): + # back propagate from sensor to mask + M_phi = fresnel_conv(M_p, wv, d1, -dz, dtype=np.float32)[0] + # constrain amplitude at mask to be unity, i.e. phase pattern + M_phi = np.exp(1j * np.angle(M_phi)) + # forward propagate from mask to sensor + M_p = fresnel_conv(M_phi, wv, d1, dz, dtype=np.float32)[0] + # constrain amplitude to be sqrt(PSF) + M_p = np.sqrt(target_psf) * np.exp(1j * np.angle(M_p)) + + phi = (np.angle(M_phi) + 2 * np.pi) % (2 * np.pi) + + if height_map: + return phi, wv * phi / (2 * np.pi * (n - 1)) + else: + return phi + + +class FresnelZoneAperture(Mask): + """ + Fresnel Zone Aperture (FZA) mask as in `this work `_, + namely binarized cosine function. + """ + + def __init__(self, radius=0.32e-3, **kwargs): + """ + Fresnel Zone Aperture mask contructor. + + Parameters + ---------- + radius: float + characteristic radius of the FZA (m) + default value: 5e-4 + **kwargs: + The keyword arguments are passed to the parent class :py:class:`~lensless.hardware.mask.Mask`. + """ + + self.radius = radius + + super().__init__(**kwargs) + + def create_mask(self): + """ + Creating binary Fresnel Zone Aperture mask. + """ + dim = self.resolution + x, y = np.meshgrid( + np.linspace(-dim[1] / 2, dim[1] / 2 - 1, dim[1]), + np.linspace(-dim[0] / 2, dim[0] / 2 - 1, dim[0]), + ) + radius_px = self.radius / self.feature_size[0] + mask = 0.5 * (1 + np.cos(np.pi * (x**2 + y**2) / radius_px**2)) + self.mask = np.round(mask) diff --git a/lensless/hardware/sensor.py b/lensless/hardware/sensor.py index b8c56120..36a5adda 100644 --- a/lensless/hardware/sensor.py +++ b/lensless/hardware/sensor.py @@ -132,7 +132,7 @@ def __init__( Parameters ---------- - pixel_size : array-like + pixel_size : array-like or float 2D pixel size in meters. resolution : array-like 2D resolution in pixels. @@ -147,9 +147,16 @@ def __init__( """ + assert len(resolution) == 2, "Resolution must be 2D" + self.resolution = ( + resolution.copy() + ) # to not overwrite original values when using downsample + + if isinstance(pixel_size, float): + pixel_size = np.array([pixel_size, pixel_size]) assert len(pixel_size) == 2, "Pixel size must be 2D" - self.pixel_size = pixel_size - self.resolution = resolution + self.pixel_size = pixel_size.copy() + self.diagonal = diagonal self.color = color if bit_depth is None: @@ -174,7 +181,7 @@ def __init__( @classmethod def from_name(cls, name, downsample=None): """ - Create a sensor from one of the available options in :py:class:`~lensless.sensor.SensorOptions`. + Create a sensor from one of the available options in :py:class:`~lensless.hardware.sensor.SensorOptions`. Parameters ---------- @@ -189,7 +196,8 @@ def from_name(cls, name, downsample=None): """ if name not in SensorOptions.values(): raise ValueError(f"Sensor {name} not supported.") - return cls(**sensor_dict[name], downsample=downsample) + sensor_specs = sensor_dict[name].copy() + return cls(**sensor_specs, downsample=downsample) def capture(self, scene=None, bit_depth=None, bayer=False): """ @@ -289,7 +297,7 @@ def downsample(self, factor): assert factor > 1, "Downsample factor must be greater than 1." - self.pixel_size *= factor + self.pixel_size = self.pixel_size * factor self.resolution = (self.resolution / factor).astype(int) self.size = self.pixel_size * self.resolution self.image_shape = self.resolution diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index cc13a8b7..58200f2a 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -373,7 +373,7 @@ def _set_initial_estimate(self, image_est): def set_image_estimate(self, image_est): """ - Overwrite current image estimate. + Set initial estimate of image, e.g. to warm-start algorithm. Parameters ---------- diff --git a/lensless/recon/tikhonov.py b/lensless/recon/tikhonov.py new file mode 100644 index 00000000..84a88011 --- /dev/null +++ b/lensless/recon/tikhonov.py @@ -0,0 +1,112 @@ +# ############################################################################# +# tikhonov.py +# ================= +# Authors : +# Aaron FARGEON [aa.fargeon@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + +""" +Tikhonov +======== + +The py:class:`~lensless.recon.tikhonov.CodedApertureReconstruction` class is meant +to recover an image from a py:class:`~lensless.hardware.mask.CodedAperture` lensless +capture, using the analytical solution to the Tikhonov optimization problem +(least squares problem with L2 regularization term), as in the `FlatCam paper +`_ (Eq. 7). +""" + +import numpy as np +from numpy.linalg import multi_dot + + +class CodedApertureReconstruction: + """ + Reconstruction method for the (non-iterative) Tikhonov algorithm presented in the `FlatCam paper `_. + + TODO: operations in float32 + """ + + def __init__(self, mask, image_shape, P=None, Q=None, lmbd=3e-4): + """ + Parameters + ---------- + mask : py:class:`~lensless.hardware.mask.CodedAperture` + Coded aperture mask object. + image_shape : (`array-like` or `tuple`) + The shape of the image to reconstruct. + P : :py:class:`~numpy.ndarray`, optional + Left convolution matrix in measurement operator. Must be of shape (measurement_resolution[0], image_shape[0]). + By default, it is generated from the mask. In practice, it may be useful to measure as in the FlatCam paper. + Q : :py:class:`~numpy.ndarray`, optional + Right convolution matrix in measurement operator. Must be of shape (measurement_resolution[1], image_shape[1]). + By default, it is generated from the mask. In practice, it may be useful to measure as in the FlatCam paper. + lmbd: float: + Regularization parameter. Default value is `3e-4` as in the FlatCam paper `code `_. + """ + + self.lmbd = lmbd + if P is None or Q is None: + self.P, self.Q = mask.get_conv_matrices(image_shape) + else: + self.P = P + self.Q = Q + assert self.P.shape == ( + mask.resolution[0], + image_shape[0], + ), "Left matrix P shape mismatch" + assert self.Q.shape == ( + mask.resolution[1], + image_shape[1], + ), "Right matrix Q shape mismatch" + + def apply(self, img): + """ + Method for performing Tikhinov reconstruction. + + Parameters + ---------- + img : :py:class:`~numpy.ndarray` + Lensless capture measurement. Must be 3D even if grayscale. + + Returns + ------- + :py:class:`~numpy.ndarray` + Reconstructed image, in the same format as the measurement. + """ + assert len(img.shape) == 3, "Object should be a 3D array (HxWxC) even if grayscale." + + # Empty matrix for reconstruction + n_channels = img.shape[-1] + x_est = np.empty([self.P.shape[1], self.Q.shape[1], n_channels]) + + # Applying reconstruction for each channel + for c in range(n_channels): + + # SVD of left matrix + UL, SL, VLh = np.linalg.svd(self.P, full_matrices=True) + VL = VLh.T + DL = np.concatenate((np.diag(SL), np.zeros([self.P.shape[0] - SL.size, SL.size]))) + singLsq = np.square(SL) + + # SVD of right matrix + UR, SR, VRh = np.linalg.svd(self.Q, full_matrices=True) + VR = VRh.T + DR = np.concatenate((np.diag(SR), np.zeros([self.Q.shape[0] - SR.size, SR.size]))) + singRsq = np.square(SR) + + # Applying analytical reconstruction + Yc = img[:, :, c] + inner = multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( + np.outer(singLsq, singRsq) + np.full(x_est.shape[0:2], self.lmbd) + ) + x_est[:, :, c] = multi_dot([VL, inner, VR.T]) + + # Non-negativity constraint: setting all negative values to 0 + x_est = x_est.clip(min=0) + + # Normalizing the image + x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) + + return x_est diff --git a/lensless/utils/image.py b/lensless/utils/image.py index b267bb75..7d2c65b3 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -168,7 +168,7 @@ def get_max_val(img, nbits=None): return max_val -def bayer2rgb( +def bayer2rgb_cc( img, nbits, blue_gain=None, @@ -276,3 +276,250 @@ def autocorr2d(vals, pad_mode="reflect"): # remove padding return autocorr[shape[0] // 2 : -shape[0] // 2, shape[1] // 2 : -shape[1] // 2] + + +def rgb2bayer(img, pattern): + """ + Converting RGB image to separated Bayer channels. + + Parameters + ---------- + img : :py:class:`~numpy.ndarray` + Image in RGB format. + pattern : str + Bayer pattern: `RGGB`, `BGGR`, `GRBG`, `GBRG`. + + Returns + ------- + :py:class:`~numpy.ndarray` + Image converted to the Bayer format `[R, Gr, Gb, B]`. `Gr` and `Gb` are for the green pixels that are on the same line as the red and blue pixels respectively. + """ + + # Verifying that the pattern is a proper Bayer pattern + pattern = pattern.upper() + assert pattern in [ + "RGGB", + "BGGR", + "GRBG", + "GBRG", + ], "Bayer pattern must be in ['RGGB', 'BGGR', 'GRBG', 'GBRG']" + + # Doubling the size of the image to anticipatie shrinking from Bayer transformation + height, width, _ = img.shape + resized = resize(img, shape=(height * 2, width * 2, 3)) + + # Separating each Bayer channel + + if pattern == "RGGB": + # RGGB pattern *------* + # | R G | + # | G B | + # *------* + r = resized[::2, ::2, 0] + gr = resized[1::2, ::2, 1] + gb = resized[::2, 1::2, 1] + b = resized[1::2, 1::2, 2] + img_bayer = np.dstack((r, gr, gb, b)) + + elif pattern == "BGGR": + # BGGR pattern *------* + # | B G | + # | G R | + # *------* + r = resized[1::2, 1::2, 0] + gr = resized[::2, 1::2, 1] + gb = resized[1::2, ::2, 1] + b = resized[::2, ::2, 2] + img_bayer = np.dstack((b, gb, gr, r)) + + elif pattern == "GRBG": + # GRGB pattern *------* + # | G R | + # | B G | + # *------* + r = resized[1::2, ::2, 0] + gr = resized[::2, ::2, 1] + gb = resized[1::2, 1::2, 1] + b = resized[::2, 1::2, 2] + img_bayer = np.dstack((gr, r, b, gb)) + + else: + # GBRG pattern *------* + # | G B | + # | R G | + # *------* + r = resized[::2, 1::2, 0] + gr = resized[1::2, 1::2, 1] + gb = resized[::2, ::2, 1] + b = resized[1::2, ::2, 2] + img_bayer = np.dstack((gb, b, r, gr)) + + return img_bayer + + +def bayer2rgb(X_bayer, pattern): + """ + Converting 4-channel Bayer image to RGB by averaging the two green channels. + + Parameters + ---------- + X_bayer : :py:class:`~numpy.ndarray` + Image in RGB format. + pattern : str + Bayer pattern: `RGGB`, `BGGR`, `GRBG`, `GBRG`. + + Returns + ------- + :py:class:`~numpy.ndarray` + Image converted to the RGB format. + """ + + # Verifying that the pattern is a proper Bayer pattern + pattern = pattern.upper() + assert pattern in [ + "RGGB", + "BGGR", + "GRBG", + "GBRG", + ], "Bayer pattern must be in ['RGGB', 'BGGR', 'GRBG', 'GBRG']" + + r_channel = [i for i, ltr in enumerate(pattern) if ltr == "R"][0] + b_channel = [i for i, ltr in enumerate(pattern) if ltr == "B"][0] + g_channels = [i for i, ltr in enumerate(pattern) if ltr == "G"] + + X_rgb = np.empty(X_bayer.shape[:-1] + (3,)) + X_rgb[:, :, 0] = X_bayer[:, :, r_channel] + X_rgb[:, :, 1] = np.mean(X_bayer[:, :, g_channels], axis=2) + X_rgb[:, :, 2] = X_bayer[:, :, b_channel] + + return X_rgb + + +def load_drunet(model_path, n_channels=3, requires_grad=False): + """ + Load a pre-trained Drunet model. + + Parameters + ---------- + model_path : str + Path to pre-trained model. + n_channels : int + Number of channels in input image. + requires_grad : bool + Whether to require gradients for model parameters. + + Returns + ------- + model : :py:class:`~torch.nn.Module` + Loaded model. + """ + from lensless.recon.drunet.network_unet import UNetRes + + model = UNetRes( + in_nc=n_channels + 1, + out_nc=n_channels, + nc=[64, 128, 256, 512], + nb=4, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ) + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = requires_grad + + return model + + +def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference"): + """ + Apply a pre-trained denoising model with input in the format Channel, Height, Width. + An additionnal channel is added for the noise level as done in Drunet. + + Parameters + ---------- + model : :py:class:`~torch.nn.Module` + Drunet compatible model. Its input must concist of 4 channels ( RGB + noise level) and outbut an RGB image both in CHW format. + image : :py:class:`~torch.Tensor` + Input image. + noise_level : float or :py:class:`~torch.Tensor` + Noise level in the image. + device : str + Device to use for computation. Can be "cpu" or "cuda". + mode : str + Mode to use for model. Can be "inference" or "train". + + Returns + ------- + image : :py:class:`~torch.Tensor` + Reconstructed image. + """ + # convert from NDHWC to NCHW + depth = image.shape[-4] + image = image.movedim(-1, -3) + image = image.reshape(-1, *image.shape[-3:]) + # pad image H and W to next multiple of 8 + top = (8 - image.shape[-2] % 8) // 2 + bottom = (8 - image.shape[-2] % 8) - top + left = (8 - image.shape[-1] % 8) // 2 + right = (8 - image.shape[-1] % 8) - left + image = torch.nn.functional.pad(image, (left, right, top, bottom), mode="constant", value=0) + # add noise level as extra channel + image = image.to(device) + if isinstance(noise_level, torch.Tensor): + noise_level = noise_level / 255.0 + else: + noise_level = torch.tensor([noise_level / 255.0]).to(device) + image = torch.cat( + ( + image, + noise_level.repeat(image.shape[0], 1, image.shape[2], image.shape[3]), + ), + dim=1, + ) + + # apply model + if mode == "inference": + with torch.no_grad(): + image = model(image) + elif mode == "train": + image = model(image) + else: + raise ValueError("mode must be 'inference' or 'train'") + + # remove padding + image = image[:, :, top:-bottom, left:-right] + # convert back to NDHWC + image = image.movedim(-3, -1) + image = image.reshape(-1, depth, *image.shape[-3:]) + return image + + +def process_with_DruNet(model, device="cpu", mode="inference"): + """ + Return a porcessing function that applies the DruNet model to an image. + + Parameters + ---------- + model : torch.nn.Module + DruNet like denoiser model + device : str + Device to use for computation. Can be "cpu" or "cuda". + mode : str + Mode to use for model. Can be "inference" or "train". + """ + + def process(image, noise_level): + x_max = torch.amax(image, dim=(-2, -3), keepdim=True) + 1e-6 + image = apply_denoiser( + model, + image, + noise_level=noise_level, + device=device, + mode="train", + ) + image = torch.clip(image, min=0.0) * x_max + return image + + return process diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 32aa0446..57c4f740 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -6,7 +6,7 @@ from lensless.utils.plot import plot_image from lensless.hardware.constants import RPI_HQ_CAMERA_BLACK_LEVEL, RPI_HQ_CAMERA_CCM_MATRIX -from lensless.utils.image import bayer2rgb, print_image_info, resize, rgb2gray +from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray def load_image( @@ -89,7 +89,7 @@ def load_image( img = img.astype(dtype) if nbits_out is None: nbits_out = n_bits - img = bayer2rgb( + img = bayer2rgb_cc( img, nbits=n_bits, blue_gain=blue_gain, diff --git a/mask_requirements.txt b/mask_requirements.txt new file mode 100644 index 00000000..ee87c51f --- /dev/null +++ b/mask_requirements.txt @@ -0,0 +1,3 @@ +sympy>=1.11.1 +perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 +waveprop>=0.0.4 \ No newline at end of file diff --git a/scripts/measure/collect_dataset_on_device.py b/scripts/measure/collect_dataset_on_device.py index e7066446..7d58ed61 100644 --- a/scripts/measure/collect_dataset_on_device.py +++ b/scripts/measure/collect_dataset_on_device.py @@ -29,7 +29,7 @@ RPI_HQ_CAMERA_BLACK_LEVEL, ) import picamerax.array -from lensless.utils.image import bayer2rgb, resize +from lensless.utils.image import bayer2rgb_cc, resize import cv2 @@ -203,7 +203,7 @@ def collect_dataset(config): output_bayer = np.sum(stream.array, axis=2).astype(np.uint16) # convert to RGB - output = bayer2rgb( + output = bayer2rgb_cc( output_bayer, nbits=12, blue_gain=float(g[1]), diff --git a/scripts/measure/on_device_capture.py b/scripts/measure/on_device_capture.py index 8f3e0c9c..22241807 100644 --- a/scripts/measure/on_device_capture.py +++ b/scripts/measure/on_device_capture.py @@ -19,7 +19,7 @@ from time import sleep from PIL import Image from lensless.hardware.utils import get_distro -from lensless.utils.image import bayer2rgb, rgb2gray, resize +from lensless.utils.image import bayer2rgb_cc, rgb2gray, resize from lensless.hardware.constants import RPI_HQ_CAMERA_CCM_MATRIX, RPI_HQ_CAMERA_BLACK_LEVEL from fractions import Fraction import time @@ -181,7 +181,7 @@ def capture(config): red_gain = config.awb_gains[0] blue_gain = config.awb_gains[1] - output_rgb = bayer2rgb( + output_rgb = bayer2rgb_cc( output, nbits=n_bits, blue_gain=blue_gain, diff --git a/scripts/sim/mask_dataset.py b/scripts/sim/mask_dataset.py new file mode 100644 index 00000000..c8e19153 --- /dev/null +++ b/scripts/sim/mask_dataset.py @@ -0,0 +1,266 @@ +""" + +Simulate a mask, simulate a few measurements with it, and reconstruct the images. + +Procedure is as follows: + +1) Simulate the mask. +2) Simulate measurements with the mask and specified physical parameters. +3) Reconstruct the images from the measurements. + +Example usage: + +Simulate FlatCam with separable simulation and Tikhonov reconstuction (https://arxiv.org/abs/1509.00116, Eq 7): +``` +python scripts/sim/mask_dataset.py mask.type=MLS simulation.flatcam=True recon.algo=tikhonov +``` + +Simulate FlatCam with PSF simulation and Tikhonov reconstuction: +``` +python scripts/sim/mask_dataset.py mask.type=MLS simulation.flatcam=False recon.algo=tikhonov +``` + +Simulate FlatCam with PSF simulation and ADMM reconstruction: +``` +python scripts/sim/mask_dataset.py mask.type=MLS simulation.flatcam=False recon.algo=admm +``` + +Simulate Fresnel Zone Aperture camera with PSF simulation and ADMM reconstuction (https://www.nature.com/articles/s41377-020-0289-9): +``` +python scripts/sim/mask_dataset.py mask.type=FZA recon.algo=admm +``` + +Simulate PhaseContour camera with PSF simulation and ADMM reconstuction (https://ieeexplore.ieee.org/document/9076617): +``` +python scripts/sim/mask_dataset.py mask.type=PhaseContour recon.algo=admm +``` + +If Pytorch is installed, you can use the `torch` flag to use Pytorch for the reconstruction (ADMM only): +``` +python scripts/sim/mask_dataset.py mask.type=PhaseContour recon.algo=admm +``` + +""" + +import hydra +import warnings +from hydra.utils import to_absolute_path +from lensless.utils.io import load_image, save_image +from lensless.utils.image import rgb2gray +import numpy as np +from lensless import ADMM +from lensless.eval.metric import mse, psnr, ssim, lpips +from waveprop.simulation import FarFieldSimulator +import glob +import os +from tqdm import tqdm +from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture +from lensless.recon.tikhonov import CodedApertureReconstruction + + +@hydra.main(version_base=None, config_path="../../configs", config_name="mask_sim_dataset") +def simulate(config): + + if config.torch: + try: + import torch + except ImportError: + raise ImportError("Pytorch not found. Please install pytorch to use torch mode.") + + # set seed + np.random.seed(config.seed) + + mask2sensor = config.simulation.mask2sensor + sensor = config.simulation.sensor + snr_db = config.simulation.snr_db + downsample = config.simulation.downsample + + dataset = to_absolute_path(config.files.dataset) + if not os.path.isdir(dataset): + print(f"No dataset found at {dataset}") + try: + from torchvision.datasets.utils import download_and_extract_archive, download_url + except ImportError: + exit() + msg = "Do you want to download the sample CelebA dataset (764KB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/Q5OdDQMwhucIlt8/download" + filename = "celeb_mini.zip" + download_and_extract_archive( + url, os.path.dirname(dataset), filename=filename, remove_finished=True + ) + + mask_type = config.mask.type + + # check for flatcam simulation + flatcam_sim = config.simulation.flatcam + if flatcam_sim and mask_type.upper() not in ["MURA", "MLS"]: + warnings.warn( + "Flatcam simulation only supported for MURA and MLS masks. Using far field simulation with PSF." + ) + flatcam_sim = False + + if config.save: + if flatcam_sim: + save_dir = to_absolute_path(config.files.dataset + "_" + mask_type + "_flatcam_sim") + else: + save_dir = to_absolute_path(config.files.dataset + "_" + mask_type) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + os.makedirs(os.path.join(save_dir, "sensor_plane")) + os.makedirs(os.path.join(save_dir, "object_plane")) + os.makedirs(os.path.join(save_dir, "reconstruction")) + + # simulate mask + mask = None + if mask_type.upper() in ["MURA", "MLS"]: + mask = CodedAperture.from_sensor( + sensor_name=sensor, + downsample=downsample, + method=mask_type, + distance_sensor=mask2sensor, + **config.mask, + ) + psf_sim = mask.psf / np.linalg.norm(mask.psf.ravel()) + elif mask_type.upper() == "FZA": + mask = FresnelZoneAperture.from_sensor( + sensor_name=sensor, + downsample=downsample, + distance_sensor=mask2sensor, + **config.mask, + ) + psf_sim = mask.psf / np.linalg.norm(mask.psf.ravel()) + elif mask_type == "PhaseContour": + mask = PhaseContour.from_sensor( + sensor_name=sensor, + downsample=downsample, + distance_sensor=mask2sensor, + n_iter=config.mask.phase_mask_iter, + **config.mask, + ) + psf_sim = mask.psf / np.linalg.norm(mask.psf.ravel()) + assert mask is not None, f"Mask type {mask_type} not implemented." + + if config.simulation.grayscale and len(psf_sim.shape) == 3: + psf_sim = rgb2gray(psf_sim) + + if config.simulation.downsample > 1: + print(f"Downsampled to {psf_sim.shape}.") + + # prepare simulator object + simulator = FarFieldSimulator(psf=psf_sim, **config.simulation) + + # loop over files in dataset + print("\nSimulating dataset...") + files = glob.glob(os.path.join(dataset, f"*.{config.files.image_ext}")) + if config.files.n_files is not None: + files = files[: config.files.n_files] + + for fp in tqdm(files): + + # load image as numpy array + image = load_image(fp) / 255 + if config.simulation.grayscale and len(image.shape) == 3: + image = rgb2gray(image) + + # simulate image + image_plane, object_plane = simulator.propagate(image, return_object_plane=True) + if flatcam_sim: + image_plane = mask.simulate(object_plane, snr_db=snr_db) + + if config.save: + + bn = os.path.basename(fp).split(".")[0] + ".png" + + # can serve as ground truth + object_plane_fp = os.path.join(save_dir, "object_plane", bn) + save_image(object_plane, object_plane_fp) # use max range of 255 + + # lensless image + lensless_fp = os.path.join(save_dir, "sensor_plane", bn) + save_image(image_plane, lensless_fp, max_val=config.simulation.max_val) + + # reconstruction + recon_algo = config.recon.algo.lower() + + if config.recon.algo is not None: + + print("\nReconstructing lensless measurements...") + # -- initialize reconstruction object + if recon_algo == "tikhonov": + if config.torch: + raise NotImplementedError("Tikhonov reconstruction not implemented for torch.") + recon = CodedApertureReconstruction( + mask, object_plane.shape, lmbd=config.recon.tikhonov.reg + ) + elif recon_algo == "admm": + psf = psf_sim[np.newaxis, :, :, :] + if config.torch: + psf = torch.from_numpy(psf).to(config.torch_device) + recon = ADMM(psf, **config.recon.admm) + + # -- metrics + mse_vals = [] + psnr_vals = [] + ssim_vals = [] + if not config.simulation.grayscale: + lpips_vals = [] + else: + lpips_vals = None + + # -- loop over files in dataset + files = glob.glob(os.path.join(save_dir, "sensor_plane", "*.png")) + if config.files.n_files is not None: + files = files[: config.files.n_files] + + for fp in tqdm(files): + + lensless = load_image(fp, as_4d=True) + lensless = lensless / np.max(lensless) + if recon_algo == "tikhonov": + recovered = recon.apply(lensless[0]) + elif recon_algo == "admm": + if config.torch: + lensless = torch.from_numpy(lensless).to(config.torch_device) + recon.set_data(lensless) + res, _ = recon.apply(n_iter=config.recon.admm.n_iter) + + # get first depth + if config.torch: + recovered = res[0].cpu().numpy() + else: + recovered = res[0] + + if config.save: + bn = os.path.basename(fp).split(".")[0] + ".png" + lensless_fp = os.path.join(save_dir, "reconstruction", bn) + + save_image(recovered, lensless_fp, max_val=config.simulation.max_val) + + # compute metrics + object_plane_fp = os.path.join(save_dir, "object_plane", os.path.basename(fp)) + object_plane = load_image(object_plane_fp) + + mse_vals.append(mse(object_plane, recovered)) + psnr_vals.append(psnr(object_plane, recovered)) + if config.simulation.grayscale: + ssim_vals.append(ssim(object_plane, recovered, channel_axis=None)) + else: + ssim_vals.append(ssim(object_plane, recovered)) + if lpips_vals is not None: + lpips_vals.append(lpips(object_plane, recovered)) + + print("\nMSE (avg)", np.mean(mse_vals)) + print("PSNR (avg)", np.mean(psnr_vals)) + print("SSIM (avg)", np.mean(ssim_vals)) + if lpips_vals is not None: + print("LPIPS (avg)", np.mean(lpips_vals)) + + print("Results saved to", save_dir) + + +if __name__ == "__main__": + simulate() diff --git a/scripts/sim/mask_single_file.py b/scripts/sim/mask_single_file.py new file mode 100644 index 00000000..e8a741b5 --- /dev/null +++ b/scripts/sim/mask_single_file.py @@ -0,0 +1,233 @@ +""" + +Simulate a mask, simulate a measurement with it, and reconstruct the image. + +Procedure is as follows: + +1) Simulate the mask. +2) Simulate a measurement with the mask and specified physical parameters. +3) Reconstruct the image from the measurement. + +Example usage: + +Simulate FlatCam with separable simulation and Tikhonov reconstuction (https://arxiv.org/abs/1509.00116, Eq 7): +``` +# MLS mask +python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=True recon.algo=tikhonov + +# MURA mask +python scripts/sim/mask_single_file.py mask.type=MURA mask.n_bits=99 simulation.flatcam=True recon.algo=tikhonov +``` + +Simulate FlatCam with PSF simulation and Tikhonov reconstuction: +``` +python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=False recon.algo=tikhonov +``` + +Simulate FlatCam with PSF simulation and ADMM reconstruction. Doesn't work well. +``` +python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=False recon.algo=admm +``` + +Simulate Fresnel Zone Aperture camera with PSF simulation and ADMM reconstuction (https://www.nature.com/articles/s41377-020-0289-9): +Doesn't work well, maybe need to remove DC offset which hurts reconstructions? +``` +python scripts/sim/mask_single_file.py mask.type=FZA recon.algo=admm recon.admm.n_iter=18 +``` + +Simulate PhaseContour camera with PSF simulation and ADMM reconstuction (https://ieeexplore.ieee.org/document/9076617): +``` +python scripts/sim/mask_single_file.py mask.type=PhaseContour recon.algo=admm recon.admm.n_iter=10 +``` + +""" + +import hydra +import warnings +from hydra.utils import to_absolute_path +from lensless.utils.io import load_image, save_image +from lensless.utils.image import rgb2gray, rgb2bayer, bayer2rgb +import numpy as np +import matplotlib.pyplot as plt +from lensless import ADMM +from lensless.utils.plot import plot_image +from lensless.eval.metric import mse, psnr, ssim, lpips +from waveprop.simulation import FarFieldSimulator +import os +from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture +from lensless.recon.tikhonov import CodedApertureReconstruction + + +@hydra.main(version_base=None, config_path="../../configs", config_name="mask_sim_single") +def simulate(config): + + fp = to_absolute_path(config.files.original) + assert os.path.exists(fp), f"File {fp} does not exist." + + # simulation parameters + object_height = config.simulation.object_height + scene2mask = config.simulation.scene2mask + mask2sensor = config.simulation.mask2sensor + sensor = config.simulation.sensor + snr_db = config.simulation.snr_db + downsample = config.simulation.downsample + max_val = config.simulation.max_val + + image_format = config.simulation.image_format.lower() + if image_format not in ["grayscale", "rgb"]: + bayer = True + else: + bayer = False + + # 1) simulate mask + mask_type = config.mask.type + if mask_type.upper() in ["MURA", "MLS"]: + mask = CodedAperture.from_sensor( + sensor_name=sensor, + downsample=downsample, + method=mask_type, + distance_sensor=mask2sensor, + **config.mask, + ) + elif mask_type.upper() == "FZA": + mask = FresnelZoneAperture.from_sensor( + sensor_name=sensor, + downsample=downsample, + distance_sensor=mask2sensor, + **config.mask, + ) + elif mask_type.lower() == "PhaseContour".lower(): + mask = PhaseContour.from_sensor( + sensor_name=sensor, + downsample=downsample, + distance_sensor=mask2sensor, + n_iter=config.mask.phase_mask_iter, + **config.mask, + ) + + # 2) simulate measurement + image = load_image(fp, verbose=True) / 255 + + flatcam_sim = config.simulation.flatcam + if flatcam_sim and mask_type.upper() not in ["MURA", "MLS"]: + warnings.warn( + "Flatcam simulation only supported for MURA and MLS masks. Using far field simulation with PSF." + ) + flatcam_sim = False + + # use far field simulator to get correct object plane sizing + simulator = FarFieldSimulator( + psf=mask.psf, + object_height=object_height, + scene2mask=scene2mask, + mask2sensor=mask2sensor, + sensor=sensor, + snr_db=snr_db, + max_val=max_val, + ) + image_plane, object_plane = simulator.propagate(image, return_object_plane=True) + + if image_format == "grayscale": + image_plane = rgb2gray(image_plane) + object_plane = rgb2gray(object_plane) + elif "bayer" in image_format: + image_plane = rgb2bayer(image_plane, pattern=image_format[-4:]) + object_plane = rgb2bayer(object_plane, pattern=image_format[-4:]) + else: + # make sure image is RGB + assert image_plane.shape[-1] == 3, "Image plane must be RGB" + assert object_plane.shape[-1] == 3, "Object plane must be RGB" + + if flatcam_sim: + image_plane = mask.simulate(object_plane, snr_db=snr_db) + + # 3) reconstruct image + save = config["save"] + if save: + save = os.getcwd() + + if config.recon.algo.lower() == "tikhonov": + recon = CodedApertureReconstruction( + mask, object_plane.shape, lmbd=config.recon.tikhonov.reg + ) + recovered = recon.apply(image_plane) + + elif config.recon.algo.lower() == "admm": + + if bayer: + raise ValueError("ADMM reconstruction not supported for Bayer images.") + + # prep PSF + if image_format == "grayscale": + psf = rgb2gray(mask.psf) + else: + psf = mask.psf + psf = psf[np.newaxis, :, :, :] / np.linalg.norm(mask.psf.ravel()) + + # prep recon + recon = ADMM(psf, **config.recon.admm) + + # add depth channel + recon.set_data(image_plane[None, :, :, :]) + res = recon.apply( + n_iter=config.recon.admm.n_iter, disp_iter=config.recon.admm.disp_iter, save=save + )[0] + + # remove depth channel + recovered = res[0] + else: + raise ValueError(f"Reconstruction algorithm {config.recon.algo} not recognized.") + + # 4) evaluate + if image_format == "grayscale": + object_plane = object_plane[:, :, 0] + recovered = recovered[:, :, 0] + + print("\nEvaluation:") + print("MSE", mse(object_plane, recovered)) + print("PSNR", psnr(object_plane, recovered)) + if image_format == "grayscale": + print("SSIM", ssim(object_plane, recovered, channel_axis=None)) + else: + print("SSIM", ssim(object_plane, recovered)) + if image_format == "rgb": + print("LPIPS", lpips(object_plane, recovered)) + + # -- plot + if bayer: + print("Converting to RGB for plotting and saving...") + image_plane = bayer2rgb(image_plane, pattern=image_format[-4:]) + object_plane = bayer2rgb(object_plane, pattern=image_format[-4:]) + recovered = bayer2rgb(recovered, pattern=image_format[-4:]) + + _, ax = plt.subplots(ncols=5, nrows=1, figsize=(24, 5)) + plot_image(object_plane, ax=ax[0]) + ax[0].set_title("Object plane") + if np.iscomplexobj(mask.mask): + # plot phase + plot_image(np.angle(mask.mask), ax=ax[1]) + ax[1].set_title("Phase mask") + else: + plot_image(mask.mask, ax=ax[1]) + ax[1].set_title("Amplitude mask") + plot_image(mask.psf, ax=ax[2], gamma=2.2) + ax[2].set_title("PSF") + plot_image(image_plane, ax=ax[3]) + ax[3].set_title("Raw data") + plot_image(recovered, ax=ax[4]) + ax[4].set_title("Reconstruction") + + for a in ax: + a.set_xticks([]), a.set_yticks([]) + + plt.tight_layout() + plt.savefig("result.png") + + if config.save: + save_image(recovered, "reconstruction.png") + + plt.show() + + +if __name__ == "__main__": + simulate() diff --git a/test/test_masks.py b/test/test_masks.py new file mode 100644 index 00000000..a16659d6 --- /dev/null +++ b/test/test_masks.py @@ -0,0 +1,98 @@ +import numpy as np +from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture +from lensless.eval.metric import mse, psnr, ssim +from waveprop.fresnel import fresnel_conv + + +resolution = np.array([380, 507]) +d1 = 3e-6 +dz = 4e-3 + + +def test_flatcam(): + + mask1 = CodedAperture( + method="MURA", + n_bits=25, + resolution=resolution, + feature_size=d1, + distance_sensor=dz, + ) + assert np.all(mask1.mask.shape == resolution) + + desired_psf_shape = np.array(tuple(resolution) + (len(mask1.psf_wavelength),)) + assert np.all(mask1.psf.shape == desired_psf_shape) + + mask2 = CodedAperture( + method="MLS", + n_bits=5, + resolution=resolution, + feature_size=d1, + distance_sensor=dz, + ) + assert np.all(mask2.mask.shape == resolution) + desired_psf_shape = np.array(tuple(resolution) + (len(mask2.psf_wavelength),)) + assert np.all(mask2.psf.shape == desired_psf_shape) + + +def test_phlatcam(): + + mask = PhaseContour( + noise_period=(8, 8), + resolution=resolution, + feature_size=d1, + distance_sensor=dz, + ) + assert np.all(mask.mask.shape == resolution) + desired_psf_shape = np.array(tuple(resolution) + (len(mask.psf_wavelength),)) + assert np.all(mask.psf.shape == desired_psf_shape) + + Mp = np.sqrt(mask.target_psf) * np.exp( + 1j * np.angle(fresnel_conv(mask.mask, mask.design_wv, d1, dz, dtype=np.float32)[0]) + ) + assert mse(abs(Mp), np.sqrt(mask.target_psf)) < 0.1 + assert psnr(abs(Mp), np.sqrt(mask.target_psf)) > 30 + assert abs(1 - ssim(abs(Mp), np.sqrt(mask.target_psf), channel_axis=None)) < 0.1 + + +def test_fza(): + + mask = FresnelZoneAperture( + radius=30.0, resolution=resolution, feature_size=d1, distance_sensor=dz + ) + assert np.all(mask.mask.shape == resolution) + desired_psf_shape = np.array(tuple(resolution) + (len(mask.psf_wavelength),)) + assert np.all(mask.psf.shape == desired_psf_shape) + + +def test_classmethod(): + + downsample = 8 + + mask1 = CodedAperture.from_sensor( + sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz + ) + assert np.all(mask1.mask.shape == resolution) + desired_psf_shape = np.array(tuple(resolution) + (len(mask1.psf_wavelength),)) + assert np.all(mask1.psf.shape == desired_psf_shape) + + mask2 = PhaseContour.from_sensor( + sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz + ) + assert np.all(mask2.mask.shape == resolution) + desired_psf_shape = np.array(tuple(resolution) + (len(mask2.psf_wavelength),)) + assert np.all(mask2.psf.shape == desired_psf_shape) + + mask3 = FresnelZoneAperture.from_sensor( + sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz + ) + assert np.all(mask3.mask.shape == resolution) + desired_psf_shape = np.array(tuple(resolution) + (len(mask3.psf_wavelength),)) + assert np.all(mask3.psf.shape == desired_psf_shape) + + +if __name__ == "__main__": + test_flatcam() + test_phlatcam() + test_fza() + test_classmethod()