From 063257a4bf15fcbe9842fa19668c09ff43e9f9cf Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 1 Aug 2023 17:30:33 +0200 Subject: [PATCH 01/12] Fix path to display script. (#73) * Add checks to pixel levels. * Add desired levels in config. * Clean script. * Move file to new subfolder in scripts. * Add script to analyze measured dataset. * Move measurement scripts. * Default to image in repo. * Move scripts into subfolders. * Update CHANGELOG. * Fix path script. * Fix path to display script. --- scripts/measure/collect_dataset_on_device.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/measure/collect_dataset_on_device.py b/scripts/measure/collect_dataset_on_device.py index 7d58ed61..4de5dc4a 100644 --- a/scripts/measure/collect_dataset_on_device.py +++ b/scripts/measure/collect_dataset_on_device.py @@ -1,7 +1,7 @@ """ To be run on the Raspberry Pi! ``` -python scripts/collect_dataset_on_device.py +python scripts/measure/collect_dataset_on_device.py ``` Note that the script is configured for the Raspberry Pi HQ camera @@ -173,7 +173,7 @@ def collect_dataset(config): display_image_path = config.display.output_fp rot90 = config.display.rot90 os.system( - f"python scripts/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" + f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" ) time.sleep(config.capture.delay) @@ -241,7 +241,7 @@ def collect_dataset(config): display_image_path = config.display.output_fp rot90 = config.display.rot90 os.system( - f"python scripts/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" + f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" ) print(f"decreasing screen brightness to {current_screen_brightness}") From 6546bcf621d39389e7f80747dbcf8ea25407f134 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Thu, 3 Aug 2023 14:58:09 +0200 Subject: [PATCH 02/12] Digicam (set LCD pattern and mask-sensor distance) and classification (#66) * Fix ADMM ordering * Return client when checking SSH credentials. * Add function to set pattern remotely. * Add author info. * Add aperture module. * Add script ot configure digicam. * Add function to set mask-sensor distance. * Fix max_distance check. * Add option to set pattern from file. * Add config for measuring with adafruit. * Update config. * Add option to set all white screen. * Add option for virtual config. * Fix APGD defaults. * Fix torch requirement to be compatible with pycsou. * Fix typo. * Add authors. * Update gitignore. * Add missing form image. * Add downsampling support to APGD. * add resizing back to original shape * Move some preprocessing into load_image. * Remove unused import. * Add script to reconstruct measured dataset. * Remove unused import. * Clean up dataset reconstruction. * Add files for training CelebA classifier from ViT. * Add link to 10K dataset. * Add script to simulate digicam PSF. * Add utility for getting dtype. * Clean up config digicam. * Fix dtype parsing. * Clean up SLM simulation. * Add option to compare wih measured PSF. --------- Co-authored-by: Yohann PERRON --- .gitignore | 1 + configs/adafruit.yaml | 9 + configs/apgd_l1.yaml | 4 + configs/apgd_l2.yaml | 4 + configs/defaults_recon.yaml | 1 + configs/demo.yaml | 2 + configs/digicam.yaml | 23 ++ configs/recon_dataset.yaml | 47 +++ configs/sim_digicam_psf.yaml | 38 +++ configs/train_celeba_classifier.yaml | 38 +++ digicam_requirements.txt | 1 + docs/source/data.rst | 14 + lensless/hardware/aperture.py | 379 +++++++++++++++++++++++ lensless/hardware/mask.py | 1 - lensless/hardware/sensor.py | 3 + lensless/hardware/slm.py | 298 ++++++++++++++++++ lensless/hardware/utils.py | 62 +++- lensless/recon/apgd.py | 73 ++++- lensless/recon/recon.py | 16 +- lensless/utils/image.py | 2 +- lensless/utils/io.py | 133 +++++++- recon_requirements.txt | 2 +- scripts/classify/train_celeba_vit.py | 330 ++++++++++++++++++++ scripts/demo.py | 4 +- scripts/hardware/config_digicam.py | 101 ++++++ scripts/hardware/digicam_measure_psfs.py | 60 ++++ scripts/measure/remote_capture.py | 4 +- scripts/measure/remote_display.py | 13 +- scripts/recon/admm.py | 1 + scripts/recon/apgd_pycsou.py | 4 +- scripts/recon/dataset.py | 202 ++++++++++++ scripts/sim/dataset.py | 2 +- scripts/sim/digicam_psf.py | 154 +++++++++ 33 files changed, 1977 insertions(+), 49 deletions(-) create mode 100644 configs/adafruit.yaml create mode 100644 configs/digicam.yaml create mode 100644 configs/recon_dataset.yaml create mode 100644 configs/sim_digicam_psf.yaml create mode 100644 configs/train_celeba_classifier.yaml create mode 100644 digicam_requirements.txt create mode 100644 lensless/hardware/aperture.py create mode 100644 lensless/hardware/slm.py create mode 100644 scripts/classify/train_celeba_vit.py create mode 100644 scripts/hardware/config_digicam.py create mode 100644 scripts/hardware/digicam_measure_psfs.py create mode 100644 scripts/recon/dataset.py create mode 100644 scripts/sim/digicam_psf.py diff --git a/.gitignore b/.gitignore index 3430ec87..ca9b0dfa 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ data/* models/* *.png *.jpg +*.npy configs/telegram_demo_secret.yaml diff --git a/configs/adafruit.yaml b/configs/adafruit.yaml new file mode 100644 index 00000000..5399204e --- /dev/null +++ b/configs/adafruit.yaml @@ -0,0 +1,9 @@ +defaults: + - demo + - _self_ + +plot: True + +capture: + exp: 5.0 + awb_gains: [1, 1] diff --git a/configs/apgd_l1.yaml b/configs/apgd_l1.yaml index 5d0621cc..006b72aa 100644 --- a/configs/apgd_l1.yaml +++ b/configs/apgd_l1.yaml @@ -3,6 +3,10 @@ defaults: - defaults_recon - _self_ +preprocess: + # Downsampling factor along X and Y + downsample: 8 + apgd: # Proximal prior / regularization: nonneg, l1, null prox_penalty: l1 diff --git a/configs/apgd_l2.yaml b/configs/apgd_l2.yaml index 65a16405..0b50ba73 100644 --- a/configs/apgd_l2.yaml +++ b/configs/apgd_l2.yaml @@ -3,6 +3,10 @@ defaults: - defaults_recon - _self_ +preprocess: + # Downsampling factor along X and Y + downsample: 8 + apgd: diff_penalty: l2 diff_lambda: 0.0001 diff --git a/configs/defaults_recon.yaml b/configs/defaults_recon.yaml index 5cd05d6c..324aa679 100644 --- a/configs/defaults_recon.yaml +++ b/configs/defaults_recon.yaml @@ -27,6 +27,7 @@ preprocess: single_psf: False # Whether to perform construction in grayscale. gray: False + bg_pix: [5, 25] # null to skip display: diff --git a/configs/demo.yaml b/configs/demo.yaml index c769d1a2..ddc0c528 100644 --- a/configs/demo.yaml +++ b/configs/demo.yaml @@ -26,6 +26,8 @@ display: psf: null # all black screen black: False + # all white screen + white: False capture: gamma: null # for visualization diff --git a/configs/digicam.yaml b/configs/digicam.yaml new file mode 100644 index 00000000..d84b3a89 --- /dev/null +++ b/configs/digicam.yaml @@ -0,0 +1,23 @@ +rpi: + username: null + hostname: null + +device: adafruit +virtual: False +save: True + +# pattern: data/psf/adafruit_random_pattern_20230719.npy +pattern: random +# pattern: rect +# pattern: circ +min_val: 0 # if pattern: random, min for range(0,1) +rect_shape: [20, 10] # if pattern: rect +radius: 20 # if pattern: circ +center: [0, 0] + + +aperture: + center: [59,76] + shape: [19,26] + +z: 4 # mask to sensor distance diff --git a/configs/recon_dataset.yaml b/configs/recon_dataset.yaml new file mode 100644 index 00000000..f474aed5 --- /dev/null +++ b/configs/recon_dataset.yaml @@ -0,0 +1,47 @@ +# python scripts/recon/dataset.py +defaults: + - defaults_recon + - _self_ + +torch: True +torch_device: 'cuda:0' + +input: + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + psf: data/psf/adafruit_random_2mm_20231907.png + # https://drive.switch.ch/index.php/s/m89D1tFEfktQueS + raw_data: data/celeba_adafruit_random_2mm_20230720_1K + +n_files: 25 # null for all files +output_folder: data/celeba_adafruit_recon + +# extraction region of interest +roi: null # top, left, bottom, right +# -- values for `data/celeba_adafruit_random_2mm_20230720_1K` +# roi: [10, 300, 560, 705] # down 4 +# roi: [6, 200, 373, 470] # down 6 +# roi: [5, 150, 280, 352] # down 8 + +preprocess: + flip: True + downsample: 6 + + # to have different data shape than PSF + data_dim: null + # data_dim: [48, 64] # down 64 + # data_dim: [506, 676] # down 6 + +display: + disp: -1 + plot: False + +algo: admm # "admm", "apgd", "null" to just copy over (resized) raw data + +apgd: + n_jobs: 1 # run in parallel as algo is slow + max_iter: 500 + +admm: + n_iter: 10 + +save: False \ No newline at end of file diff --git a/configs/sim_digicam_psf.yaml b/configs/sim_digicam_psf.yaml new file mode 100644 index 00000000..216455cd --- /dev/null +++ b/configs/sim_digicam_psf.yaml @@ -0,0 +1,38 @@ +# python scripts/sim/digicam_psf.py +hydra: + job: + chdir: True # change to output folder + +use_torch: False +dtype: float32 +torch_device: cuda +requires_grad: True + +digicam: + + slm: adafruit + sensor: rpi_hq + + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + pattern: data/psf/adafruit_random_pattern_20230719.npy + ap_center: [59, 76] + ap_shape: [19, 26] + rotate: -0.8 # rotation in degrees + + # optionally provide measured PSF for side-by-side comparison + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + psf: data/psf/adafruit_random_2mm_20231907.png + gamma: 2 # for plotting measured + +sim: + + # whether SLM is fliped + flipud: True + + # in practice found waveprop=True or False doesn't make difference + waveprop: False + + # below are ignored if waveprop=False + scene2mask: 0.03 # [m] + mask2sensor: 0.002 # [m] + \ No newline at end of file diff --git a/configs/train_celeba_classifier.yaml b/configs/train_celeba_classifier.yaml new file mode 100644 index 00000000..11a391c8 --- /dev/null +++ b/configs/train_celeba_classifier.yaml @@ -0,0 +1,38 @@ +hydra: + job: + chdir: True # change to output folder + +seed: 0 + +data: + # -- path to original CelebA (parent directory) + original: /scratch/bezzam + + output_dir: "./vit-celeba" # basename for model output + + # -- raw + # https://drive.switch.ch/index.php/s/m89D1tFEfktQueS + measured: data/celeba_adafruit_random_2mm_20230720_10K + raw: True + + # # -- reconstructed + # # run `python scripts/recon/dataset.py` to get a reconstructed dataset + # measured: null + # raw: False + + n_files: null # null to use all in measured_folder + test_size: 0.15 + attr: Male # "Male", "Smiling", etc + +augmentation: + + random_resize_crop: False + horizontal_flip: True # cannot be used with raw measurement! + +train: + + prev: null # path to previously trained model + n_epochs: 4 + dropout: 0.1 + batch_size: 16 + learning_rate: 2e-4 diff --git a/digicam_requirements.txt b/digicam_requirements.txt new file mode 100644 index 00000000..fbbcaa30 --- /dev/null +++ b/digicam_requirements.txt @@ -0,0 +1 @@ +slm_controller @ git+https://github.com/ebezzam/slm-controller.git \ No newline at end of file diff --git a/docs/source/data.rst b/docs/source/data.rst index 768b46fb..50b323c6 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -39,6 +39,20 @@ use the correct PSF file for the data you're using! input.psf=data/psf/tape_rgb.png +Measured CelebA Dataset +----------------------- + +You can download 1K measurements of the CelebA dataset done with +our lensless camera and a random pattern on the Adafruit LCD +`here (1.2 GB) `__, +and a dataset with 10K measurements +`here (13.1 GB) `__. +They both correspond to the PSF which can be found `here `__ +(``adafruit_random_2mm_20231907.png`` which is the PSF of +``adafruit_random_pattern_20230719.npy`` measured with a mask to sensor +distance of 2 mm). + + DiffuserCam Lensless Mirflickr Dataset (DLMD) --------------------------------------------- diff --git a/lensless/hardware/aperture.py b/lensless/hardware/aperture.py new file mode 100644 index 00000000..37e8e37b --- /dev/null +++ b/lensless/hardware/aperture.py @@ -0,0 +1,379 @@ +# ############################################################################# +# aperture.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +from enum import Enum + +import numpy as np +from lensless.utils.image import rgb2gray + + +class ApertureOptions(Enum): + RECT = "rect" + SQUARE = "square" + LINE = "line" + CIRC = "circ" + + @staticmethod + def values(): + return [shape.value for shape in ApertureOptions] + + +class Aperture: + def __init__(self, shape, pixel_pitch): + """ + Class for defining VirtualSLM. + + :param shape: (height, width) in number of cell. + :type shape: tuple(int) + :param pixel_pitch: Pixel pitch (height, width) in meters. + :type pixel_pitch: tuple(float) + """ + assert np.all(shape) > 0 + assert np.all(pixel_pitch) > 0 + self._shape = shape + self._pixel_pitch = pixel_pitch + self._values = np.zeros((3,) + shape, dtype=np.uint8) + + @property + def size(self): + return np.prod(self._shape) + + @property + def shape(self): + return self._shape + + @property + def pixel_pitch(self): + return self._pixel_pitch + + @property + def center(self): + return np.array([self.height / 2, self.width / 2]) + + @property + def dim(self): + return np.array(self._shape) * np.array(self._pixel_pitch) + + @property + def height(self): + return self.dim[0] + + @property + def width(self): + return self.dim[1] + + @property + def values(self): + return self._values + + @property + def grayscale_values(self): + return rgb2gray(self._values) + + def at(self, physical_coord, value=None): + """ + Get/set values of VirtualSLM at physical coordinate in meters. + + :param physical_coord: Physical coordinates to get/set VirtualSLM values. + :type physical_coord: int, float, slice tuples + :param value: [Optional] values to set, otherwise return values at + specified coordinates. Defaults to None + :type value: int, float, :py:class:`~numpy.ndarray`, optional + :return: If getter is used, values at those coordinates + :rtype: ndarray + """ + idx = prepare_index_vals(physical_coord, self._pixel_pitch) + if value is None: + # getter + return self._values[idx] + else: + # setter + self._values[idx] = value + + def __getitem__(self, key): + return self._values[key] + + def __setitem__(self, key, value): + self._values[key] = value + + def plot(self, show_tick_labels=False): + """ + Plot Aperture. + + :param show_tick_labels: Whether to show cell number along x- and y-axis, defaults to False + :type show_tick_labels: bool, optional + :return: The axes of the plot. + :rtype: Axes + """ + # prepare mask data for `imshow`, expects the input data array size to be (width, height, 3) + Z = self.values.transpose(1, 2, 0) + + # plot + import matplotlib.pyplot as plt + + _, ax = plt.subplots() + extent = [ + -0.5 * self._pixel_pitch[1], + (self._shape[1] - 0.5) * self._pixel_pitch[1], + (self._shape[0] - 0.5) * self._pixel_pitch[0], + -0.5 * self._pixel_pitch[0], + ] + ax.imshow(Z, extent=extent) + ax.grid(which="major", axis="both", linestyle="-", color="0.5", linewidth=0.25) + + x_ticks = np.arange(-0.5, self._shape[1], 1) * self._pixel_pitch[1] + ax.set_xticks(x_ticks) + if show_tick_labels: + x_tick_labels = (np.arange(-0.5, self._shape[1], 1) + 0.5).astype(int) + else: + x_tick_labels = [None] * len(x_ticks) + ax.set_xticklabels(x_tick_labels) + + y_ticks = np.arange(-0.5, self._shape[0], 1) * self._pixel_pitch[0] + ax.set_yticks(y_ticks) + if show_tick_labels: + y_tick_labels = (np.arange(-0.5, self._shape[0], 1) + 0.5).astype(int) + else: + y_tick_labels = [None] * len(y_ticks) + ax.set_yticklabels(y_tick_labels) + return ax + + +def rect_aperture(slm_shape, pixel_pitch, apert_dim, center=None): + """ + Create and return VirtualSLM object with rectangular aperture of desired dimensions. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param apert_dim: Dimensions (height, width) of aperture in meters. + :type apert_dim: tuple(float) + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :raises ValueError: If aperture does extend over the boarder of the SLM. + :return: VirtualSLM object with cells programmed to desired rectangular aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # check input values + assert np.all(apert_dim) > 0 + + # initialize SLM + slm = Aperture(shape=slm_shape, pixel_pitch=pixel_pitch) + + # check / compute center + if center is None: + center = slm.center + else: + assert ( + 0 <= center[0] < slm.height + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + assert ( + 0 <= center[1] < slm.width + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + + # compute mask + apert_dim = np.array(apert_dim) + top_left = center - apert_dim / 2 + bottom_right = top_left + apert_dim + if ( + top_left[0] < 0 + or top_left[1] < 0 + or bottom_right[0] >= slm.dim[0] + or bottom_right[1] >= slm.dim[1] + ): + raise ValueError( + f"Aperture ({top_left[0]}:{bottom_right[0]}, " + f"{top_left[1]}:{bottom_right[1]}) extends past valid " + f"VirtualSLM dimensions {slm.dim}" + ) + slm.at( + physical_coord=np.s_[top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]], + value=255, + ) + + return slm + + +def line_aperture(slm_shape, pixel_pitch, length, vertical=True, center=None): + """ + Create and return VirtualSLM object with a line aperture of desired length. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param length: Length of aperture in meters. + :type length: float + :param vertical: Orient line vertically, defaults to True. + :type vertical: bool, optional + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired line aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # call `create_rect_aperture` + apert_dim = (length, pixel_pitch[1]) if vertical else (pixel_pitch[0], length) + return rect_aperture(slm_shape, pixel_pitch, apert_dim, center) + + +def square_aperture(slm_shape, pixel_pitch, side, center=None): + """ + Create and return VirtualSLM object with a square aperture of desired shape. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param side: Side length of square aperture in meters. + :type side: float + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired square aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + return rect_aperture(slm_shape, pixel_pitch, (side, side), center) + + +def circ_aperture(slm_shape, pixel_pitch, radius, center=None): + """ + Create and return VirtualSLM object with a circle aperture of desired shape. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param radius: Radius of aperture in meters. + :type radius: float + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired circle aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # check input values + assert radius > 0 + + # initialize SLM + slm = Aperture(shape=slm_shape, pixel_pitch=pixel_pitch) + + # check / compute center + if center is None: + center = slm.center + else: + assert ( + 0 <= center[0] < slm.height + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + assert ( + 0 <= center[1] < slm.width + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + + # compute mask + i, j = np.meshgrid( + np.arange(slm.dim[0], step=slm.pixel_pitch[0]), + np.arange(slm.dim[1], step=slm.pixel_pitch[1]), + sparse=True, + indexing="ij", + ) + x2 = (i - center[0]) ** 2 + y2 = (j - center[1]) ** 2 + slm[:] = 255 * (x2 + y2 < radius**2) + return slm + + +def _cell_slice(_slice, cell_m): + """ + Convert slice indexing in meters to slice indexing in cells. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param _slice: Original slice in meters. + :type _slice: slice + :param cell_m: Dimension of cell in meters. + :type cell_m: float + :return: The new slice + :rtype: slice + """ + start = None if _slice.start is None else _m_to_cell_idx(_slice.start, cell_m) + stop = _m_to_cell_idx(_slice.stop, cell_m) if _slice.stop is not None else None + step = _m_to_cell_idx(_slice.step, cell_m) if _slice.step is not None else None + return slice(start, stop, step) + + +def _m_to_cell_idx(val, cell_m): + """ + Convert location to cell index. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param val: Location in meters. + :type val: float + :param cell_m: Dimension of cell in meters. + :type cell_m: float + :return: The cell index. + :rtype: int + """ + return int(val / cell_m) + + +def prepare_index_vals(key, pixel_pitch): + """ + Convert indexing object in meters to indexing object in cell indices. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param key: Indexing operation in meters. + :type key: int, float, slice, or list + :param pixel_pitch: Pixel pitch (height, width) in meters. + :type pixel_pitch: tuple(float) + :raises ValueError: If the key is of the wrong type. + :raises NotImplementedError: If key is of size 3, individual channels can't + be indexed. + :raises ValueError: If the key has the wrong dimensions. + :return: The new indexing object. + :rtype: tuple[slice, int] | tuple[slice, slice] | tuple[slice, ...] + """ + if isinstance(key, (float, int)): + idx = slice(None), _m_to_cell_idx(key, pixel_pitch[0]) + + elif isinstance(key, slice): + idx = slice(None), _cell_slice(key, pixel_pitch[0]) + + elif len(key) == 2: + idx = [slice(None)] + for k, _slice in enumerate(key): + + if isinstance(_slice, slice): + idx.append(_cell_slice(_slice, pixel_pitch[k])) + + elif isinstance(_slice, (float, int)): + idx.append(_m_to_cell_idx(_slice, pixel_pitch[k])) + + else: + raise ValueError("Invalid key.") + idx = tuple(idx) + + elif len(key) == 3: + raise NotImplementedError("Cannot index individual channels.") + + else: + raise ValueError("Invalid key.") + return idx diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index 9cde01b2..126d21f1 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -32,7 +32,6 @@ from waveprop.noise import add_shot_noise from lensless.hardware.sensor import VirtualSensor from lensless.utils.image import resize -from lensless.utils.image import rgb2bayer, bayer2rgb class Mask(abc.ABC): diff --git a/lensless/hardware/sensor.py b/lensless/hardware/sensor.py index 36a5adda..08a00a05 100644 --- a/lensless/hardware/sensor.py +++ b/lensless/hardware/sensor.py @@ -170,6 +170,8 @@ def __init__( else: self.size = self.pixel_size * self.resolution + self.pitch = self.size / self.resolution + self.image_shape = self.resolution if self.color: self.image_shape = np.append(self.image_shape, 3) @@ -298,6 +300,7 @@ def downsample(self, factor): assert factor > 1, "Downsample factor must be greater than 1." self.pixel_size = self.pixel_size * factor + self.pitch = self.pitch * factor self.resolution = (self.resolution / factor).astype(int) self.size = self.pixel_size * self.resolution self.image_shape = self.resolution diff --git a/lensless/hardware/slm.py b/lensless/hardware/slm.py new file mode 100644 index 00000000..572ae4a7 --- /dev/null +++ b/lensless/hardware/slm.py @@ -0,0 +1,298 @@ +# ############################################################################# +# slm.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +import os +import numpy as np +from lensless.hardware.utils import check_username_hostname +from lensless.utils.io import get_dtype, get_ctypes +from slm_controller.hardware import SLMParam, slm_devices +from waveprop.spherical import spherical_prop +from waveprop.color import ColorSystem +from waveprop.rs import angular_spectrum +from waveprop.slm import get_centers, get_color_filter +from waveprop.devices import SLMParam as SLMParam_wp +from scipy.ndimage import rotate as rotate_func + + +try: + import torch + from torchvision import transforms + + torch_available = True +except ImportError: + torch_available = False + + +SUPPORTED_DEVICE = { + "adafruit": "~/slm-controller/examples/adafruit_slm.py", + "nokia": "~/slm-controller/examples/nokia_slm.py", + "holoeye": "~/slm-controller/examples/holoeye_slm.py", +} + + +def set_programmable_mask(pattern, device, rpi_username, rpi_hostname): + """ + Set LCD pattern on Raspberry Pi. + + This function assumes that `slm-controller `_ + is installed on the Raspberry Pi. + + Parameters + ---------- + pattern : :py:class:`~numpy.ndarray` + Pattern to set on programmable mask. + device : str + Name of device to set pattern on. Supported devices: "adafruit", "nokia", "holoeye". + rpi_username : str + Username of Raspberry Pi. + rpi_hostname : str + Hostname of Raspberry Pi. + + """ + + client = check_username_hostname(rpi_username, rpi_hostname) + + # get path to python executable on Raspberry Pi + rpi_python = "~/slm-controller/slm_controller_env/bin/python" + assert ( + device in SUPPORTED_DEVICE.keys() + ), f"Device {device} not supported. Supported devices: {SUPPORTED_DEVICE.keys()}" + script = SUPPORTED_DEVICE[device] + + # check that pattern is correct shape + expected_shape = slm_devices[device][SLMParam.SLM_SHAPE] + if not slm_devices[device][SLMParam.MONOCHROME]: + expected_shape = (3, *expected_shape) + assert ( + pattern.shape == expected_shape + ), f"Pattern shape {pattern.shape} does not match expected shape {expected_shape}" + + # save pattern + pattern_fn = "tmp_pattern.npy" + local_path = os.path.join(os.getcwd(), pattern_fn) + np.save(local_path, pattern) + + # copy pattern to Raspberry Pi + remote_path = f"~/{pattern_fn}" + print(f"PUTTING {local_path} to {remote_path}") + + os.system('scp %s "%s@%s:%s" ' % (local_path, rpi_username, rpi_hostname, remote_path)) + # # -- not sure why this doesn't work... permission denied + # sftp = client.open_sftp() + # sftp.put(local_path, remote_path, confirm=True) + # sftp.close() + + # run script on Raspberry Pi to set mask pattern + command = f"{rpi_python} {script} --file_path {remote_path}" + print(f"COMMAND : {command}") + _stdin, _stdout, _stderr = client.exec_command(command) + print(_stdout.read().decode()) + client.close() + + os.remove(local_path) + + +def get_programmable_mask( + vals, + sensor, + slm_param, + rotate=None, + flipud=False, + nbits=8, +): + """ + Get mask as a numpy or torch array. Return same type. + + Parameters + ---------- + vals : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Values to set on programmable mask. + sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` + Sensor object. + slm_param : dict + SLM parameters. + rotate : float, optional + Rotation angle in degrees. + flipud : bool, optional + Flip mask vertically. + nbits : int, optional + Number of bits/levels to quantize mask to. + + """ + + use_torch = False + if torch_available: + use_torch = isinstance(vals, torch.Tensor) + dtype = vals.dtype + + # -- prepare SLM mask + n_active_slm_pixels = vals.shape + n_color_filter = np.prod(slm_param["color_filter"].shape[:2]) + pixel_pitch = slm_param[SLMParam_wp.PITCH] + centers = get_centers(n_active_slm_pixels, pixel_pitch=pixel_pitch) + + if SLMParam_wp.COLOR_FILTER in slm_param.keys(): + color_filter = slm_param[SLMParam_wp.COLOR_FILTER] + if flipud: + color_filter = np.flipud(color_filter) + + cf = get_color_filter( + slm_dim=n_active_slm_pixels, + color_filter=color_filter, + shift=0, + flat=True, + ) + + else: + + # monochrome + cf = None + + d1 = sensor.pitch + _height_pixel, _width_pixel = (slm_param[SLMParam_wp.CELL_SIZE] / d1).astype(int) + + if use_torch: + mask = torch.zeros((n_color_filter,) + tuple(sensor.resolution)).to(vals) + slm_vals_flat = vals.flatten() + else: + mask = np.zeros((n_color_filter,) + tuple(sensor.resolution), dtype=dtype) + slm_vals_flat = vals.reshape(-1) + + for i, _center in enumerate(centers): + + _center_pixel = (_center / d1 + sensor.resolution / 2).astype(int) + _center_top_left_pixel = ( + _center_pixel[0] - np.floor(_height_pixel / 2).astype(int), + _center_pixel[1] + 1 - np.floor(_width_pixel / 2).astype(int), + ) + + if cf is not None: + _rect = np.tile(cf[i][:, np.newaxis, np.newaxis], (1, _height_pixel, _width_pixel)) + else: + _rect = np.ones((1, _height_pixel, _width_pixel)) + + if use_torch: + _rect = torch.tensor(_rect).to(slm_vals_flat) + + mask[ + :, + _center_top_left_pixel[0] : _center_top_left_pixel[0] + _height_pixel, + _center_top_left_pixel[1] : _center_top_left_pixel[1] + _width_pixel, + ] = ( + slm_vals_flat[i] * _rect + ) + + # quantize mask + if use_torch: + mask = mask / torch.max(mask) + mask = torch.round(mask * (2**nbits - 1)) / (2**nbits - 1) + else: + mask = mask / np.max(mask) + mask = np.round(mask * (2**nbits - 1)) / (2**nbits - 1) + + # rotate + if rotate is not None: + if use_torch: + mask = transforms.functional.rotate(mask, angle=rotate) + else: + mask = rotate_func(mask, axes=(2, 1), angle=rotate, reshape=False) + + return mask + + +def get_intensity_psf( + mask, + waveprop=False, + sensor=None, + scene2mask=None, + mask2sensor=None, + color_system=None, +): + """ + Get intensity PSF from mask pattern. Return same type of data. + + Parameters + ---------- + mask : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Mask pattern. + waveprop : bool, optional + Whether to use wave propagation to compute PSF. Default is False, + namely to return squared intensity of mask pattern as the PSF (i.e., + no wave propagation and just shadow of pattern). + sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` + Sensor object. Not used if ``waveprop=False``. + scene2mask : float + Distance from scene to mask. Not used if ``waveprop=False``. + mask2sensor : float + Distance from mask to sensor. Not used if ``waveprop=False``. + color_system : :py:class:`~waveprop.color.ColorSystem`, optional + Color system. Not used if ``waveprop=False``. + + """ + if color_system is None: + color_system = ColorSystem.rgb() + + is_torch = False + device = None + if torch_available: + is_torch = isinstance(mask, torch.Tensor) + device = mask.device + + dtype = mask.dtype + ctype, _ = get_ctypes(dtype, is_torch) + + if is_torch: + psfs = torch.zeros(mask.shape, dtype=ctype, device=device) + else: + psfs = np.zeros(mask.shape, dtype=ctype) + + if waveprop: + + assert sensor is not None, "sensor must be specified" + assert scene2mask is not None, "scene2mask must be specified" + assert mask2sensor is not None, "mask2sensor must be specified" + + assert ( + len(color_system.wv) == mask.shape[0] + ), "Number of wavelengths must match number of color channels" + + # spherical wavefronts to mask + spherical_wavefront = spherical_prop( + in_shape=sensor.resolution, + d1=sensor.pitch, + wv=color_system.wv, + dz=scene2mask, + return_psf=True, + is_torch=True, + device=device, + dtype=dtype, + ) + u_in = spherical_wavefront * mask + + # free space propagation to sensor + for i, wv in enumerate(color_system.wv): + psfs[i], _, _ = angular_spectrum( + u_in=u_in[i], + wv=wv, + d1=sensor.pitch, + dz=mask2sensor, + dtype=dtype, + device=device, + ) + + else: + + psfs = mask + + # -- intensity PSF + if is_torch: + psf_in = torch.square(torch.abs(psfs)) + else: + psf_in = np.square(np.abs(psfs)) + + return psf_in diff --git a/lensless/hardware/utils.py b/lensless/hardware/utils.py index a0c0d573..97b384f6 100644 --- a/lensless/hardware/utils.py +++ b/lensless/hardware/utils.py @@ -2,6 +2,7 @@ import os import socket import subprocess +import time import paramiko from paramiko.ssh_exception import AuthenticationException, BadHostKeyException, SSHException @@ -65,7 +66,7 @@ def check_username_hostname(username, hostname, timeout=10): except (BadHostKeyException, AuthenticationException, SSHException, socket.error) as e: raise ValueError(f"Could not connect to {username}@{hostname}\n{e}") - return username, hostname + return client def get_distro(): @@ -92,3 +93,62 @@ def get_distro(): # Just major version shown, replace it with the full version RELEASE_DATA["VERSION"] = " ".join([DEBIAN_VERSION] + version_split[1:]) return f"{RELEASE_DATA['NAME']} {RELEASE_DATA['VERSION']}" + + +def set_mask_sensor_distance(distance, rpi_username, rpi_hostname, motor=1): + """ + Set the distance between the mask and sensor. + + This functions assumes that `StepperDriver `_ is installed. + is downloaded on the Raspberry Pi. + + Parameters + ---------- + distance : float + Distance in mm. Positive values move the mask away from the sensor. + rpi_username : str + Username of Raspberry Pi. + rpi_hostname : str + Hostname of Raspberry Pi. + """ + + MAX_DISTANCE = 16 # mm + timeout = 5 + + client = check_username_hostname(rpi_username, rpi_hostname) + assert motor in [0, 1] + assert distance >= 0, "Distance must be non-negative" + assert distance < MAX_DISTANCE, f"Distance must be less than {MAX_DISTANCE} mm" + + # assumes that `StepperDriver` is in home directory + rpi_python = "python3" + script = "~/StepperDriver/Python/serial_motors.py" + + # reset to zero + print("Resetting to zero distance...") + try: + command = f"{rpi_python} {script} {motor} REV {MAX_DISTANCE * 1000}" + _stdin, _stdout, _stderr = client.exec_command(command, timeout=timeout) + except socket.timeout: # socket.timeout + pass + + client.close() + time.sleep(5) # TODO reduce this time + client = check_username_hostname(rpi_username, rpi_hostname) + + # set to desired distance + if distance != 0: + print(f"Setting distance to {distance} mm...") + distance_um = distance * 1000 + if distance_um >= 0: + command = f"{rpi_python} {script} {motor} FWD {distance_um}" + else: + command = f"{rpi_python} {script} {motor} REV {-1 * distance_um}" + print(f"COMMAND : {command}") + try: + _stdin, _stdout, _stderr = client.exec_command(command, timeout=timeout) + print(_stdout.read().decode()) + except socket.timeout: # socket.timeout + client.close() + + client.close() diff --git a/lensless/recon/apgd.py b/lensless/recon/apgd.py index 2ae5a69d..327c32de 100644 --- a/lensless/recon/apgd.py +++ b/lensless/recon/apgd.py @@ -11,7 +11,9 @@ import inspect import numpy as np from typing import Optional +from lensless.utils.image import resize from lensless.recon.rfft_convolve import RealFFTConvolve2D as Convolver +import cv2 import pycsou.abc as pyca import pycsou.operator.func as func @@ -20,6 +22,7 @@ import pycsou.runtime as pycrt import pycsou.util as pycu import pycsou.util.ptype as pyct +import pycsou.operator.linop as pycl class APGDPriors: @@ -95,6 +98,7 @@ def __init__( rel_error=None, lipschitz_tight=True, lipschitz_tol=1.0, + img_shape=None, **kwargs ): """ @@ -132,27 +136,52 @@ def __init__( Whether to use tight Lipschitz constant or not. Default is True. lipschitz_tol : float, optional Tolerance to compute Lipschitz constant. Default is 1. + img_shape : tuple, optional + Shape of measurement (H, W, C). If None, assume shape of PSF. """ assert isinstance(psf, np.ndarray), "PSF must be a numpy array" - # PSF and data are the same size / shape self._original_shape = psf.shape - self._original_size = psf.size - self._apgd = None - self._gen = None - - super(APGD, self).__init__(psf, dtype, n_iter=max_iter, **kwargs) self._stop_crit = stop.MaxIter(max_iter) if rel_error is not None: self._stop_crit = self._stop_crit | stop.RelError(eps=rel_error) self._disp = disp - # Convolution operator + # Convolution (and optional downsampling) operator + if img_shape is not None: + + meas_shape = np.array(img_shape[:2]) + rec_shape = np.array(self._original_shape[1:3]) + assert np.all(meas_shape <= rec_shape), "Image shape must be smaller than PSF shape" + self.downsampling_factor = np.round(rec_shape / meas_shape).astype(int) + + # new PSF shape, must be integer multiple of image shape + new_shape = tuple(np.array(meas_shape) * self.downsampling_factor) + (psf.shape[-1],) + psf_re = resize(psf.copy(), shape=new_shape, interpolation=cv2.INTER_CUBIC) + + # combine operations + conv = RealFFTConvolve2D(psf_re, dtype=dtype) + ds = pycl.SubSample( + psf_re.shape, + slice(None), + slice(0, -1, self.downsampling_factor[0]), + slice(0, -1, self.downsampling_factor[1]), + slice(None), + ) + + self._H = ds * conv + + super(APGD, self).__init__(psf_re, dtype, n_iter=max_iter, **kwargs) + + else: + self.downsampling_factor = 1 + self._H = RealFFTConvolve2D(psf, dtype=dtype) + + super(APGD, self).__init__(psf, dtype, n_iter=max_iter, **kwargs) - self._H = RealFFTConvolve2D(self._psf, dtype=dtype) self._H.lipschitz(tol=lipschitz_tol, tight=lipschitz_tight) # initialize solvers which will be created when data is set @@ -192,9 +221,25 @@ def set_data(self, data): 3D (RGB). """ - super(APGD, self).set_data( - np.repeat(data, self._original_shape[-4], axis=0) - ) # we repeat the data for each depth to match the size of the PSF + + # super(APGD, self).set_data( + # np.repeat(data, self._original_shape[-4], axis=0) + # ) # we repeat the data for each depth to match the size of the PSF + + data = np.repeat(data, self._original_shape[-4], axis=0) # repeat for each depth + assert isinstance(data, np.ndarray) + assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]." + + assert np.all( + self._psf_shape[-3:-1] == (np.array(data.shape)[-3:-1] * self.downsampling_factor) + ), "PSF and data shape mismatch" + + if len(data.shape) == 3: + self._data = data[None, None, ...] + elif len(data.shape) == 4: + self._data = data[None, ...] + else: + self._data = data """ Set up problem """ # Cost function @@ -220,13 +265,15 @@ def reset(self): if self._initial_est is not None: self._image_est = self._initial_est else: - self._image_est = np.zeros(self._original_size, dtype=self._dtype) + self._image_est = np.zeros(np.prod(self._psf_shape), dtype=self._dtype) def _update(self, iter): res = next(self._apgd.steps()) self._image_est[:] = res["x"] def _form_image(self): - image = self._image_est.reshape(self._original_shape) + image = self._image_est.reshape(self._psf_shape) image[image < 0] = 0 + if np.any(self._psf_shape != self._original_shape): + image = resize(image, shape=self._original_shape) return image diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 58200f2a..c035f4cd 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -154,6 +154,7 @@ class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction import pathlib as plib import matplotlib.pyplot as plt from lensless.utils.plot import plot_image +from lensless.utils.io import get_dtype from lensless.recon.rfft_convolve import RealFFTConvolve2D try: @@ -232,16 +233,7 @@ def __init__( self._psf_shape = np.array(self._psf.shape) # set dtype - if dtype is None: - if self.is_torch: - dtype = torch.float32 - else: - dtype = np.float32 - else: - if self.is_torch: - dtype = torch.float32 if dtype == "float32" else torch.float64 - else: - dtype = np.float32 if dtype == "float32" else np.float64 + dtype = get_dtype(dtype, self.is_torch) if self.is_torch: if dtype: @@ -491,7 +483,9 @@ def apply( if (plot or save) and disp_iter is not None: if ax is None: - ax = plot_image(self._get_numpy_data(self._image_est[0]), gamma=gamma) + img = self._form_image() + ax = plot_image(self._get_numpy_data(img[0]), gamma=gamma) + else: ax = None disp_iter = n_iter + 1 diff --git a/lensless/utils/image.py b/lensless/utils/image.py index 7d2c65b3..19c977e2 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -1,5 +1,5 @@ # ############################################################################# -# image_utils.py +# image.py # ================= # Authors : # Eric BEZZAM [ebezzam@gmail.com] diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 57c4f740..f502719a 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -1,3 +1,11 @@ +# ############################################################################# +# io.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + import warnings from PIL import Image import cv2 @@ -6,7 +14,7 @@ from lensless.utils.plot import plot_image from lensless.hardware.constants import RPI_HQ_CAMERA_BLACK_LEVEL, RPI_HQ_CAMERA_CCM_MATRIX -from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray +from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray, get_max_val def load_image( @@ -22,6 +30,10 @@ def load_image( nbits_out=None, as_4d=False, downsample=None, + bg=None, + return_float=False, + shape=None, + dtype=None, ): """ Load image as numpy array. @@ -53,6 +65,15 @@ def load_image( height, width, color). downsample : int, optional Downsampling factor. Recommended for image reconstruction. + bg : array_like + Background level to subtract. + return_float : bool + Whether to return image as float array, or unsigned int. + shape : tuple, optional + Shape (H, W, C) to resize to. + dtype : str, optional + Data type of returned data. Default is to use that of input. + Returns ------- img : :py:class:`~numpy.ndarray` @@ -103,6 +124,8 @@ def load_image( if len(img.shape) == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + original_dtype = img.dtype + if flip: img = np.flipud(img) img = np.fliplr(img) @@ -110,14 +133,39 @@ def load_image( if verbose: print_image_info(img) + if bg is not None: + + # if bg is float vector, turn into int-valued vector + if bg.max() <= 1: + bg = bg * get_max_val(img) + + img = img - bg + img = np.clip(img, a_min=0, a_max=img.max()) + if as_4d: if len(img.shape) == 3: img = img[np.newaxis, :, :, :] elif len(img.shape) == 2: img = img[np.newaxis, :, :, np.newaxis] - if downsample is not None: - img = resize(img, factor=1 / downsample) + if downsample is not None or shape is not None: + if downsample is not None: + factor = 1 / downsample + else: + factor = None + img = resize(img, factor=factor, shape=shape) + + if return_float: + if dtype is None: + dtype = np.float32 + assert dtype == np.float32 or dtype == np.float64 + img = img.astype(dtype) + img /= img.max() + + else: + if dtype is None: + dtype = original_dtype + img = img.astype(dtype) return img @@ -212,6 +260,7 @@ def load_psf( ) original_dtype = psf.dtype + max_val = get_max_val(psf) psf = np.array(psf, dtype=dtype) if use_3d: @@ -274,6 +323,7 @@ def load_psf( if return_float: # psf /= psf.max() psf /= np.linalg.norm(psf.ravel()) + bg /= max_val else: psf = psf.astype(original_dtype) @@ -379,21 +429,21 @@ def load_data( ) # load and process raw measurement - data = load_image(data_fp, flip=flip, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain) - data = np.array(data, dtype=dtype) - - data -= bg - data = np.clip(data, a_min=0, a_max=data.max()) - - if len(data.shape) == 3: - data = data[np.newaxis, :, :, :] - elif len(data.shape) == 2: - data = data[np.newaxis, :, :, np.newaxis] + data = load_image( + data_fp, + flip=flip, + bayer=bayer, + blue_gain=blue_gain, + red_gain=red_gain, + bg=bg, + as_4d=True, + return_float=True, + shape=shape, + ) if data.shape != psf.shape: # in DiffuserCam dataset, images are already reshaped data = resize(data, shape=psf.shape) - data /= np.linalg.norm(data.ravel()) if data.shape[3] > 1 and psf.shape[3] == 1: warnings.warn( @@ -454,3 +504,58 @@ def save_image(img, fp, max_val=255): img = Image.fromarray(img) img.save(fp) + + +def get_dtype(dtype=None, is_torch=False): + """ + Get dtype for numpy or torch. + + Parameters + ---------- + dtype : str, optional + "float32" or "float64", Default is "float32". + is_torch : bool, optional + Whether to return torch dtype. + """ + if dtype is None: + dtype = "float32" + assert dtype == "float32" or dtype == "float64" + + if is_torch: + import torch + + if dtype is None: + if is_torch: + dtype = torch.float32 + else: + dtype = np.float32 + else: + if is_torch: + dtype = torch.float32 if dtype == "float32" else torch.float64 + else: + dtype = np.float32 if dtype == "float32" else np.float64 + + return dtype + + +def get_ctypes(dtype, is_torch): + if not is_torch: + if dtype == np.float32 or dtype == np.complex64: + return np.complex64, np.complex64 + elif dtype == np.float64 or dtype == np.complex128: + return np.complex128, np.complex128 + else: + raise ValueError("Unexpected dtype: ", dtype) + else: + import torch + + if dtype == np.float32 or dtype == np.complex64: + return torch.complex64, np.complex64 + elif dtype == np.float64 or dtype == np.complex128: + return torch.complex128, np.complex128 + elif dtype == torch.float32 or dtype == torch.complex64: + return torch.complex64, np.complex64 + elif dtype == torch.float64 or dtype == torch.complex128: + return torch.complex128, np.complex128 + else: + raise ValueError("Unexpected dtype: ", dtype) diff --git a/recon_requirements.txt b/recon_requirements.txt index 5d142936..4ebe4412 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -7,6 +7,6 @@ click>=8.0.1 waveprop>=0.0.3 # for simulation # Library for learning algorithm -torch >= 1.8.0 +torch >= 2.0.0 torchvision lpips \ No newline at end of file diff --git a/scripts/classify/train_celeba_vit.py b/scripts/classify/train_celeba_vit.py new file mode 100644 index 00000000..79a32e44 --- /dev/null +++ b/scripts/classify/train_celeba_vit.py @@ -0,0 +1,330 @@ +""" +Fine-tune ViT on CelebA dataset measured with lensless camera. +Original tutorial: https://huggingface.co/blog/fine-tune-vit + +First, set-up HuggingFace libraries: +``` +pip install datasets transformers +``` + +Raw measurement datasets can be download from SwitchDrive. +This will be done by the script if the dataset is not found. +``` +# 10K measurements (13.1 GB) +python scripts/classify/train_celeba_vit.py \ +data.measured=data/celeba_adafruit_random_2mm_20230720_10K + +# 1K measurements (1.2 GB) +python scripts/classify/train_celeba_vit.py \ +data.measured=data/celeba_adafruit_random_2mm_20230720_1K +``` + +Note that the CelebA dataset also needs to be available locally! +It can be download here: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + +In order to classify on reconstructed outputs, the following +script needs to be run to create the dataset of reconstructed +images: +``` +# reconstruct with ADMM +python scripts/recon/dataset.py algo=admm \ +input.raw_data=path/to/raw/data +``` + +To classify on raw downsampled images, the same script can be +used, e.g. with the following command (`algo=null` for no reconstruction): +``` +python scripts/recon/dataset.py algo=null \ +input.raw_data=path/to/raw/data \ +preprocess.data_dim=[48,64] +``` + +Other hyperparameters for classification can be found in +`configs/train_celeba_classifier.yaml`. + +""" + +import warnings +from transformers import ViTImageProcessor, ViTForImageClassification +from transformers import TrainingArguments, Trainer, TrainerCallback +import numpy as np +import torch +import os +from hydra.utils import to_absolute_path +import glob +import hydra +import random +from datasets import load_metric +from PIL import Image +import pandas as pd +import time +import torchvision.transforms as transforms +import torchvision.datasets as dset +from datasets import Dataset +from copy import deepcopy +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor, +) + + +class CustomCallback(TrainerCallback): + def __init__(self, trainer) -> None: + super().__init__() + self._trainer = trainer + + def on_epoch_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + def on_step_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + def on_train_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + +@hydra.main(version_base=None, config_path="../../configs", config_name="train_celeba_classifier") +def train_celeba_classifier(config): + + seed = config.seed + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + # check how many measured files + measured_dataset = to_absolute_path(config.data.measured) + if not os.path.isdir(measured_dataset): + print(f"No dataset found at {measured_dataset}") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the CelebA dataset measured with a random Adafruit LCD pattern (13.1 GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/9NNGCJs3DoBDGlY/download" + filename = "celeba_adafruit_random_2mm_20230720_10K.zip" + download_and_extract_archive( + url, os.path.dirname(measured_dataset), filename=filename, remove_finished=True + ) + measured_files = sorted(glob.glob(os.path.join(measured_dataset, "*.png"))) + print(f"Found {len(measured_files)} files in {measured_dataset}") + + if config.data.n_files is not None: + n_files = config.data.n_files + measured_files = measured_files[: config.data.n_files] + print(f"Using {len(measured_files)} files") + n_files = len(measured_files) + + # create dataset split + attr = config.data.attr + ds = dset.CelebA( + root=config.data.original, + split="all", + download=False, + transform=transforms.ToTensor(), + ) + label_idx = ds.attr_names.index(attr) + labels = ds.attr[:, label_idx][:n_files] + + # make dataset with measured data and corresponding labels + df = pd.DataFrame( + { + "labels": labels, + "image_file_path": measured_files, + } + ) + ds = Dataset.from_pandas(df) + ds = ds.class_encode_column("labels") + + # -- train / test split + test_size = config.data.test_size + ds = ds.train_test_split( + test_size=test_size, stratify_by_column="labels", seed=seed, shuffle=True + ) + + # prepare dataset + model_name_or_path = "google/vit-base-patch16-224-in21k" + processor = ViTImageProcessor.from_pretrained(model_name_or_path) + + # -- processors for train and val + image_mean, image_std = processor.image_mean, processor.image_std + size = processor.size["height"] + + normalize = Normalize(mean=image_mean, std=image_std) + # _train_transforms = Compose( + # [ + # # RandomResizedCrop( + # # size, + # # scale=(0.9, 1.0), + # # ratio=(0.9, 1.1), + # # ), + # Resize(size), + # CenterCrop(size), + # RandomHorizontalFlip(), + # ToTensor(), + # normalize, + # ] + # ) + _train_transforms = [] + if config.augmentation.random_resize_crop: + _train_transforms.append( + RandomResizedCrop( + size, + scale=(0.9, 1.0), + ratio=(0.9, 1.1), + ) + ) + _train_transforms.append( + Resize(size), + CenterCrop(size), + ) + if config.augmentation.horizontal_flip: + if config.data.raw: + warnings.warn("Horizontal flip is not supported for raw data, Skipping!") + else: + _train_transforms.append(RandomHorizontalFlip()) + _train_transforms.append( + ToTensor(), + normalize, + ) + _train_transforms = Compose(_train_transforms) + + _val_transforms = Compose( + [ + Resize(size), + CenterCrop(size), + ToTensor(), + normalize, + ] + ) + + def train_transforms(examples): + # Take a list of PIL images and turn them to pixel values + examples["pixel_values"] = [ + _train_transforms(Image.open(fp)) for fp in examples["image_file_path"] + ] + return examples + + def val_transforms(examples): + # Take a list of PIL images and turn them to pixel values + examples["pixel_values"] = [ + _val_transforms(Image.open(fp)) for fp in examples["image_file_path"] + ] + return examples + + # transform dataset + ds["train"].set_transform(train_transforms) + ds["test"].set_transform(val_transforms) + + # data collator + def collate_fn(batch): + return { + "pixel_values": torch.stack([x["pixel_values"] for x in batch]), + "labels": torch.tensor([x["labels"] for x in batch]), + } + + # evaluation metric + metric = load_metric("accuracy") + + def compute_metrics(p): + return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids) + + # load model + if config.train.prev is not None: + model_path = to_absolute_path(config.train.prev) + else: + model_path = model_name_or_path + + labels = ds["train"].features["labels"].names + model = ViTForImageClassification.from_pretrained( + model_path, + num_labels=len(labels), + id2label={str(i): c for i, c in enumerate(labels)}, + label2id={c: str(i) for i, c in enumerate(labels)}, + hidden_dropout_prob=config.train.dropout, + attention_probs_dropout_prob=config.train.dropout, + ) + + # configure training + output_dir = ( + config.data.output_dir + f"-{config.data.attr}" + os.path.basename(measured_dataset) + ) + + training_args = TrainingArguments( + output_dir=output_dir, + per_device_train_batch_size=config.train.batch_size, + evaluation_strategy="steps", + eval_steps=100, + save_steps=100, + num_train_epochs=config.train.n_epochs, + fp16=True, + logging_steps=10, + learning_rate=config.train.learning_rate, + save_total_limit=2, + remove_unused_columns=False, # important to keep False + push_to_hub=False, + report_to="tensorboard", + load_best_model_at_end=True, + ) + + trainer = Trainer( + model=model, + args=training_args, + data_collator=collate_fn, + compute_metrics=compute_metrics, + tokenizer=processor, + train_dataset=ds["train"], + eval_dataset=ds["test"], + ) + trainer.add_callback(CustomCallback(trainer)) # add accuracy on train set + + # train + hydra_output = os.getcwd() + print("Results saved to : ", hydra_output) + + start_time = time.time() + train_results = trainer.train() + trainer.save_model() + trainer.log_metrics("train", train_results.metrics) + trainer.save_metrics("train", train_results.metrics) + trainer.save_state() + + # evaluate + metrics = trainer.evaluate(ds["train"]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + metrics = trainer.evaluate(ds["test"]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # + hydra_output = os.getcwd() + print("Results saved to : ", hydra_output) + print(f"Training took {time.time() - start_time} seconds") + + +if __name__ == "__main__": + train_celeba_classifier() diff --git a/scripts/demo.py b/scripts/demo.py index 32b26e42..760b663a 100644 --- a/scripts/demo.py +++ b/scripts/demo.py @@ -18,7 +18,9 @@ @hydra.main(version_base=None, config_path="../configs", config_name="demo") def demo(config): - RPI_USERNAME, RPI_HOSTNAME = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + RPI_USERNAME = config.rpi.username + RPI_HOSTNAME = config.rpi.hostname display_fp = to_absolute_path(config.fp) if config.save: diff --git a/scripts/hardware/config_digicam.py b/scripts/hardware/config_digicam.py new file mode 100644 index 00000000..cd8cab86 --- /dev/null +++ b/scripts/hardware/config_digicam.py @@ -0,0 +1,101 @@ +import warnings +import hydra +from datetime import datetime +import numpy as np +from slm_controller import slm +from slm_controller.hardware import SLMParam, slm_devices +import matplotlib.pyplot as plt + +from lensless.hardware.slm import set_programmable_mask +from lensless.hardware.aperture import rect_aperture, circ_aperture +from lensless.hardware.utils import set_mask_sensor_distance + + +@hydra.main(version_base=None, config_path="../../configs", config_name="digicam") +def config_digicam(config): + + rpi_username = config.rpi.username + rpi_hostname = config.rpi.hostname + device = config.device + + shape = slm_devices[device][SLMParam.SLM_SHAPE] + if not slm_devices[device][SLMParam.MONOCHROME]: + shape = (3, *shape) + pixel_pitch = slm_devices[device][SLMParam.PIXEL_PITCH] + + # set mask to sensor distance + if config.z is not None and not config.virtual: + set_mask_sensor_distance(config.z, rpi_username, rpi_hostname) + + center = np.array(config.center) * pixel_pitch + + # create random pattern + pattern = None + if config.pattern.endswith(".npy"): + pattern = np.load(config.pattern) + elif config.pattern == "random": + rng = np.random.RandomState(1) + # pattern = rng.randint(low=0, high=np.iinfo(np.uint8).max, size=shape, dtype=np.uint8) + pattern = rng.uniform(low=config.min_val, high=1, size=shape) + pattern = (pattern * np.iinfo(np.uint8).max).astype(np.uint8) + + elif config.pattern == "rect": + rect_shape = config.rect_shape + apert_dim = rect_shape[0] * pixel_pitch[0], rect_shape[1] * pixel_pitch[1] + ap = rect_aperture( + apert_dim=apert_dim, + slm_shape=slm_devices[device][SLMParam.SLM_SHAPE], + pixel_pitch=pixel_pitch, + center=center, + ) + pattern = ap.values + elif config.pattern == "circ": + ap = circ_aperture( + radius=config.radius * pixel_pitch[0], + slm_shape=slm_devices[device][SLMParam.SLM_SHAPE], + pixel_pitch=pixel_pitch, + center=center, + ) + pattern = ap.values + else: + raise ValueError(f"Pattern {config.pattern} not supported") + + # save pattern + if not config.pattern.endswith(".npy") and config.save: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + pattern_fn = f"{device}_{config.pattern}_pattern_{timestamp}.npy" + np.save(pattern_fn, pattern) + print(f"Saved pattern to {pattern_fn}") + + print("Pattern shape : ", pattern.shape) + print("Pattern dtype : ", pattern.dtype) + print("Pattern min : ", pattern.min()) + print("Pattern max : ", pattern.max()) + + # apply aperture + if config.aperture is not None: + + aperture = np.zeros(shape, dtype=np.uint8) + top_left = np.array(config.aperture.center) - np.array(config.aperture.shape) // 2 + bottom_right = top_left + np.array(config.aperture.shape) + aperture[:, top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]] = 1 + pattern = pattern * aperture + + assert pattern is not None + + n_nonzero = np.count_nonzero(pattern) + print(f"Nonzero pixels: {n_nonzero}") + + if not config.virtual: + set_programmable_mask(pattern, device, rpi_username, rpi_hostname) + + # preview mask + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = slm.create(device) + s._show_preview(pattern) + plt.savefig("preview.png") + + +if __name__ == "__main__": + config_digicam() diff --git a/scripts/hardware/digicam_measure_psfs.py b/scripts/hardware/digicam_measure_psfs.py new file mode 100644 index 00000000..901d24cb --- /dev/null +++ b/scripts/hardware/digicam_measure_psfs.py @@ -0,0 +1,60 @@ +import numpy as np +from lensless.hardware.utils import set_mask_sensor_distance +import hydra +import os +from datetime import datetime +from PIL import Image + +SATURATION_THRESHOLD = 0.01 + + +@hydra.main(version_base=None, config_path="../../configs", config_name="digicam") +def config_digicam(config): + + rpi_username = config.rpi.username + rpi_hostname = config.rpi.hostname + + mask_sensor_distances = np.arange(9) * 0.1 + exposure_time = 5 + + timestamp = datetime.now().strftime("%Y%m%d") + + for i in range(len(mask_sensor_distances)): + + print(f"Mask sensor distance: {mask_sensor_distances[i]}mm") + mask_sensor_distance = mask_sensor_distances[i] + + # set the mask sensor distance + set_mask_sensor_distance(mask_sensor_distance, rpi_username, rpi_hostname) + + good_exposure = False + while not good_exposure: + + # measure PSF + output_folder = f"adafruit_psf_{mask_sensor_distance}mm__{timestamp}" + os.system( + f"python scripts/remote_capture.py -cn capture_bayer output={output_folder} rpi.username={rpi_username} rpi.hostname={rpi_hostname} capture.exp={exposure_time}" + ) + + # check for saturation + OUTPUT_FP = os.path.join(output_folder, "raw_data.png") + # -- load picture to check for saturation + img = np.array(Image.open(OUTPUT_FP)) + ratio = np.sum(img == 4095) / np.prod(img.shape) + print(f"Saturation ratio: {ratio}") + if ratio > SATURATION_THRESHOLD or ratio == 0: + + if ratio == 0: + print("Need to increase exposure time.") + else: + print("Need to decrease exposure time.") + + # enter new exposure time from keyboard + exposure_time = float(input("Enter new exposure time: ")) + + else: + good_exposure = True + + +if __name__ == "__main__": + config_digicam() diff --git a/scripts/measure/remote_capture.py b/scripts/measure/remote_capture.py index 92f2033e..66210a86 100644 --- a/scripts/measure/remote_capture.py +++ b/scripts/measure/remote_capture.py @@ -32,7 +32,9 @@ def liveview(config): rgb = config.capture.rgb gray = config.capture.gray - username, hostname = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + username = config.rpi.username + hostname = config.rpi.hostname legacy = config.capture.legacy nbits_out = config.capture.nbits_out fn = config.capture.raw_data_fn diff --git a/scripts/measure/remote_display.py b/scripts/measure/remote_display.py index f9ab3ed2..1be931a3 100644 --- a/scripts/measure/remote_display.py +++ b/scripts/measure/remote_display.py @@ -35,12 +35,15 @@ @hydra.main(version_base=None, config_path="../../configs", config_name="demo") def remote_display(config): - username, hostname = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + username = config.rpi.username + hostname = config.rpi.hostname fp = config.fp shape = np.array(config.display.screen_res) psf = config.display.psf black = config.display.black + white = config.display.white if psf: point_source = np.zeros(tuple(shape) + (3,)) @@ -58,12 +61,18 @@ def remote_display(config): im = Image.fromarray(point_source.astype("uint8"), "RGB") im.save(fp) + elif white: + point_source = np.ones(tuple(shape) + (3,)) * 255 + fp = "tmp_display.png" + im = Image.fromarray(point_source.astype("uint8"), "RGB") + im.save(fp) + """ processing on remote machine, less issues with copying """ # copy picture to Raspberry Pi print("\nCopying over picture...") display(fp=fp, rpi_username=username, rpi_hostname=hostname, **config.display) - if psf or black: + if psf or black or white: os.remove(fp) diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 17a88461..3ba3de1f 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -41,6 +41,7 @@ def admm(config): shape=config["preprocess"]["shape"], torch=config.torch, torch_device=config.torch_device, + bg_pix=config.preprocess.bg_pix, ) disp = config["display"]["disp"] diff --git a/scripts/recon/apgd_pycsou.py b/scripts/recon/apgd_pycsou.py index 878b378f..0bf236d0 100644 --- a/scripts/recon/apgd_pycsou.py +++ b/scripts/recon/apgd_pycsou.py @@ -17,7 +17,7 @@ import time import matplotlib.pyplot as plt from lensless.utils.io import load_data -from lensless import APGD +from lensless.recon.apgd import APGD import os import pathlib as plib @@ -28,7 +28,7 @@ log = logging.getLogger(__name__) -@hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") +@hydra.main(version_base=None, config_path="../../configs", config_name="apgd_l1") def apgd( config, ): diff --git a/scripts/recon/dataset.py b/scripts/recon/dataset.py new file mode 100644 index 00000000..4c4192c5 --- /dev/null +++ b/scripts/recon/dataset.py @@ -0,0 +1,202 @@ +""" +Apply ADMM reconstruction to folder. + +``` +python scripts/recon/dataset.py +``` + +To run APGD, use the following command: +``` +python scripts/recon/dataset.py algo=apgd +``` + +To just copy resized raw data, use the following command: +``` +python scripts/recon/dataset.py algo=null preprocess.data_dim=[48,64] +``` + +""" + +import hydra +from hydra.utils import to_absolute_path +import os +import time +import numpy as np +from lensless.utils.io import load_psf, load_image, save_image +from lensless import ADMM +import torch +import glob +from tqdm import tqdm +from lensless.recon.apgd import APGD +from joblib import Parallel, delayed + + +@hydra.main(version_base=None, config_path="../../configs", config_name="recon_dataset") +def admm_dataset(config): + + algo = config.algo + + # get raw data file paths + dataset = to_absolute_path(config.input.raw_data) + if not os.path.isdir(dataset): + print(f"No dataset found at {dataset}") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the sample CelebA dataset measured with a random Adafruit LCD pattern (1.2 GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/m89D1tFEfktQueS/download" + filename = "celeba_adafruit_random_2mm_20230720_1K.zip" + download_and_extract_archive( + url, os.path.dirname(dataset), filename=filename, remove_finished=True + ) + data_fps = sorted(glob.glob(os.path.join(dataset, "*.png"))) + if config.n_files is not None: + data_fps = data_fps[: config.n_files] + n_files = len(data_fps) + + # load PSF + psf_fp = to_absolute_path(config.input.psf) + flip = config.preprocess.flip + dtype = config.input.dtype + print("\nPSF:") + psf, bg = load_psf( + psf_fp, + verbose=True, + downsample=config.preprocess.downsample, + return_bg=True, + flip=flip, + dtype=dtype, + ) + print(f"Downsampled PSF shape: {psf.shape}") + + data_dim = None + if config.preprocess.data_dim is not None: + data_dim = tuple(config.preprocess.data_dim) + (psf.shape[-1],) + else: + data_dim = psf.shape + + # -- create output folder + output_folder = to_absolute_path(config.output_folder) + if algo == "apgd": + output_folder = output_folder + f"_apgd{config.apgd.max_iter}" + elif algo == "admm": + output_folder = output_folder + f"_admm{config.admm.n_iter}" + else: + output_folder = output_folder + "_raw" + output_folder = output_folder + f"_{data_dim[-3]}x{data_dim[-2]}" + os.makedirs(output_folder, exist_ok=True) + + # -- apply reconstruction + if algo == "apgd": + + start_time = time.time() + + def recover(i): + + # reconstruction object + recon = APGD(psf=psf, **config.apgd) + + data_fp = data_fps[i] + + # load data + data = load_image( + data_fp, flip=flip, bg=bg, as_4d=True, return_float=True, shape=data_dim + ) + data = data[0] # first depth + + # apply reconstruction + recon.set_data(data) + img = recon.apply( + disp_iter=config.display.disp, + gamma=config.display.gamma, + plot=config.display.plot, + ) + + # -- extract region of interest and save + if config.roi is not None: + roi = config.roi + img = img[roi[0] : roi[2], roi[1] : roi[3]] + + bn = os.path.basename(data_fp) + output_fp = os.path.join(output_folder, bn) + save_image(img, output_fp) + + n_jobs = config.apgd.n_jobs + if n_jobs > 1: + Parallel(n_jobs=n_jobs)(delayed(recover)(i) for i in range(n_files)) + else: + for i in tqdm(range(n_files)): + recover(i) + + else: + + if config.torch: + torch_dtype = torch.float32 + torch_device = config.torch_device + psf = torch.from_numpy(psf).type(torch_dtype).to(torch_device) + + # create reconstruction object + recon = None + if config.algo == "admm": + recon = ADMM(psf, **config.admm) + + # loop over files and apply reconstruction + start_time = time.time() + + for i in tqdm(range(n_files)): + data_fp = data_fps[i] + + # load data + data = load_image( + data_fp, flip=flip, bg=bg, as_4d=True, return_float=True, shape=data_dim + ) + + if config.torch: + data = torch.from_numpy(data).type(torch_dtype).to(torch_device) + + if recon is not None: + + # set data + recon.set_data(data) + + # apply reconstruction + res = recon.apply( + n_iter=config.admm.n_iter, + disp_iter=config.display.disp, + gamma=config.display.gamma, + plot=config.display.plot, + ) + + else: + + # copy resized raw data + res = data + + # save reconstruction as PNG + # -- take first depth + if config.torch: + img = res[0].cpu().numpy() + else: + img = res[0] + + # -- extract region of interest + if config.roi is not None: + img = img[config.roi[0] : config.roi[2], config.roi[1] : config.roi[3]] + + bn = os.path.basename(data_fp) + output_fp = os.path.join(output_folder, bn) + save_image(img, output_fp) + + print(f"Processing time : {time.time() - start_time} s") + # time per file + print(f"Time per file : {(time.time() - start_time) / n_files} s") + print("Files saved to: ", output_folder) + + +if __name__ == "__main__": + admm_dataset() diff --git a/scripts/sim/dataset.py b/scripts/sim/dataset.py index 2c08ba71..263d01f2 100644 --- a/scripts/sim/dataset.py +++ b/scripts/sim/dataset.py @@ -32,7 +32,7 @@ def simulate(config): if not os.path.isdir(dataset): print(f"No dataset found at {dataset}") try: - from torchvision.datasets.utils import download_and_extract_archive, download_url + from torchvision.datasets.utils import download_and_extract_archive except ImportError: exit() msg = "Do you want to download the sample CelebA dataset (764KB)?" diff --git a/scripts/sim/digicam_psf.py b/scripts/sim/digicam_psf.py new file mode 100644 index 00000000..d0e0636b --- /dev/null +++ b/scripts/sim/digicam_psf.py @@ -0,0 +1,154 @@ +import numpy as np +import os +import time +import hydra +import torch +from hydra.utils import to_absolute_path +import matplotlib.pyplot as plt +from slm_controller import slm +from lensless.utils.io import save_image, get_dtype, load_psf +from lensless.utils.plot import plot_image +from lensless.hardware.sensor import VirtualSensor +from lensless.hardware.slm import get_programmable_mask, get_intensity_psf +from waveprop.devices import slm_dict +from PIL import Image + + +@hydra.main(version_base=None, config_path="../../configs", config_name="sim_digicam_psf") +def digicam_psf(config): + + output_folder = os.getcwd() + + fp = to_absolute_path(config.digicam.pattern) + bn = os.path.basename(fp).split(".")[0] + + # digicam config + ap_center = np.array(config.digicam.ap_center) + ap_shape = np.array(config.digicam.ap_shape) + rotate_angle = config.digicam.rotate + slm_param = slm_dict[config.digicam.slm] + sensor = VirtualSensor.from_name(config.digicam.sensor) + + # simulation parameters + scene2mask = config.sim.scene2mask + mask2sensor = config.sim.mask2sensor + + torch_device = config.torch_device + dtype = get_dtype(config.dtype, config.use_torch) + + """ + Load pattern + """ + pattern = np.load(fp) + + # -- apply aperture + aperture = np.zeros(pattern.shape, dtype=np.uint8) + top_left = np.array(ap_center) - np.array(ap_shape) // 2 + bottom_right = top_left + np.array(ap_shape) + aperture[:, top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]] = 1 + pattern = pattern * aperture + + # -- extract aperture region + idx_1 = ap_center[0] - ap_shape[0] // 2 + idx_2 = ap_center[1] - ap_shape[1] // 2 + + pattern_sub = pattern[ + :, + idx_1 : idx_1 + ap_shape[0], + idx_2 : idx_2 + ap_shape[1], + ] + print("Controllable region shape: ", pattern_sub.shape) + print("Total number of pixels: ", np.prod(pattern_sub.shape)) + + # -- plot full + s = slm.create(config.digicam.slm) + s.set_preview(True) + s.imshow(pattern) + plt.savefig(os.path.join(output_folder, "pattern.png")) + + # -- plot sub pattern + plt.imshow(pattern_sub.transpose(1, 2, 0)) + plt.savefig(os.path.join(output_folder, "pattern_sub.png")) + + """ + Simulate PSF + """ + start_time = time.time() + slm_vals = pattern_sub / 255.0 + + if config.digicam.slm == "adafruit": + # flatten color channel along rows + slm_vals = slm_vals.reshape((-1, slm_vals.shape[-1]), order="F") + + if config.use_torch: + slm_vals = torch.from_numpy(slm_vals).to(device=torch_device, dtype=dtype) + else: + slm_vals = slm_vals.astype(dtype) + + mask = get_programmable_mask( + vals=slm_vals, + sensor=sensor, + slm_param=slm_param, + rotate=rotate_angle, + flipud=config.sim.flipud, + ) + + # -- plot mask + if config.use_torch: + mask_np = mask.cpu().detach().numpy() + else: + mask_np = mask.copy() + mask_np = np.transpose(mask_np, (1, 2, 0)) + plt.imshow(mask_np) + plt.savefig(os.path.join(output_folder, "mask.png")) + + # -- propagate to sensor + psf_in = get_intensity_psf( + mask=mask, + sensor=sensor, + waveprop=config.sim.waveprop, + scene2mask=scene2mask, + mask2sensor=mask2sensor, + ) + + # -- plot PSF + if config.use_torch: + psf_in_np = psf_in.cpu().detach().numpy() + else: + psf_in_np = psf_in.copy() + psf_in_np = np.transpose(psf_in_np, (1, 2, 0)) + + # plot + psf_meas = None + if config.digicam.psf is not None: + fp_psf = to_absolute_path(config.digicam.psf) + if os.path.exists(fp_psf): + psf_meas = load_psf(fp_psf) + else: + print("Could not load PSF image from: ", fp_psf) + + fp = os.path.join(output_folder, "psf_plot.png") + if psf_meas is not None: + _, ax = plt.subplots(1, 2) + ax[0].imshow(psf_in_np) + ax[0].set_title("Simulated") + plot_image(psf_meas, gamma=config.digicam.gamma, normalize=True, ax=ax[1]) + # ax[1].imshow(psf_meas) + ax[1].set_title("Measured") + plt.savefig(fp) + else: + plt.imshow(psf_in_np) + plt.savefig(fp) + + # save PSF as png + fp = os.path.join(output_folder, f"{bn}_SIM_psf.png") + save_image(psf_in_np, fp) + + proc_time = time.time() - start_time + print(f"\nProcessing time: {proc_time:.2f} seconds") + + print(f"\nFiles saved to : {output_folder}") + + +if __name__ == "__main__": + digicam_psf() From f226b97b6e28df6160d1400601af20d2e9101653 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 16 Aug 2023 15:33:10 +0200 Subject: [PATCH 03/12] Unrolled docs (#72) * Updating docs. * Update docs --- lensless/recon/recon.py | 10 +++++++++- lensless/recon/trainable_recon.py | 2 -- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index c035f4cd..1124c289 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -10,7 +10,7 @@ ============== The core algorithmic component of ``LenslessPiCam`` is the abstract -class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction +class :py:class:`~lensless.ReconstructionAlgorithm`. The five reconstruction strategies available in ``LenslessPiCam`` derive from this class: - :py:class:`~lensless.GradientDescent`: projected gradient descent with a @@ -25,6 +25,14 @@ class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction long as it is compatible with Pycsou, namely derives from one of `DiffFunc `_ or `ProxFunc `_. +- :py:class:`~lensless.UnrolledFISTA`: unrolled FISTA with a non-negativity constraint. +- :py:class:`~lensless.UnrolledADMM`: unrolled ADMM with a non-negativity constraint and a total variation (TV) regularizer [1]_. + +Note that the unrolled algorithms derive from the abstract class +:py:class:`~lensless.TrainableReconstructionAlgorithm`, which itself derives from +:py:class:`~lensless.ReconstructionAlgorithm` while adding functionality +for training on batches and adding trainable pre- and post-processing +blocks. New reconstruction algorithms can be conveniently implemented by deriving from the abstract class and defining the following abstract diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index c7129a3b..e554f6b0 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -5,7 +5,6 @@ # Yohann PERRON [yohann.perron@gmail.com] # ############################################################################# -import abc from lensless.recon.recon import ReconstructionAlgorithm try: @@ -24,7 +23,6 @@ class TrainableReconstructionAlgorithm(ReconstructionAlgorithm, torch.nn.Module) * ``_update``: updating state variables at each iterations. * ``reset``: reset state variables. * ``_form_image``: any pre-processing that needs to be done in order to view the image estimate, e.g. reshaping or clipping. - * ``batch_call``: method for performing iterative reconstruction on a batch of images. One advantage of deriving from this abstract class is that functionality for iterating, saving, and visualization is already implemented, namely in the From bba42c0ce588e2d6484328f26cad1391d6e34b54 Mon Sep 17 00:00:00 2001 From: YohannPerron <73244423+YohannPerron@users.noreply.github.com> Date: Tue, 29 Aug 2023 22:17:49 +0200 Subject: [PATCH 04/12] Improved dataset (#68) * New simulated dataset (moved old dataset) * Move dataset to utils * Added parent class DualDataset * Use new dataset structure for training * Fix doc and bugs * New dataset for lensless only * Fixes for downscaling * Update change * Disclaimer for LenslessDataset * Added header * Updated documentation * Fix typos and wording. * Move dataset docs to data section. * Fixed docstring * Fix for flip in simulated dataset * Add wrapper arounf FarFieldSimulator * Fix import error * Fix docstrings * FIx typos. * Fix doc rendering of FarFieldSimulator. * Refactor. * Refactor. * Fix import. * Refactor and rephrase for clearer dataset diff * Fixed no attribute psf * add new simulation to training script * Remove print. * Update changelog. --------- Co-authored-by: Eric Bezzam --- CHANGELOG.rst | 3 + docs/requirements.txt | 3 +- docs/source/conf.py | 8 +- docs/source/dataset.rst | 27 ++ docs/source/evaluation.rst | 4 - docs/source/index.rst | 1 + docs/source/simulation.rst | 12 + lensless/eval/benchmark.py | 209 +-------------- lensless/utils/dataset.py | 448 ++++++++++++++++++++++++++++++++ lensless/utils/simulation.py | 100 +++++++ scripts/eval/benchmark_recon.py | 3 +- scripts/recon/train_unrolled.py | 36 +-- 12 files changed, 620 insertions(+), 234 deletions(-) create mode 100644 docs/source/dataset.rst create mode 100644 lensless/utils/dataset.py create mode 100644 lensless/utils/simulation.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 847fa0f7..99be0bb1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,9 @@ Added - Script for measuring arbitrary dataset (from Raspberry Pi). - Support for preprocessing and postprocessing, such as denoising, in ``TrainableReconstructionAlgorithm``. Both trainable and fix postprocessing can be used. - Utilities to load a trained DruNet model for use as postprocessing in ``TrainableReconstructionAlgorithm``. +- Unified interface for dataset. See ``utils.dataset.DualDataset``. +- New simulated dataset compatible with new data format ([(batch_size), depth, width, height, color]). See ``utils.dataset.SimulatedFarFieldDataset``. +- New dataset for pair of original image and their measurement from a screen. See ``utils.dataset.MeasuredDataset`` and ``utils.dataset.MeasuredDatasetSimulatedOriginal``. - Support for unrolled loading and inference in the script ``admm.py``. - Tikhonov reconstruction for coded aperture measurements (MLS / MURA). diff --git a/docs/requirements.txt b/docs/requirements.txt index f8146fac..e105c9f7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,4 +4,5 @@ docutils==0.16 # >0.17 doesn't render bullets numpy>=1.22 # so that default dtype are correctly rendered torch>=1.10 torchvision>=0.15.2 -torchmetrics>=0.11.4 \ No newline at end of file +torchmetrics>=0.11.4 +waveprop>=0.0.5 \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index fc01f75b..02d3e0b0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,14 +28,14 @@ "pycsou.util", "pycsou.util.ptype", "PIL", + "PIL.Image", "tqdm", "paramiko", "paramiko.ssh_exception", "perlin_numpy", - "waveprop", - "waveprop.fresnel", - "waveprop.rs", - "waveprop.noise", + "scipy.special", + "matplotlib.cm", + "pyffs", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst new file mode 100644 index 00000000..1312e1cc --- /dev/null +++ b/docs/source/dataset.rst @@ -0,0 +1,27 @@ +Dataset objects (for training and testing) +========================================== + +The software below provides functionality (with PyTorch) to load +datasets for training and testing. + +.. automodule:: lensless.utils.dataset + +.. autoclass:: lensless.utils.dataset.DualDataset + :members: _get_images_pair + :special-members: __init__, __len__ + +.. autoclass:: lensless.utils.dataset.SimulatedFarFieldDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset + :members: + :special-members: __init__ diff --git a/docs/source/evaluation.rst b/docs/source/evaluation.rst index f3f381d2..0f2c9d93 100644 --- a/docs/source/evaluation.rst +++ b/docs/source/evaluation.rst @@ -23,8 +23,4 @@ .. automodule:: lensless.eval.benchmark - .. autoclass:: lensless.eval.benchmark.ParallelDataset - :members: - :special-members: __init__ - .. autofunction:: lensless.eval.benchmark.benchmark \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 94c236e6..3fba13d2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,6 +35,7 @@ Contents simulation data + dataset .. toctree:: :hidden: diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst index d5ecaa34..12739ad2 100644 --- a/docs/source/simulation.rst +++ b/docs/source/simulation.rst @@ -16,6 +16,18 @@ library is used with the following simulation steps: PyTorch support is available to speed up simulation on GPU, and to create Dataset and DataLoader objects for training and testing! +FarFieldSimulator +------------------ + +A wrapper around `waveprop.simulation.FarFieldSimulator `__ +is implemented as :py:class:`lensless.utils.simulation.FarFieldSimulator`. +It handles the conversion between the HWC and CHW dimension orderings so that the convention of LenslessPiCam can be maintained (namely HWC). + +.. autoclass:: lensless.utils.simulation.FarFieldSimulator + :members: + :special-members: __init__ + + Simulating 3D data ------------------ diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index b4aa6b79..2f78f402 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -7,18 +7,14 @@ # ############################################################################# -import glob -import os -from lensless.utils.io import load_psf -from lensless.utils.image import resize -import numpy as np +from lensless.utils.dataset import DiffuserCamTestDataset from tqdm import tqdm from lensless.utils.io import load_image try: import torch - from torch.utils.data import Dataset, DataLoader + from torch.utils.data import DataLoader from torch.nn import MSELoss, L1Loss from torchmetrics import StructuralSimilarityIndexMeasure from torchmetrics.image import lpip, psnr @@ -28,207 +24,6 @@ ) -class ParallelDataset(Dataset): - """ - Dataset consisting of lensless and corresponding lensed image. - - It can be used with a PyTorch DataLoader to load a batch of lensless and corresponding lensed images. - - """ - - def __init__( - self, - root_dir, - n_files=False, - background=None, - downsample=4, - flip=False, - transform_lensless=None, - transform_lensed=None, - lensless_fn="diffuser", - lensed_fn="lensed", - image_ext="npy", - **kwargs, - ): - """ - Dataset consisting of lensless and corresponding lensed image. Default parameters are for the DiffuserCam - Lensless Mirflickr Dataset (DLMD). - - Parameters - ---------- - - root_dir : str - Path to the test dataset. It is expected to contain two folders: ones of lensless images and one of lensed images. - n_files : int or None, optional - Metrics will be computed only on the first ``n_files`` images. If None, all images are used, by default False - background : :py:class:`~torch.Tensor` or None, optional - If not ``None``, background is removed from lensless images, by default ``None``. - downsample : int, optional - Downsample factor of the lensless images, by default 4. - flip : bool, optional - If ``True``, lensless images are flipped, by default ``False``. - transform_lensless : PyTorch Transform or None, optional - Transform to apply to the lensless images, by default None - transform_lensed : PyTorch Transform or None, optional - Transform to apply to the lensed images, by default None - lensless_fn : str, optional - Name of the folder containing the lensless images, by default "diffuser". - lensed_fn : str, optional - Name of the folder containing the lensed images, by default "lensed". - image_ext : str, optional - Extension of the images, by default "npy". - """ - - self.root_dir = root_dir - self.lensless_dir = os.path.join(root_dir, lensless_fn) - self.lensed_dir = os.path.join(root_dir, lensed_fn) - self.image_ext = image_ext.lower() - - files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) - if n_files: - files = files[:n_files] - self.files = [os.path.basename(fn) for fn in files] - - if len(self.files) == 0: - raise FileNotFoundError( - f"No files found in {self.lensless_dir} with extension {image_ext}" - ) - - self.background = background - self.downsample = downsample / 4 - self.flip = flip - self.transform_lensless = transform_lensless - self.transform_lensed = transform_lensed - - def __len__(self): - return len(self.files) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - - if self.image_ext == "npy": - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) - lensless = np.load(lensless_fp) - lensed = np.load(lensed_fp) - else: - # more standard image formats: png, jpg, tiff, etc. - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) - lensless = load_image(lensless_fp) - lensed = load_image(lensed_fp) - - # convert to float - if lensless.dtype == np.uint8: - lensless = lensless.astype(np.float32) / 255 - lensed = lensed.astype(np.float32) / 255 - else: - # 16 bit - lensless = lensless.astype(np.float32) / 65535 - lensed = lensed.astype(np.float32) / 65535 - - if self.downsample != 1.0: - lensless = resize(lensless, factor=1 / self.downsample) - lensed = resize(lensed, factor=1 / self.downsample) - - lensless = torch.from_numpy(lensless) - lensed = torch.from_numpy(lensed) - - # If [H, W, C] -> [D, H, W, C] - if len(lensless.shape) == 3: - lensless = lensless.unsqueeze(0) - if len(lensed.shape) == 3: - lensed = lensed.unsqueeze(0) - - if self.background is not None: - lensless = lensless - self.background - - # flip image x and y if needed - if self.flip: - lensless = torch.rot90(lensless, dims=(-3, -2)) - lensed = torch.rot90(lensed, dims=(-3, -2)) - if self.transform_lensless: - lensless = self.transform_lensless(lensless) - - if self.transform_lensed: - lensed = self.transform_lensed(lensed) - - return lensless, lensed - - -class DiffuserCamTestDataset(ParallelDataset): - """ - Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking. - """ - - def __init__( - self, - data_dir="data", - n_files=200, - downsample=8, - ): - """ - Dataset consisting of lensless and corresponding lensed image. Default parameters are for the test set of DiffuserCam - Lensless Mirflickr Dataset (DLMD). - - Parameters - ---------- - data_dir : str, optional - The path to the folder containing the DiffuserCam_Test dataset, by default "data" - n_files : int, optional - Number of image pair to load in the dataset , by default 200 - downsample : int, optional - Downsample factor of the lensless images, by default 8 - """ - # download dataset if necessary - main_dir = data_dir - data_dir = os.path.join(data_dir, "DiffuserCam_Test") - if not os.path.isdir(data_dir): - print("No dataset found for benchmarking.") - try: - from torchvision.datasets.utils import download_and_extract_archive - except ImportError: - exit() - msg = "Do you want to download the sample dataset (3.5GB)?" - - # default to yes if no input is given - valid = input("%s (Y/n) " % msg).lower() != "n" - if valid: - url = "https://drive.switch.ch/index.php/s/D3eRJ6PRljfHoH8/download" - filename = "DiffuserCam_Test.zip" - download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True) - - psf_fp = os.path.join(data_dir, "psf.tiff") - psf, background = load_psf( - psf_fp, - downsample=downsample, - return_float=True, - return_bg=True, - bg_pix=(0, 15), - ) - - # transform from BGR to RGB - from torchvision import transforms - - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - - self.psf = transform_BRG2RGB(torch.from_numpy(psf)) - - super().__init__( - data_dir, - n_files, - background, - downsample, - flip=False, - transform_lensless=transform_BRG2RGB, - transform_lensed=transform_BRG2RGB, - lensless_fn="diffuser", - lensed_fn="lensed", - image_ext="npy", - ) - - def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): """ Compute multiple metrics for a reconstruction algorithm. diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py new file mode 100644 index 00000000..2634cb7c --- /dev/null +++ b/lensless/utils/dataset.py @@ -0,0 +1,448 @@ +# ############################################################################# +# dataset.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# ############################################################################# + +import numpy as np +import glob +import os +import torch +from abc import abstractmethod +from torch.utils.data import Dataset +from torchvision import transforms +from lensless.utils.simulation import FarFieldSimulator +from lensless.utils.io import load_image, load_psf +from lensless.utils.image import resize + + +class DualDataset(Dataset): + """ + Abstract class for defining a dataset of paired lensed and lensless images. + """ + + def __init__( + self, + indices=None, + background=None, + downsample=1, + flip=False, + transform_lensless=None, + transform_lensed=None, + **kwargs, + ): + """ + Dataset consisting of lensless and corresponding lensed image. + + Parameters + ---------- + indices : range or int or None + Indices of the images to use in the dataset (if integer, it should be interpreted as range(indices)), by default None. + background : :py:class:`~torch.Tensor` or None, optional + If not ``None``, background is removed from lensless images, by default ``None``. + downsample : int, optional + Downsample factor of the lensless images, by default 1. + flip : bool, optional + If ``True``, lensless images are flipped, by default ``False``. + transform_lensless : PyTorch Transform or None, optional + Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + transform_lensed : PyTorch Transform or None, optional + Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + """ + if isinstance(indices, int): + indices = range(indices) + self.indices = indices + self.background = background + self.downsample = downsample + self.flip = flip + self.transform_lensless = transform_lensless + self.transform_lensed = transform_lensed + + @abstractmethod + def __len__(self): + """ + Abstract method to get the length of the dataset. It should take into account the indices parameter. + """ + raise NotImplementedError + + @abstractmethod + def _get_images_pair(self, idx): + """ + Abstract method to get the lensed and lensless images. Should return a pair (lensless, lensed) of numpy arrays with values in [0,1]. + + Parameters + ---------- + idx : int + images index + """ + raise NotImplementedError + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.item() + + if self.indices is not None: + idx = self.indices[idx] + lensless, lensed = self._get_images_pair(idx) + + if isinstance(lensless, np.ndarray): + # expected case + if self.downsample != 1.0: + lensless = resize(lensless, factor=1 / self.downsample) + lensed = resize(lensed, factor=1 / self.downsample) + + lensless = torch.from_numpy(lensless) + lensed = torch.from_numpy(lensed) + else: + # torch tensor + # This mean get_images_pair returned a torch tensor. This isn't recommended, if possible get_images_pair should return a numpy array + # In this case it should also have applied the downsampling + pass + + # If [H, W, C] -> [D, H, W, C] + if len(lensless.shape) == 3: + lensless = lensless.unsqueeze(0) + if len(lensed.shape) == 3: + lensed = lensed.unsqueeze(0) + + if self.background is not None: + lensless = lensless - self.background + + # flip image x and y if needed + if self.flip: + lensless = torch.rot90(lensless, dims=(-3, -2)) + lensed = torch.rot90(lensed, dims=(-3, -2)) + if self.transform_lensless: + lensless = self.transform_lensless(lensless) + if self.transform_lensed: + lensed = self.transform_lensed(lensed) + + return lensless, lensed + + +class SimulatedFarFieldDataset(DualDataset): + """ + Dataset of propagated images (through simulation) from a Torch Dataset. :py:class:`lensless.utils.simulation.FarFieldSimulator` is used for simulation, + assuming a far-field propagation and a shift-invariant system with a single point spread function (PSF). + + """ + + def __init__( + self, + dataset, + simulator, + pre_transform=None, + dataset_is_CHW=False, + flip=False, + **kwargs, + ): + """ + Parameters + ---------- + + dataset : :py:class:`torch.utils.data.Dataset` + Dataset to propagate. Should output images with shape [H, W, C] unless ``dataset_is_CHW`` is ``True`` (and therefore images have the dimension ordering of [C, H, W]). + simulator : :py:class:`lensless.utils.simulation.FarFieldSimulator` + Simulator object used on images from ``dataset``.Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. + pre_transform : PyTorch Transform or None, optional + Transform to apply to the images before simulation, by default ``None``. Note that this transform is applied on HCW images (different from torchvision). + dataset_is_CHW : bool, optional + If True, the input dataset is expected to output images with shape [C, H, W], by default ``False``. + flip : bool, optional + If True, images are flipped beffore the simulation, by default ``False``.. + """ + + # we do the flipping before the simualtion + super(SimulatedFarFieldDataset, self).__init__(flip=False, **kwargs) + + assert isinstance(dataset, Dataset) + self.dataset = dataset + self.n_files = len(dataset) + + self.dataset_is_CHW = dataset_is_CHW + self._pre_transform = pre_transform + self.flip_pre_sim = flip + + # check simulator + assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" + assert simulator.is_torch, "Simulator should be a pytorch simulator" + assert simulator.fft_shape is not None, "Simulator should have a psf" + self.sim = simulator + + def get_image(self, index): + return self.dataset[index] + + def _get_images_pair(self, index): + # load image + img, _ = self.get_image(index) + # convert to CHW for simulator and transform + if self.dataset_is_CHW: + img = img.moveaxis(-3, -1) + if self.flip_pre_sim: + img = torch.rot90(img, dims=(-3, -2)) + if self._pre_transform is not None: + img = self._pre_transform(img) + + lensless, lensed = self.sim.propagate(img, return_object_plane=True) + + return lensless, lensed + + def __len__(self): + if self.indices is None: + return self.n_files + else: + return len([x for x in self.indices if x < self.n_files]) + + +class MeasuredDatasetSimulatedOriginal(DualDataset): + """ + Dataset consisting of lensless image captured from a screen and the corresponding image shown on the screen. + Unlike :py:class:`lensless.utils.dataset.MeasuredDataset`, the ground-truth lensed image is simulated using a :py:class:`lensless.utils.simulation.FarFieldSimulator` + object rather than measured with a lensed camera. + """ + + def __init__( + self, + root_dir, + simulator, + lensless_fn="diffuser", + original_fn="lensed", + image_ext="npy", + original_ext=None, + downsample=1, + **kwargs, + ): + """ + Dataset consisting of lensless image captured from a screen and the corresponding image shown on screen. + + Parameters + ---------- + root_dir : str + Path to the test dataset. It is expected to contain two folders: one of lensless images and one of original images. + simulator : :py:class:`lensless.utils.simulatorFarFieldSimulator` + Simulator to use for the projection of the original image to object space. The PSF **should not** be specified, and it is expect to have ``is_torch = True``. + lensless_fn : str, optional + Name of the folder containing the lensless images, by default "diffuser". + lensed_fn : str, optional + Name of the folder containing the lensed images, by default "lensed". + image_ext : str, optional + Extension of the images, by default "npy". + original_ext : str, optional + Extension of the original image if different from lenless, by default None. + downsample : int, optional + Downsample factor of the lensless images, by default 1. + """ + super(MeasuredDatasetSimulatedOriginal, self).__init__(downsample=1, **kwargs) + self.pre_downsample = downsample + + self.root_dir = root_dir + self.lensless_dir = os.path.join(root_dir, lensless_fn) + self.original_dir = os.path.join(root_dir, original_fn) + self.image_ext = image_ext.lower() + self.original_ext = original_ext.lower() if original_ext is not None else image_ext.lower() + + files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) + files.sort() + self.files = [os.path.basename(fn) for fn in files] + + if len(self.files) == 0: + raise FileNotFoundError( + f"No files found in {self.lensless_dir} with extension {image_ext}" + ) + + # check simulator + assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" + assert simulator.is_torch, "Simulator should be a pytorch simulator" + assert simulator.fft_shape is None, "Simulator should not have a psf" + self.sim = simulator + + def __len__(self): + if self.indices is None: + return len(self.files) + else: + return len([i for i in self.indices if i < len(self.files)]) + + def _get_images_pair(self, idx): + if self.image_ext == "npy" or self.image_ext == "npz": + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + original_fp = os.path.join(self.original_dir, self.files[idx]) + lensless = np.load(lensless_fp) + lensless = resize(lensless, factor=1 / self.downsample) + original = np.load(original_fp[:-3] + self.original_ext) + else: + # more standard image formats: png, jpg, tiff, etc. + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + original_fp = os.path.join(self.original_dir, self.files[idx]) + lensless = load_image(lensless_fp, downsample=self.pre_downsample) + original = load_image( + original_fp[:-3] + self.original_ext, downsample=self.pre_downsample + ) + + # convert to float + if lensless.dtype == np.uint8: + lensless = lensless.astype(np.float32) / 255 + original = original.astype(np.float32) / 255 + else: + # 16 bit + lensless = lensless.astype(np.float32) / 65535 + original = original.astype(np.float32) / 65535 + + # convert to torch + lensless = torch.from_numpy(lensless) + original = torch.from_numpy(original) + + # project original image to lensed space + with torch.no_grad(): + lensed = self.sim.propagate() + + return lensless, lensed + + +class MeasuredDataset(DualDataset): + """ + Dataset consisting of lensless and corresponding lensed image. + It can be used with a PyTorch DataLoader to load a batch of lensless and corresponding lensed images. + Unless the setup is perfectly calibrated, one should expect to have to use ``transform_lensed`` to adjust the alignment and rotation. + """ + + def __init__( + self, + root_dir, + lensless_fn="diffuser", + lensed_fn="lensed", + image_ext="npy", + **kwargs, + ): + """ + Dataset consisting of lensless and corresponding lensed image. Default parameters are for the + `DiffuserCam Lensless Mirflickr Dataset (DLMD) `_. + + Parameters + ---------- + root_dir : str + Path to the test dataset. It is expected to contain two folders: ones of lensless images and one of lensed images. + lensless_fn : str, optional + Name of the folder containing the lensless images, by default "diffuser". + lensed_fn : str, optional + Name of the folder containing the lensed images, by default "lensed". + image_ext : str, optional + Extension of the images, by default "npy". + """ + + super(MeasuredDataset, self).__init__(**kwargs) + + self.root_dir = root_dir + self.lensless_dir = os.path.join(root_dir, lensless_fn) + self.lensed_dir = os.path.join(root_dir, lensed_fn) + self.image_ext = image_ext.lower() + + files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) + files.sort() + self.files = [os.path.basename(fn) for fn in files] + + if len(self.files) == 0: + raise FileNotFoundError( + f"No files found in {self.lensless_dir} with extension {image_ext}" + ) + + def __len__(self): + if self.indices is None: + return len(self.files) + else: + return len([i for i in self.indices if i < len(self.files)]) + + def _get_images_pair(self, idx): + if self.image_ext == "npy" or self.image_ext == "npz": + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) + lensless = np.load(lensless_fp) + lensed = np.load(lensed_fp) + else: + # more standard image formats: png, jpg, tiff, etc. + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) + lensless = load_image(lensless_fp) + lensed = load_image(lensed_fp) + + # convert to float + if lensless.dtype == np.uint8: + lensless = lensless.astype(np.float32) / 255 + lensed = lensed.astype(np.float32) / 255 + else: + # 16 bit + lensless = lensless.astype(np.float32) / 65535 + lensed = lensed.astype(np.float32) / 65535 + + return lensless, lensed + + +class DiffuserCamTestDataset(MeasuredDataset): + """ + Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking. + """ + + def __init__( + self, + data_dir="data", + n_files=200, + downsample=2, + ): + """ + Dataset consisting of lensless and corresponding lensed image. Default parameters are for the test set of + `DiffuserCam Lensless Mirflickr Dataset (DLMD) `_. + + Parameters + ---------- + data_dir : str, optional + The path to the folder containing the DiffuserCam_Test dataset, by default "data". + n_files : int, optional + Number of image pairs to load in the dataset , by default 200. + downsample : int, optional + Downsample factor of the lensless images, by default 8. + """ + + # download dataset if necessary + main_dir = data_dir + data_dir = os.path.join(data_dir, "DiffuserCam_Test") + if not os.path.isdir(data_dir): + print("No dataset found for benchmarking.") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the sample dataset (3.5GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/D3eRJ6PRljfHoH8/download" + filename = "DiffuserCam_Test.zip" + download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True) + + psf_fp = os.path.join(data_dir, "psf.tiff") + psf, background = load_psf( + psf_fp, + downsample=downsample, + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + + # transform from BGR to RGB + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + + self.psf = transform_BRG2RGB(torch.from_numpy(psf)) + + super().__init__( + root_dir=data_dir, + indices=range(n_files), + background=background, + downsample=downsample / 4, + flip=False, + transform_lensless=transform_BRG2RGB, + transform_lensed=transform_BRG2RGB, + lensless_fn="diffuser", + lensed_fn="lensed", + image_ext="npy", + ) diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py new file mode 100644 index 00000000..36aac243 --- /dev/null +++ b/lensless/utils/simulation.py @@ -0,0 +1,100 @@ +# ############################################################################# +# simulation.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# ############################################################################# + +import numpy as np +from waveprop.simulation import FarFieldSimulator as FarFieldSimulator_wp + + +class FarFieldSimulator(FarFieldSimulator_wp): + """ + LenslessPiCam-compatible wrapper for :py:class:`~waveprop.simulation.FarFieldSimulator` (source code on `GitHub `__). + """ + + def __init__( + self, + object_height, + scene2mask, + mask2sensor, + sensor, + psf=None, + output_dim=None, + snr_db=None, + max_val=255, + device_conv="cpu", + random_shift=False, + is_torch=False, + **kwargs + ): + """ + Parameters + ---------- + psf : np.ndarray, optional. + Point spread function. If not provided, return image at object plane. + object_height : float or (float, float) + Height of object in meters. Or range of values to randomly sample from. + scene2mask : float + Distance from scene to mask in meters. + mask2sensor : float + Distance from mask to sensor in meters. + sensor : str + Sensor name. + snr_db : float, optional + Signal-to-noise ratio in dB, by default None. + max_val : int, optional + Maximum value of image, by default 255. + device_conv : str, optional + Device to use for convolution (when using pytorch), by default "cpu". + random_shift : bool, optional + Whether to randomly shift the image, by default False. + is_torch : bool, optional + Whether to use pytorch, by default False. + """ + + if psf is not None: + # convert HWC to CHW + psf = psf.squeeze().movedim(-1, 0) + + super().__init__( + object_height, + scene2mask, + mask2sensor, + sensor, + psf, + output_dim, + snr_db, + max_val, + device_conv, + random_shift, + is_torch, + **kwargs + ) + + def propagate(self, obj, return_object_plane=False): + """ + Parameters + ---------- + obj : np.ndarray or torch.Tensor + Single image to propagate at format HWC. + return_object_plane : bool, optional + Whether to return object plane, by default False. + """ + if self.is_torch: + obj = obj.moveaxis(-1, 0) + res = super().propagate(obj, return_object_plane) + if isinstance(res, tuple): + res = res[0].moveaxis(-3, -1), res[1].moveaxis(-3, -1) + else: + res = res.moveaxis(-3, -1) + return res + else: + obj = np.moveaxis(obj, -1, 0) + res = super().propagate(obj, return_object_plane) + if isinstance(res, tuple): + res = np.moveaxis(res[0], -3, -1), np.moveaxis(res[1], -3, -1) + else: + res = np.moveaxis(res, -3, -1) + return res diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index de6a1c68..6611ceec 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -20,9 +20,10 @@ import json import os import pathlib as plib -from lensless.eval.benchmark import benchmark, DiffuserCamTestDataset +from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent +from lensless.utils.dataset import DiffuserCamTestDataset try: import torch diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 883f1819..a608ce97 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -22,9 +22,10 @@ import time import matplotlib.pyplot as plt from lensless import UnrolledFISTA, UnrolledADMM -from waveprop.dataset_util import SimulatedPytorchDataset +from lensless.utils.dataset import DiffuserCamTestDataset, SimulatedFarFieldDataset from lensless.utils.image import rgb2gray -from lensless.eval.benchmark import benchmark, DiffuserCamTestDataset +from lensless.utils.simulation import FarFieldSimulator +from lensless.eval.benchmark import benchmark import torch from torchvision import transforms, datasets from tqdm import tqdm @@ -58,15 +59,11 @@ def simulate_dataset(config, psf): psf = rgb2gray(psf) if not isinstance(psf, torch.Tensor): psf = transforms.ToTensor()(psf) - elif psf.shape[-1] == 3: - # Waveprop syntetic dataset expect C H W - psf = psf.permute(2, 0, 1) # batch_size = config.files.batch_size batch_size = config.training.batch_size n_files = config.files.n_files device_conv = config.torch_device - target = config.target # check if gpu is available if device_conv == "cuda" and torch.cuda.is_available(): @@ -74,11 +71,17 @@ def simulate_dataset(config, psf): else: device_conv = "cpu" + # create simulator + simulator = FarFieldSimulator( + psf=psf, + is_torch=True, + **config.simulation, + ) # create Pytorch dataset and dataloader if n_files is not None: ds = torch.utils.data.Subset(ds, np.arange(n_files)) - ds_prop = SimulatedPytorchDataset( - dataset=ds, psf=psf, device_conv=device_conv, target=target, **config.simulation + ds_prop = SimulatedFarFieldDataset( + dataset=ds, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv ) ds_loader = torch.utils.data.DataLoader( dataset=ds_prop, batch_size=batch_size, shuffle=True, pin_memory=(psf.device != "cpu") @@ -138,9 +141,6 @@ def train_unrolled( # torch.autograd.set_detect_anomaly(True) - # if using a portrait dataset rotate the PSF - flip = config.files.dataset in ["CelebA"] - # benchmarking dataset: path = os.path.join(get_original_cwd(), "data") benchmark_dataset = DiffuserCamTestDataset( @@ -155,8 +155,6 @@ def train_unrolled( psf = psf[..., [2, 1, 0]] # if using a portrait dataset rotate the PSF - if flip: - psf = torch.rot90(psf, dims=[0, 1]) disp = config.display.disp if disp < 0: @@ -222,17 +220,21 @@ def train_unrolled( # load dataset and create dataloader if config.files.dataset == "DiffuserCam": # Use a ParallelDataset - from lensless.eval.benchmark import ParallelDataset + from lensless.utils.dataset import MeasuredDataset + + max_indices = 30000 + if config.files.n_files is not None: + max_indices = config.files.n_files + 1000 data_path = os.path.join(get_original_cwd(), "data", "DiffuserCam") - dataset = ParallelDataset( + dataset = MeasuredDataset( root_dir=data_path, - n_files=config.files.n_files, + indices=range(1000, max_indices), background=background, psf=psf, lensless_fn="diffuser_images", lensed_fn="ground_truth_lensed", - downsample=config.simulation.downsample, + downsample=config.simulation.downsample / 4, transform_lensless=transform_BRG2RGB, transform_lensed=transform_BRG2RGB, ) From 58f747adeba5d0007f6c4484acbc60342387cede Mon Sep 17 00:00:00 2001 From: YohannPerron <73244423+YohannPerron@users.noreply.github.com> Date: Wed, 30 Aug 2023 01:04:18 +0200 Subject: [PATCH 05/12] Streamlined training with new Trainer class (#77) * move utility function outside of script * New trainer class for training reconstruction * Update docstring * Update changelog * Update to trainer save * Fix partial mask support bug * Fix docstrings. * Fix APGD rendering. --------- Co-authored-by: Eric Bezzam --- CHANGELOG.rst | 1 + docs/requirements.txt | 1 + docs/source/conf.py | 4 + docs/source/reconstruction.rst | 22 +- lensless/eval/benchmark.py | 2 - lensless/recon/utils.py | 410 +++++++++++++++++++++++++++++++- recon_requirements.txt | 1 - scripts/recon/admm.py | 6 +- scripts/recon/train_unrolled.py | 238 ++---------------- setup.py | 1 + 10 files changed, 453 insertions(+), 233 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 99be0bb1..6ca03bd3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,7 @@ Added - New dataset for pair of original image and their measurement from a screen. See ``utils.dataset.MeasuredDataset`` and ``utils.dataset.MeasuredDatasetSimulatedOriginal``. - Support for unrolled loading and inference in the script ``admm.py``. - Tikhonov reconstruction for coded aperture measurements (MLS / MURA). +- New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. Changed diff --git a/docs/requirements.txt b/docs/requirements.txt index e105c9f7..484c5d20 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,4 +5,5 @@ numpy>=1.22 # so that default dtype are correctly rendered torch>=1.10 torchvision>=0.15.2 torchmetrics>=0.11.4 +pyFFS>=2.2.3 # for waveprop waveprop>=0.0.5 \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 02d3e0b0..60ee9e96 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,9 @@ "torchmetrics.image", "scipy.ndimage", "pycsou.abc", + "pycsou.operator", "pycsou.operator.func", + "pycsou.operator.linop", "pycsou.opt.solver", "pycsou.opt.stop", "pycsou.runtime", @@ -33,6 +35,8 @@ "paramiko", "paramiko.ssh_exception", "perlin_numpy", + "hydra", + "hydra.utils", "scipy.special", "matplotlib.cm", "pyffs", diff --git a/docs/source/reconstruction.rst b/docs/source/reconstruction.rst index 27434c40..e5b927f4 100644 --- a/docs/source/reconstruction.rst +++ b/docs/source/reconstruction.rst @@ -55,7 +55,7 @@ Accelerated Proximal Gradient Descent (APGD) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - .. autoclass:: lensless.APGD + .. autoclass:: lensless.recon.apgd.APGD :special-members: __init__ @@ -88,4 +88,22 @@ .. autoclass:: lensless.UnrolledADMM :members: batch_call :special-members: __init__ - :show-inheritance: \ No newline at end of file + :show-inheritance: + + + Reconstruction Utilities + ------------------------ + + .. autoclass:: lensless.recon.utils.Trainer + :members: + :special-members: __init__ + + .. autofunction:: lensless.recon.utils.load_drunet + + .. autofunction:: lensless.recon.utils.apply_denoiser + + .. autofunction:: lensless.recon.utils.get_drunet_function + + .. autofunction:: lensless.recon.utils.measure_gradient + + .. autofunction:: lensless.recon.utils.create_process_network diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 2f78f402..f93b754d 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -10,8 +10,6 @@ from lensless.utils.dataset import DiffuserCamTestDataset from tqdm import tqdm -from lensless.utils.io import load_image - try: import torch from torch.utils.data import DataLoader diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 7fad0400..54d23a1d 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -1,4 +1,21 @@ +# ############################################################################# +# dataset.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +import json +import math +import time +from hydra.utils import get_original_cwd +import os +import matplotlib.pyplot as plt import torch +from lensless.eval.benchmark import benchmark +from tqdm import tqdm from lensless.recon.drunet.network_unet import UNetRes @@ -17,7 +34,7 @@ def load_drunet(model_path, n_channels=3, requires_grad=False): Returns ------- - model : :py:class:`~torch.nn.Module` + model : :py:class:`torch.nn.Module` Loaded model. """ @@ -45,11 +62,11 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference") Parameters ---------- - model : :py:class:`~torch.nn.Module` + model : :py:class:`torch.nn.Module` Drunet compatible model. Its input must consist of 4 channels (RGB + noise level) and output an RGB image both in CHW format. - image : :py:class:`~torch.Tensor` + image : :py:class:`torch.Tensor` Input image. - noise_level : float or :py:class:`~torch.Tensor` + noise_level : float or :py:class:`torch.Tensor` Noise level in the image. device : str Device to use for computation. Can be "cpu" or "cuda". @@ -58,7 +75,7 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference") Returns ------- - image : :py:class:`~torch.Tensor` + image : :py:class:`torch.Tensor` Reconstructed image. """ # convert from NDHWC to NCHW @@ -108,7 +125,7 @@ def get_drunet_function(model, device="cpu", mode="inference"): Parameters ---------- - model : torch.nn.Module + model : :py:class:`torch.nn.Module` DruNet like denoiser model device : str Device to use for computation. Can be "cpu" or "cuda". @@ -129,3 +146,384 @@ def process(image, noise_level): return image return process + + +def measure_gradient(model): + """ + Helper function to measure L2 norm of the gradient of a model. + + Parameters + ---------- + model : :py:class:`torch.nn.Module` + Model to measure gradient of. + + Returns + ------- + Float + L2 norm of the gradient of the model. + """ + total_norm = 0.0 + for p in model.parameters(): + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm**0.5 + return total_norm + + +def create_process_network(network, depth, device="cpu"): + """ + Helper function to create a process network. + + Parameters + ---------- + network : str + Name of network to use. Can be "DruNet" or "UnetRes". + depth : int + Depth of network. + device : str + Device to use for computation. Can be "cpu" or "cuda". Defaults to "cpu". + + Returns + ------- + :py:class:`torch.nn.Module` + New process network. Already trained for Drunet. + """ + if network == "DruNet": + from lensless.recon.utils import load_drunet + + process = load_drunet( + os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True + ).to(device) + process_name = "DruNet" + elif network == "UnetRes": + from lensless.recon.drunet.network_unet import UNetRes + + n_channels = 3 + process = UNetRes( + in_nc=n_channels + 1, + out_nc=n_channels, + nc=[64, 128, 256, 512], + nb=depth, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ).to(device) + process_name = "UnetRes_d" + str(depth) + else: + process = None + process_name = None + + return (process, process_name) + + +class Trainer: + def __init__( + self, + recon, + train_dataset, + test_dataset, + batch_size=4, + loss="l2", + lpips=None, + optimizer="Adam", + optimizer_lr=1e-6, + slow_start=None, + skip_NAN=False, + algorithm_name="Unknown", + ): + """ + Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. + + Parameters + ---------- + recon : :py:class:`lensless.TrainableReconstructionAlgorithm` + Reconstruction algorithm to train. + train_dataset : :py:class:`torch.utils.data.Dataset` + Dataset to use for training. + test_dataset : :py:class:`torch.utils.data.Dataset` + Dataset to use for testing. + batch_size : int, optional + Batch size to use for training, by default 4 + loss : str, optional + Loss function to use for training "l1" or "l2", by default "l2" + lpips : float, optional + the weight of the lpips(VGG) in the total loss. If None ignore. By default None + optimizer : str, optional + Optimizer to use durring training. Available : "Adam". By default "Adam" + optimizer_lr : float, optional + Learning rate for the optimizer, by default 1e-6 + slow_start : float, optional + Multiplicative factor to reduce the learning rate during the first two epochs. If None, ignored. Default is None. + skip_NAN : bool, optional + Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False + algorithm_name : str, optional + Algorithm name for logging, by default "Unknown". + + """ + self.device = recon._psf.device + + self.recon = recon + self.train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=(self.device != "cpu"), + ) + self.test_dataset = test_dataset + self.lpips = lpips + self.skip_NAN = skip_NAN + + # loss + if loss == "l2": + self.Loss = torch.nn.MSELoss() + elif loss == "l1": + self.Loss = torch.nn.L1Loss() + else: + raise ValueError(f"Unsuported loss : {loss}") + + # Lpips loss + if lpips: + try: + import lpips + + self.Loss_lpips = lpips.LPIPS(net="vgg").to(self.device) + except ImportError: + return ImportError( + "lpips package is need for LPIPS loss. Install using : pip install lpips" + ) + + # optimizer + if optimizer == "Adam": + # the parameters of the base model and non torch.Module process must be added separatly + parameters = [{"params": recon.parameters()}] + self.optimizer = torch.optim.Adam(parameters, lr=optimizer_lr) + else: + raise ValueError(f"Unsuported optimizer : {optimizer}") + # Scheduler + if slow_start: + + def learning_rate_function(epoch): + if epoch == 0: + return slow_start + elif epoch == 1: + return math.sqrt(slow_start) + else: + return 1 + + else: + + def learning_rate_function(epoch): + return 1 + + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=learning_rate_function + ) + + self.metrics = { + "LOSS": [], + "MSE": [], + "MAE": [], + "LPIPS_Vgg": [], + "LPIPS_Alex": [], + "PSNR": [], + "SSIM": [], + "ReconstructionError": [], + "n_iter": self.recon._n_iter, + "algorithm": algorithm_name, + } + + # Backward hook that detect NAN in the gradient and print the layer weights + if not self.skip_NAN: + + def detect_nan(grad): + if torch.isnan(grad).any(): + print(grad, flush=True) + for name, param in recon.named_parameters(): + if param.requires_grad: + print(name, param) + raise ValueError("Gradient is NaN") + return grad + + for param in recon.parameters(): + if param.requires_grad: + param.register_hook(detect_nan) + if param.requires_grad: + param.register_hook(detect_nan) + + def train_epoch(self, data_loader, disp=-1): + """ + Train for one epoch. + + Parameters + ---------- + data_loader : :py:class:`torch.utils.data.DataLoader` + Data loader to use for training. + disp : int, optional + Display interval, if -1, no display, by default -1 + + Returns + ------- + float + Mean loss of the epoch. + """ + mean_loss = 0.0 + i = 1.0 + pbar = tqdm(data_loader) + for X, y in pbar: + # send to device + X = X.to(self.device) + y = y.to(self.device) + + y_pred = self.recon.batch_call(X.to(self.device)) + # normalizing each output + eps = 1e-12 + y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps + y_pred = y_pred / y_pred_max + + # normalizing y + y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps + y = y / y_max + + if i % disp == 1: + img_pred = y_pred[0, 0].cpu().detach().numpy() + img_truth = y[0, 0].cpu().detach().numpy() + + plt.imshow(img_pred) + plt.savefig(f"y_pred_{i-1}.png") + plt.imshow(img_truth) + plt.savefig(f"y_{i-1}.png") + + self.optimizer.zero_grad(set_to_none=True) + # convert to CHW for loss and remove depth + y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) + y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) + + loss_v = self.Loss(y_pred, y) + if self.lpips: + # value for LPIPS needs to be in range [-1, 1] + loss_v = loss_v + self.lpips * torch.mean( + self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) + ) + loss_v.backward() + + torch.nn.utils.clip_grad_norm_(self.recon.parameters(), 1.0) + + # if any gradient is NaN, skip training step + if self.skip_NAN: + is_NAN = False + for param in self.recon.parameters(): + if torch.isnan(param.grad).any(): + is_NAN = True + break + if is_NAN: + print("NAN detected in gradiant, skipping training step") + i += 1 + continue + self.optimizer.step() + + mean_loss += (loss_v.item() - mean_loss) * (1 / i) + pbar.set_description(f"loss : {mean_loss}") + i += 1 + + return mean_loss + + def evaluate(self, mean_loss, save_pt): + """ + Evaluate the reconstruction algorithm on the test dataset. + + Parameters + ---------- + mean_loss : float + Mean loss of the last epoch. + save_pt : str + Path to save metrics dictionary to. If None, no logging of metrics. + """ + if self.test_dataset is None: + return + # benchmarking + current_metrics = benchmark(self.recon, self.test_dataset, batchsize=10) + + # update metrics with current metrics + self.metrics["LOSS"].append(mean_loss) + for key in current_metrics: + self.metrics[key].append(current_metrics[key]) + + if save_pt: + # save dictionary metrics to file with json + with open(os.path.join(save_pt, "metrics.json"), "w") as f: + json.dump(self.metrics, f) + + def on_epoch_end(self, mean_loss, save_pt): + """ + Called at the end of each epoch. + + Parameters + ---------- + mean_loss : float + Mean loss of the last epoch. + save_pt : str + Path to save metrics dictionary to. If None, no logging of metrics. + """ + if save_pt is None: + # Use current directory + save_pt = os.getcwd() + + # save model + self.save(path=save_pt, include_optimizer=False) + self.evaluate(mean_loss, save_pt) + + def train(self, n_epoch=1, save_pt=None, disp=-1): + """ + Train the reconstruction algorithm. + + Parameters + ---------- + n_epoch : int, optional + Number of epochs to train for, by default 1 + save_pt : str, optional + Path to save metrics dictionary to. If None, use current directory, by default None + disp : int, optional + Display interval, if -1, no display. Default is -1. + """ + + start_time = time.time() + + for epoch in range(n_epoch): + print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") + mean_loss = self.train_epoch(self.train_dataloader, disp=disp) + self.on_epoch_end(mean_loss, save_pt) + self.scheduler.step() + + print(f"Train time : {time.time() - start_time} s") + + def save(self, path="recon", include_optimizer=False): + """ + Save state of reconstruction algorithm. + + Parameters + ---------- + path : str, optional + Path to save model to, by default "recon" + include_optimizer : bool, optional + Whether to include optimizer state, by default False + + """ + # create directory if it does not exist + if not os.path.exists(path): + os.makedirs(path) + + # TODO : ADD mask support + # # save mask + # if self.use_mask: + # torch.save(self.mask._mask, os.path.join(path, "mask.pt")) + # torch.save(self.mask._optimizer.state_dict(), os.path.join(path, "mask_optim.pt")) + # import matplotlib.pyplot as plt + + # plt.imsave( + # os.path.join(path, "psf.png"), self.mask.get_psf().detach().cpu().numpy()[0, ...] + # ) + # save optimizer + if include_optimizer: + torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) + # save recon + torch.save(self.recon.state_dict(), os.path.join(path, "recon.pt")) diff --git a/recon_requirements.txt b/recon_requirements.txt index 4ebe4412..b9e9f324 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -2,7 +2,6 @@ jedi==0.18.0 lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 -hydra-core click>=8.0.1 waveprop>=0.0.3 # for simulation diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 3ba3de1f..2a053722 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -58,14 +58,14 @@ def admm(config): else: assert config.torch, "Unrolled ADMM only works with torch" from lensless.recon.unrolled_admm import UnrolledADMM - import train_unrolled + import lensless.recon.utils - pre_process = train_unrolled.create_process_network( + pre_process = lensless.recon.utils.create_process_network( network=config.admm.pre_process_model.network, depth=config.admm.pre_process_depth.depth, device=config.torch_device, ) - post_process = train_unrolled.create_process_network( + post_process = lensless.recon.utils.create_process_network( network=config.admm.post_process_model.network, depth=config.admm.post_process_depth.depth, device=config.torch_device, diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index a608ce97..7d0a31e1 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -14,26 +14,19 @@ """ -import math import hydra from hydra.utils import get_original_cwd import os import numpy as np import time -import matplotlib.pyplot as plt from lensless import UnrolledFISTA, UnrolledADMM from lensless.utils.dataset import DiffuserCamTestDataset, SimulatedFarFieldDataset +from lensless.recon.utils import create_process_network from lensless.utils.image import rgb2gray from lensless.utils.simulation import FarFieldSimulator -from lensless.eval.benchmark import benchmark +from lensless.recon.utils import Trainer import torch from torchvision import transforms, datasets -from tqdm import tqdm - -try: - import json -except ImportError: - print("json package not found, metrics will not be saved") def simulate_dataset(config, psf): @@ -60,8 +53,6 @@ def simulate_dataset(config, psf): if not isinstance(psf, torch.Tensor): psf = transforms.ToTensor()(psf) - # batch_size = config.files.batch_size - batch_size = config.training.batch_size n_files = config.files.n_files device_conv = config.torch_device @@ -83,49 +74,7 @@ def simulate_dataset(config, psf): ds_prop = SimulatedFarFieldDataset( dataset=ds, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv ) - ds_loader = torch.utils.data.DataLoader( - dataset=ds_prop, batch_size=batch_size, shuffle=True, pin_memory=(psf.device != "cpu") - ) - return ds_loader - - -def create_process_network(network, depth, device="cpu"): - if network == "DruNet": - from lensless.recon.utils import load_drunet - - process = load_drunet( - os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True - ).to(device) - process_name = "DruNet" - elif network == "UnetRes": - from lensless.recon.drunet.network_unet import UNetRes - - n_channels = 3 - process = UNetRes( - in_nc=n_channels + 1, - out_nc=n_channels, - nc=[64, 128, 256, 512], - nb=depth, - act_mode="R", - downsample_mode="strideconv", - upsample_mode="convtranspose", - ).to(device) - process_name = "UnetRes_d" + str(depth) - else: - process = None - process_name = None - - return (process, process_name) - - -def measure_gradient(model): - # return the L2 norm of the gradient - total_norm = 0.0 - for p in model.parameters(): - param_norm = p.grad.detach().data.norm(2) - total_norm += param_norm.item() ** 2 - total_norm = total_norm**0.5 - return total_norm + return ds_prop @hydra.main(version_base=None, config_path="../../configs", config_name="unrolled_recon") @@ -189,7 +138,6 @@ def train_unrolled( pre_process=pre_process, post_process=post_process, ).to(device) - n_iter = config.reconstruction.unrolled_fista.n_iter elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( psf, @@ -201,7 +149,6 @@ def train_unrolled( pre_process=pre_process, post_process=post_process, ).to(device) - n_iter = config.reconstruction.unrolled_admm.n_iter else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") @@ -238,175 +185,28 @@ def train_unrolled( transform_lensless=transform_BRG2RGB, transform_lensed=transform_BRG2RGB, ) - data_loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config.training.batch_size, - shuffle=True, - pin_memory=(device != "cpu"), - ) else: # Use a simulated dataset - data_loader = simulate_dataset(config, psf) + dataset = simulate_dataset(config, psf) print(f"Setup time : {time.time() - start_time} s") - start_time = time.time() - - # loss - if config.loss == "l2": - Loss = torch.nn.MSELoss() - elif config.loss == "l1": - Loss = torch.nn.L1Loss() - else: - raise ValueError(f"Unsuported loss : {config.loss}") - - # Lpips loss - if config.lpips: - try: - import lpips - - loss_lpips = lpips.LPIPS(net="vgg").to(device) - except ImportError: - return ImportError( - "lpips package is need for LPIPS loss. Install using : pip install lpips" - ) - - # optimizer - if config.optimizer.type == "Adam": - # the parameters of the base model and non torch.Module process must be added separatly - parameters = [{"params": recon.parameters()}] - optimizer = torch.optim.Adam(parameters, lr=config.optimizer.lr) - else: - raise ValueError(f"Unsuported optimizer : {config.optimizer.type}") - # Scheduler - if config.training.slow_start: - - def learning_rate_function(epoch): - if epoch == 0: - return config.training.slow_start - elif epoch == 1: - return math.sqrt(config.training.slow_start) - else: - return 1 - - else: - - def learning_rate_function(epoch): - return 1 - - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=learning_rate_function) - - metrics = { - "LOSS": [], - "MSE": [], - "MAE": [], - "LPIPS_Vgg": [], - "LPIPS_Alex": [], - "PSNR": [], - "SSIM": [], - "ReconstructionError": [], - "n_iter": n_iter, - "algorithm": algorithm_name, - } - - # Backward hook that detect NAN in the gradient and print the layer weights - if not config.training.skip_NAN: - - def detect_nan(grad): - if torch.isnan(grad).any(): - print(grad, flush=True) - for name, param in recon.named_parameters(): - if param.requires_grad: - print(name, param) - raise ValueError("Gradient is NaN") - return grad - - for param in recon.parameters(): - if param.requires_grad: - param.register_hook(detect_nan) - if param.requires_grad: - param.register_hook(detect_nan) - - # Training loop - for epoch in range(config.training.epoch): - print(f"Epoch {epoch} with learning rate {scheduler.get_last_lr()}") - mean_loss = 0.0 - i = 1.0 - pbar = tqdm(data_loader) - for X, y in pbar: - # send to device - X = X.to(device) - y = y.to(device) - if X.shape[3] == 3: - X = X - y = y - - y_pred = recon.batch_call(X.to(device)) - # normalizing each output - eps = 1e-12 - y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps - y_pred = y_pred / y_pred_max - - # normalizing y - y = y.to(device) - y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps - y = y / y_max - - if i % disp == 1 and config.display.plot: - img_pred = y_pred[0, 0].cpu().detach().numpy() - img_truth = y[0, 0].cpu().detach().numpy() - - plt.imshow(img_pred) - plt.savefig(f"y_pred_{i-1}.png") - plt.imshow(img_truth) - plt.savefig(f"y_{i-1}.png") - - optimizer.zero_grad(set_to_none=True) - # convert to CHW for loss and remove depth - y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) - y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) - - loss_v = Loss(y_pred, y) - if config.lpips: - # value for LPIPS needs to be in range [-1, 1] - loss_v = loss_v + config.lpips * torch.mean(loss_lpips(2 * y_pred - 1, 2 * y - 1)) - loss_v.backward() - torch.nn.utils.clip_grad_norm_(recon.parameters(), 1.0) - - # if any gradient is NaN, skip training step - is_NAN = False - for param in recon.parameters(): - if torch.isnan(param.grad).any(): - is_NAN = True - break - if is_NAN: - print("NAN detected in gradiant, skipping training step") - i += 1 - continue - optimizer.step() - - mean_loss += (loss_v.item() - mean_loss) * (1 / i) - pbar.set_description(f"loss : {mean_loss}") - i += 1 - - # benchmarking - current_metrics = benchmark(recon, benchmark_dataset, batchsize=10) - # update metrics with current metrics - metrics["LOSS"].append(mean_loss) - for key in current_metrics: - metrics[key].append(current_metrics[key]) - - # Update learning rate - scheduler.step() - - print(f"Train time : {time.time() - start_time} s") - - # save dictionary metrics to file with json - with open(os.path.join(save, "metrics.json"), "w") as f: - json.dump(metrics, f) + trainer = Trainer( + recon, + dataset, + benchmark_dataset, + batch_size=config.training.batch_size, + loss=config.loss, + lpips=config.lpips, + optimizer=config.optimizer.type, + optimizer_lr=config.optimizer.lr, + slow_start=config.training.slow_start, + skip_NAN=config.training.skip_NAN, + algorithm_name=algorithm_name, + ) - # save pytorch model recon - torch.save(recon.state_dict(), "recon.pt") + trainer.train(n_epoch=config.training.epoch, save_pt=save) + trainer.save(path=os.path.join(save, "recon.pt")) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 20c07f7a..79468810 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "matplotlib>=3.4.2", "rawpy>=0.16.0", "paramiko>=3.2.0", + "hydra-core", ], extra_requires={"dev": ["pudb", "black"]}, ) From ff86fb27fb86078ee0d279d12e0e95f355e90ec8 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 30 Aug 2023 10:31:37 -0700 Subject: [PATCH 06/12] Torch support for Coded Aperture reconstruction (#79) * added pytorch version of tikhonov reconstruction * replaced type(...) == ... by isintance (..., ...) * Fix torch support for tikhonov. * Change docstring. * Change docstring. * Update changelog. * Added "try" before SSIM computation * removed 'try' for SSIM --------- Co-authored-by: Aaron Fargeon --- CHANGELOG.rst | 2 +- configs/mask_sim_single.yaml | 1 + lensless/hardware/mask.py | 20 ++++- lensless/recon/tikhonov.py | 132 +++++++++++++++++++++++--------- scripts/sim/mask_single_file.py | 31 +++++++- 5 files changed, 145 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6ca03bd3..b1657928 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -26,7 +26,7 @@ Added - New simulated dataset compatible with new data format ([(batch_size), depth, width, height, color]). See ``utils.dataset.SimulatedFarFieldDataset``. - New dataset for pair of original image and their measurement from a screen. See ``utils.dataset.MeasuredDataset`` and ``utils.dataset.MeasuredDatasetSimulatedOriginal``. - Support for unrolled loading and inference in the script ``admm.py``. -- Tikhonov reconstruction for coded aperture measurements (MLS / MURA). +- Tikhonov reconstruction for coded aperture measurements (MLS / MURA): numpy and Pytorch support. - New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. diff --git a/configs/mask_sim_single.yaml b/configs/mask_sim_single.yaml index f793d302..0d20efa5 100644 --- a/configs/mask_sim_single.yaml +++ b/configs/mask_sim_single.yaml @@ -8,6 +8,7 @@ files: #original: data/original/mnist_3.png save: True +use_torch: False simulation: object_height: 0.3 diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index 126d21f1..f9597bf5 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -33,6 +33,13 @@ from lensless.hardware.sensor import VirtualSensor from lensless.utils.image import resize +try: + import torch + + torch_available = True +except ImportError: + torch_available = False + class Mask(abc.ABC): """ @@ -295,12 +302,23 @@ def simulate(self, obj, snr_db=20): # Convolve image n_channels = obj.shape[-1] - meas = np.dstack([multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)]) + + if torch_available and isinstance(obj, torch.Tensor): + P = torch.from_numpy(P).float() + Q = torch.from_numpy(Q).float() + meas = torch.dstack( + [torch.linalg.multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)] + ).float() + else: + meas = np.dstack([multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)]) # Add noise if snr_db is not None: meas = add_shot_noise(meas, snr_db=snr_db) + if torch_available and isinstance(obj, torch.Tensor): + meas = meas.to(obj) + return meas diff --git a/lensless/recon/tikhonov.py b/lensless/recon/tikhonov.py index 84a88011..fb9a182d 100644 --- a/lensless/recon/tikhonov.py +++ b/lensless/recon/tikhonov.py @@ -2,8 +2,8 @@ # tikhonov.py # ================= # Authors : -# Aaron FARGEON [aa.fargeon@gmail.com] # Eric BEZZAM [ebezzam@gmail.com] +# Aaron FARGEON [aa.fargeon@gmail.com] # ############################################################################# """ @@ -20,6 +20,13 @@ import numpy as np from numpy.linalg import multi_dot +try: + import torch + + torch_available = True +except ImportError: + torch_available = False + class CodedApertureReconstruction: """ @@ -32,7 +39,7 @@ def __init__(self, mask, image_shape, P=None, Q=None, lmbd=3e-4): """ Parameters ---------- - mask : py:class:`~lensless.hardware.mask.CodedAperture` + mask : py:class:`lensless.hardware.mask.CodedAperture` Coded aperture mask object. image_shape : (`array-like` or `tuple`) The shape of the image to reconstruct. @@ -67,46 +74,97 @@ def apply(self, img): Parameters ---------- - img : :py:class:`~numpy.ndarray` + img : :py:class:`~numpy.ndarray` or :py:class:`torch.Tensor` Lensless capture measurement. Must be 3D even if grayscale. Returns ------- - :py:class:`~numpy.ndarray` + :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` Reconstructed image, in the same format as the measurement. """ - assert len(img.shape) == 3, "Object should be a 3D array (HxWxC) even if grayscale." - - # Empty matrix for reconstruction - n_channels = img.shape[-1] - x_est = np.empty([self.P.shape[1], self.Q.shape[1], n_channels]) - - # Applying reconstruction for each channel - for c in range(n_channels): - - # SVD of left matrix - UL, SL, VLh = np.linalg.svd(self.P, full_matrices=True) - VL = VLh.T - DL = np.concatenate((np.diag(SL), np.zeros([self.P.shape[0] - SL.size, SL.size]))) - singLsq = np.square(SL) - - # SVD of right matrix - UR, SR, VRh = np.linalg.svd(self.Q, full_matrices=True) - VR = VRh.T - DR = np.concatenate((np.diag(SR), np.zeros([self.Q.shape[0] - SR.size, SR.size]))) - singRsq = np.square(SR) - - # Applying analytical reconstruction - Yc = img[:, :, c] - inner = multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( - np.outer(singLsq, singRsq) + np.full(x_est.shape[0:2], self.lmbd) - ) - x_est[:, :, c] = multi_dot([VL, inner, VR.T]) - - # Non-negativity constraint: setting all negative values to 0 - x_est = x_est.clip(min=0) - - # Normalizing the image - x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) + assert ( + len(img.shape) == 3 + ), "Object should be a 3D array or tensor (HxWxC) even if grayscale." + + if torch_available and isinstance(img, torch.Tensor): + + # Empty matrix for reconstruction + n_channels = img.shape[-1] + x_est = torch.empty([self.P.shape[1], self.Q.shape[1], n_channels]) + + self.P = torch.from_numpy(self.P).float() + self.Q = torch.from_numpy(self.Q).float() + + # Applying reconstruction for each channel + for c in range(n_channels): + Yc = img[:, :, c] + + # SVD of left matrix + UL, SL, VLh = torch.linalg.svd(self.P) + VL = VLh.T + DL = torch.cat( + ( + torch.diag(SL), + torch.zeros([self.P.shape[0] - SL.size(0), SL.size(0)], device=SL.device), + ) + ) + singLsq = SL**2 + + # SVD of right matrix + UR, SR, VRh = torch.linalg.svd(self.Q) + VR = VRh.T + DR = torch.cat( + ( + torch.diag(SR), + torch.zeros([self.Q.shape[0] - SR.size(0), SR.size(0)], device=SR.device), + ) + ) + singRsq = SR**2 + + # Applying analytical reconstruction + inner = torch.linalg.multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( + torch.outer(singLsq, singRsq) + torch.full(x_est.shape[0:2], self.lmbd) + ) + x_est[:, :, c] = torch.linalg.multi_dot([VL, inner, VR.T]) + + # Non-negativity constraint: setting all negative values to 0 + x_est = torch.clamp(x_est, min=0) + + # Normalizing the image + x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) + + else: + + # Empty matrix for reconstruction + n_channels = img.shape[-1] + x_est = np.empty([self.P.shape[1], self.Q.shape[1], n_channels]) + + # Applying reconstruction for each channel + for c in range(n_channels): + + # SVD of left matrix + UL, SL, VLh = np.linalg.svd(self.P, full_matrices=True) + VL = VLh.T + DL = np.concatenate((np.diag(SL), np.zeros([self.P.shape[0] - SL.size, SL.size]))) + singLsq = np.square(SL) + + # SVD of right matrix + UR, SR, VRh = np.linalg.svd(self.Q, full_matrices=True) + VR = VRh.T + DR = np.concatenate((np.diag(SR), np.zeros([self.Q.shape[0] - SR.size, SR.size]))) + singRsq = np.square(SR) + + # Applying analytical reconstruction + Yc = img[:, :, c] + inner = multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( + np.outer(singLsq, singRsq) + np.full(x_est.shape[0:2], self.lmbd) + ) + x_est[:, :, c] = multi_dot([VL, inner, VR.T]) + + # Non-negativity constraint: setting all negative values to 0 + x_est = x_est.clip(min=0) + + # Normalizing the image + x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) return x_est diff --git a/scripts/sim/mask_single_file.py b/scripts/sim/mask_single_file.py index e8a741b5..8513e75c 100644 --- a/scripts/sim/mask_single_file.py +++ b/scripts/sim/mask_single_file.py @@ -19,6 +19,11 @@ python scripts/sim/mask_single_file.py mask.type=MURA mask.n_bits=99 simulation.flatcam=True recon.algo=tikhonov ``` +Using Torch +``` +python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=True recon.algo=tikhonov use_torch=True +``` + Simulate FlatCam with PSF simulation and Tikhonov reconstuction: ``` python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=False recon.algo=tikhonov @@ -56,6 +61,7 @@ import os from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture from lensless.recon.tikhonov import CodedApertureReconstruction +import torch @hydra.main(version_base=None, config_path="../../configs", config_name="mask_sim_single") @@ -107,6 +113,9 @@ def simulate(config): # 2) simulate measurement image = load_image(fp, verbose=True) / 255 + if config.use_torch: + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).float() flatcam_sim = config.simulation.flatcam if flatcam_sim and mask_type.upper() not in ["MURA", "MLS"]: @@ -116,17 +125,29 @@ def simulate(config): flatcam_sim = False # use far field simulator to get correct object plane sizing + psf = mask.psf + if config.use_torch: + psf = psf.transpose(2, 0, 1) + psf = torch.from_numpy(psf).float() + simulator = FarFieldSimulator( - psf=mask.psf, + psf=psf, object_height=object_height, scene2mask=scene2mask, mask2sensor=mask2sensor, sensor=sensor, snr_db=snr_db, max_val=max_val, + is_torch=config.use_torch, ) image_plane, object_plane = simulator.propagate(image, return_object_plane=True) + # channels as last dimension + if config.use_torch: + image_plane = image_plane.permute(1, 2, 0) + object_plane = object_plane.permute(1, 2, 0) + image = image.permute(1, 2, 0) + if image_format == "grayscale": image_plane = rgb2gray(image_plane) object_plane = rgb2gray(object_plane) @@ -178,6 +199,12 @@ def simulate(config): else: raise ValueError(f"Reconstruction algorithm {config.recon.algo} not recognized.") + # back to numpy for evaluation and plotting + if config.use_torch: + recovered = recovered.numpy() + object_plane = object_plane.numpy() + image_plane = image_plane.numpy() + # 4) evaluate if image_format == "grayscale": object_plane = object_plane[:, :, 0] @@ -218,7 +245,7 @@ def simulate(config): ax[4].set_title("Reconstruction") for a in ax: - a.set_xticks([]), a.set_yticks([]) + a.set_axis_off() plt.tight_layout() plt.savefig("result.png") From f67985ef02b8a9cc6fc061f41dd2f428b6a49822 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Thu, 31 Aug 2023 19:15:55 -0700 Subject: [PATCH 07/12] Add torch support for rgb2gray. (#85) * Add torch support for rgb2gray. * Fix pycsou install. * Update CHANGELOG. --- .github/workflows/python_pycsou.yml | 2 +- CHANGELOG.rst | 1 + README.rst | 8 ++-- lensless/utils/image.py | 62 ++++++++++++++++++++++------- test/test_io.py | 32 +++++++++++++-- 5 files changed, 82 insertions(+), 23 deletions(-) diff --git a/.github/workflows/python_pycsou.yml b/.github/workflows/python_pycsou.yml index d5cf1e91..61f89fa5 100644 --- a/.github/workflows/python_pycsou.yml +++ b/.github/workflows/python_pycsou.yml @@ -59,5 +59,5 @@ jobs: pip install -U pytest pip install -r recon_requirements.txt pip install -r mask_requirements.txt - pip install git+https://github.com/matthieumeo/pycsou.git@v2-dev + pip install git+https://github.com/matthieumeo/pycsou.git@38e9929c29509d350a7ff12c514e2880fdc99d6e pytest \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b1657928..90db0c99 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,7 @@ Added - Support for unrolled loading and inference in the script ``admm.py``. - Tikhonov reconstruction for coded aperture measurements (MLS / MURA): numpy and Pytorch support. - New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. +- PyTorch support for ``lensless.utils.io.rgb2gray``. Changed diff --git a/README.rst b/README.rst index 23066e12..eb5e7e72 100644 --- a/README.rst +++ b/README.rst @@ -84,15 +84,15 @@ install the library locally. python scripts/recon/admm.py -Note (25-04-2023): for using reconstruction method based on Pycsou ``lensless.apgd.APGD``, -V2 has to be installed: +Note (25-04-2023): for using the reconstruction method based on Pycsou (now [Pyxu](https://github.com/matthieumeo/pyxu)) +``lensless.apgd.APGD``, a specific commit has to be installed (as there was no release at the time of implementation): .. code:: bash - pip install git+https://github.com/matthieumeo/pycsou.git@v2-dev + pip install git+https://github.com/matthieumeo/pycsou.git@38e9929c29509d350a7ff12c514e2880fdc99d6e If PyTorch is installed, you will need to be sure to have PyTorch 2.0 or higher, -as Pycsou V2 is not compatible with earlier versions of PyTorch. Moreover, +as Pycsou is not compatible with earlier versions of PyTorch. Moreover, Pycsou requires Python within `[3.9, 3.11) `__. diff --git a/lensless/utils/image.py b/lensless/utils/image.py index 19c977e2..f3bbe28f 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -14,6 +14,7 @@ try: import torch import torchvision.transforms as tf + from torchvision.transforms.functional import rgb_to_grayscale torch_available = True except ImportError: @@ -82,10 +83,10 @@ def rgb2gray(rgb, weights=None, keepchanneldim=True): Parameters ---------- - rgb : :py:class:`~numpy.ndarray` + rgb : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` ([Depth,] Height, Width, Channel) image. weights : :py:class:`~numpy.ndarray` - [Optional] (3,) weights to convert from RGB to grayscale. + [Optional] (3,) weights to convert from RGB to grayscale. Only used for NumPy arrays. keepchanneldim : bool Whether to keep the channel dimension. Default is True. @@ -95,22 +96,53 @@ def rgb2gray(rgb, weights=None, keepchanneldim=True): Grayscale image of dimension ([depth,] height, width [, 1]). """ - if weights is None: - weights = np.array([0.299, 0.587, 0.114]) - assert len(weights) == 3 - - if len(rgb.shape) == 4: - image = np.tensordot(rgb, weights, axes=((3,), 0)) - elif len(rgb.shape) == 3: - image = np.tensordot(rgb, weights, axes=((2,), 0)) - else: - raise ValueError("Input must be at least 3D.") - if keepchanneldim: - return image[..., np.newaxis] - else: + use_torch = False + if torch_available: + if torch.is_tensor(rgb): + use_torch = True + + if use_torch: + + # move channel dimension to third to last + if len(rgb.shape) == 4: + rgb = rgb.permute(0, 3, 1, 2) + elif len(rgb.shape) == 3: + rgb = rgb.permute(2, 0, 1) + else: + raise ValueError("Input must be at least 3D.") + + image = rgb_to_grayscale(rgb) + + # move channel dimension to last + if len(rgb.shape) == 4: + image = image.permute(0, 2, 3, 1) + elif len(rgb.shape) == 3: + image = image.permute(1, 2, 0) + + if not keepchanneldim: + image = image.squeeze(-1) + return image + else: + + if weights is None: + weights = np.array([0.299, 0.587, 0.114]) + assert len(weights) == 3 + + if len(rgb.shape) == 4: + image = np.tensordot(rgb, weights, axes=((3,), 0)) + elif len(rgb.shape) == 3: + image = np.tensordot(rgb, weights, axes=((2,), 0)) + else: + raise ValueError("Input must be at least 3D.") + + if keepchanneldim: + return image[..., np.newaxis] + else: + return image + def gamma_correction(vals, gamma=2.2): """ diff --git a/test/test_io.py b/test/test_io.py index 16823e1f..5c2f8884 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,5 +1,4 @@ -from lensless.utils.io import load_data -import numpy as np +from lensless.utils.io import load_data, rgb2gray psf_fp = "data/psf/tape_rgb.png" data_fp = "data/raw_data/thumbs_up_rgb.png" @@ -26,4 +25,31 @@ def test_load_data(): assert data.dtype == dtype, dtype -test_load_data() +def test_rgb2gray(): + for is_torch in [True, False]: + psf, data = load_data( + psf_fp=psf_fp, + data_fp=data_fp, + downsample=downsample, + plot=False, + dtype="float32", + torch=is_torch, + ) + data = data[0] # drop first depth dimension + + # try with 4D + psf_gray = rgb2gray(psf, keepchanneldim=False) + assert len(psf_gray.shape) == 3 + psf_gray = rgb2gray(psf, keepchanneldim=True) + assert len(psf_gray.shape) == 4 + + # try with 3D + data_gray = rgb2gray(data, keepchanneldim=False) + assert len(data_gray.shape) == 2 + data_gray = rgb2gray(data, keepchanneldim=True) + assert len(data_gray.shape) == 3 + + +if __name__ == "__main__": + test_load_data() + test_rgb2gray() From 8dfdc554bdfbe00514095099204a9c65e8d9c25d Mon Sep 17 00:00:00 2001 From: YohannPerron <73244423+YohannPerron@users.noreply.github.com> Date: Tue, 5 Sep 2023 23:01:51 +0200 Subject: [PATCH 08/12] Trainable mask (#81) * Add support for changing the psf * First implementation of trainable mask * Fix to projection * add support for trainable mask * new datased with trainable mask * Fix comment and dataset name * Fix for SimulatedDatasetTrainableMask * Update to trainer save * If no test dataset, sample from test * Add support for l1 regularisation on mask * Support for gray mask to rgb psf * remove update frequency param * add auto gray to rgb conversion * fix update bug * Update simulation for TrainableMask * Fix SimulatedDatasetTrainableMask * Clean and changelog * Fix simulation flip * Fix not using a mask * Default config doesn't use TrainableMask * Fix PR comment * Add config for PSF fine-tuning * Added to doc * Fix / update dataset docs. * Add method to set PSF of simulator. * Add check for dataset. * Move trainable mask. * Fix trainable mask documentation. * Fix docs. --------- Co-authored-by: Eric Bezzam --- CHANGELOG.rst | 2 + configs/fine-tune_PSF.yaml | 119 ++++++++++++++++++++++++++++ configs/unrolled_recon.yaml | 34 +++++--- docs/requirements.txt | 2 +- docs/source/dataset.rst | 34 ++++++++ docs/source/mask.rst | 10 +++ lensless/hardware/trainable_mask.py | 86 ++++++++++++++++++++ lensless/recon/recon.py | 22 +++++ lensless/recon/rfft_convolve.py | 3 + lensless/recon/utils.py | 74 +++++++++++------ lensless/utils/dataset.py | 54 ++++++++++++- lensless/utils/simulation.py | 33 ++++++++ mask_requirements.txt | 2 +- recon_requirements.txt | 2 +- scripts/recon/train_unrolled.py | 91 +++++++++++++++++---- 15 files changed, 510 insertions(+), 58 deletions(-) create mode 100644 configs/fine-tune_PSF.yaml create mode 100644 lensless/hardware/trainable_mask.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 90db0c99..ddf3c78a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,8 @@ Added - Support for unrolled loading and inference in the script ``admm.py``. - Tikhonov reconstruction for coded aperture measurements (MLS / MURA): numpy and Pytorch support. - New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. +- New ``TrainableMask`` and ``TrainablePSF`` class to train/fine-tune a mask from a dataset. +- New ``SimulatedDatasetTrainableMask`` class to train/fine-tune a mask for measurement. - PyTorch support for ``lensless.utils.io.rgb2gray``. diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml new file mode 100644 index 00000000..040dc81b --- /dev/null +++ b/configs/fine-tune_PSF.yaml @@ -0,0 +1,119 @@ +hydra: + job: + chdir: True # change to output folder + +#Reconstruction algorithm +input: + # File path for recorded PSF + psf: data/DiffuserCam_Mirflickr_200_3011302021_11h43_seed11/psf.tiff + dtype: float32 + +torch: True +torch_device: 'cuda' + +preprocess: + # Image shape (height, width) for reconstruction. + shape: null + # Whether image is raw bayer data. + bayer: False + blue_gain: null + red_gain: null + # Same PSF for all channels (sum) or unique PSF for RGB. + single_psf: False + # Whether to perform construction in grayscale. + gray: False + + +display: + # How many iterations to wait for intermediate plot. + # Set to negative value for no intermediate plots. + disp: 500 + # Whether to plot results. + plot: True + # Gamma factor for plotting. + gamma: null + +# Whether to save intermediate and final reconstructions. +save: True + +reconstruction: + # Method: unrolled_admm, unrolled_fista + method: unrolled_admm + + # Hyperparameters for each method + unrolled_fista: # for unrolled_fista + # Number of iterations + n_iter: 20 + tk: 1 + learn_tk: True + unrolled_admm: + # Number of iterations + n_iter: 20 + # Hyperparameters + mu1: 1e-4 + mu2: 1e-4 + mu3: 1e-4 + tau: 2e-4 + pre_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + +#Trainable Mask +trainable_mask: + mask_type: TrainablePSF #Null or "TrainablePSF" + initial_value: "DiffuserCam" # "random" or "DiffuserCam" or "DiffuserCam_gray" + mask_lr: 1e-3 + L1_strength: 1.0 #False or float + use_mask_in_dataset : False # Work only with simulated dataset + +# Train Dataset +files: + dataset: "DiffuserCam" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + n_files: null # null to use all + +target: "object_plane" # "original" or "object_plane" or "label" + +#for simulated dataset +simulation: + grayscale: False + # random variations + object_height: 0.04 # range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) + random_shift: False + random_vflip: 0.5 + random_hflip: 0.5 + random_rotate: False + # these distance parameters are typically fixed for a given PSF + # for DiffuserCam psf # for tape_rgb psf + scene2mask: 10e-2 # scene2mask: 40e-2 + mask2sensor: 9e-3 # mask2sensor: 4e-3 + # see waveprop.devices + sensor: "rpi_hq" + snr_db: 10 + # simulate different sensor resolution + # output_dim: [24, 32] # [H, W] or null + # Downsampling for PSF + downsample: 8 + # max val in simulated measured (quantized 8 bits) + max_val: 255 + +#Training + +training: + batch_size: 8 + epoch: 10 + #In case of instable training + skip_NAN: True + slow_start: False #float how much to reduce lr for first epoch + + +optimizer: + type: Adam + lr: 1e-4 + +loss: 'l2' +# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) +lpips: 1.0 \ No newline at end of file diff --git a/configs/unrolled_recon.yaml b/configs/unrolled_recon.yaml index 621e3cfa..2673f20c 100644 --- a/configs/unrolled_recon.yaml +++ b/configs/unrolled_recon.yaml @@ -27,7 +27,7 @@ preprocess: display: # How many iterations to wait for intermediate plot. # Set to negative value for no intermediate plots. - disp: 400 + disp: 500 # Whether to plot results. plot: True # Gamma factor for plotting. @@ -48,23 +48,30 @@ reconstruction: learn_tk: True unrolled_admm: # Number of iterations - n_iter: 5 + n_iter: 20 # Hyperparameters mu1: 1e-4 mu2: 1e-4 mu3: 1e-4 tau: 2e-4 pre_process: - network : UnetRes # UnetRes or DruNet or null + network : null # UnetRes or DruNet or null depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet post_process: - network : UnetRes # UnetRes or DruNet or null + network : null # UnetRes or DruNet or null depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet -# Train Dataset +#Trainable Mask +trainable_mask: + mask_type: Null #Null or "TrainablePSF" + initial_value: "DiffuserCam_gray" # "random" or "DiffuserCam" or "DiffuserCam_gray" + mask_lr: 1e-3 + L1_strength: 1.0 #False or float + use_mask_in_dataset : True # Work only with simulated dataset +# Train Dataset files: - dataset: "DiffuserCam" # "mnist", "fashion_mnist", "cifar10", "CelebA", "DiffuserCam" + dataset: "DiffuserCam" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" n_files: null # null to use all target: "object_plane" # "original" or "object_plane" or "label" @@ -73,18 +80,19 @@ target: "object_plane" # "original" or "object_plane" or "label" simulation: grayscale: False # random variations - object_height: 0.6 # range for random height or scalar + object_height: 0.04 # range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) random_shift: False random_vflip: 0.5 random_hflip: 0.5 random_rotate: False # these distance parameters are typically fixed for a given PSF - # for tape_rgb psf # for DiffuserCam psf - scene2mask: 40e-2 # scene2mask: 10e-2 - mask2sensor: 4e-3 # mask2sensor: 9e-3 + # for DiffuserCam psf # for tape_rgb psf + scene2mask: 10e-2 # scene2mask: 40e-2 + mask2sensor: 9e-3 # mask2sensor: 4e-3 # see waveprop.devices sensor: "rpi_hq" - snr_db: 40 + snr_db: 10 # simulate different sensor resolution # output_dim: [24, 32] # [H, W] or null # Downsampling for PSF @@ -96,7 +104,7 @@ simulation: training: batch_size: 8 - epoch: 50 + epoch: 10 #In case of instable training skip_NAN: True slow_start: False #float how much to reduce lr for first epoch @@ -104,7 +112,7 @@ training: optimizer: type: Adam - lr: 1e-6 + lr: 1e-4 loss: 'l2' # set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) diff --git a/docs/requirements.txt b/docs/requirements.txt index 484c5d20..3eb1e15f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,4 +6,4 @@ torch>=1.10 torchvision>=0.15.2 torchmetrics>=0.11.4 pyFFS>=2.2.3 # for waveprop -waveprop>=0.0.5 \ No newline at end of file +waveprop>=0.0.7 \ No newline at end of file diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst index 1312e1cc..ad21defb 100644 --- a/docs/source/dataset.rst +++ b/docs/source/dataset.rst @@ -6,14 +6,48 @@ datasets for training and testing. .. automodule:: lensless.utils.dataset +Abstract base class +------------------- + +All dataset objects derive from this abstract base class, which +lays out the notion of a dataset with pairs of images: one image +is lensed (simulated or measured), and the other is lensless (simulated +or measured). + .. autoclass:: lensless.utils.dataset.DualDataset :members: _get_images_pair :special-members: __init__, __len__ + +Simulated dataset objects +------------------------- + +These dataset objects can be used for training and testing with +simulated data. The main assumption is that the imaging system +is linear shift-invariant (LSI), and that the lensless image is +the result of a convolution of the lensed image with a point-spread +function (PSF). Check out `this Medium post `__ +for more details on the simulation procedure. + +With simulated data, we can avoid the hassle of collecting a large +amount of data. However, it's important to note that the LSI assumption +can sometimes be too idealistic, in particular for large angles. + +Nevertheless, simulating data is the only option of learning the +mask / PSF. + .. autoclass:: lensless.utils.dataset.SimulatedFarFieldDataset :members: :special-members: __init__ +.. autoclass:: lensless.utils.dataset.SimulatedDatasetTrainableMask + :members: + :special-members: __init__ + + +Measured dataset objects +------------------------ + .. autoclass:: lensless.utils.dataset.MeasuredDataset :members: :special-members: __init__ diff --git a/docs/source/mask.rst b/docs/source/mask.rst index 0ad8327e..036d0f12 100644 --- a/docs/source/mask.rst +++ b/docs/source/mask.rst @@ -29,5 +29,15 @@ ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: lensless.hardware.mask.FresnelZoneAperture + :members: + :special-members: __init__ + + Trainable Mask + ~~~~~~~~~~~~~~~~~~~~~ + .. autoclass:: lensless.hardware.trainable_mask.TrainableMask + :members: + :special-members: __init__ + + .. autoclass:: lensless.hardware.trainable_mask.TrainablePSF :members: :special-members: __init__ \ No newline at end of file diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py new file mode 100644 index 00000000..593c2360 --- /dev/null +++ b/lensless/hardware/trainable_mask.py @@ -0,0 +1,86 @@ +# ############################################################################# +# trainable_mask.py +# ================== +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# ############################################################################# + +import abc +import torch + + +class TrainableMask(metaclass=abc.ABCMeta): + """ + Abstract class for defining trainable masks. + + The following abstract methods need to be defined: + + - :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`: returning the PSF of the mask. + - :py:class:`~lensless.hardware.trainable_mask.TrainableMask.project`: projecting the mask parameters to a valid space (should be a subspace of [0,1]). + + """ + + def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): + """ + Base constructor. Derived constructor may define new state variables + + Parameters + ---------- + initial_mask : :py:class:`~torch.Tensor` + Initial mask parameters. + optimizer : str, optional + Optimizer to use for updating the mask parameters, by default "Adam" + lr : float, optional + Learning rate for the mask parameters, by default 1e-3 + """ + self._mask = torch.nn.Parameter(initial_mask) + self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs) + self._counter = 0 + + @abc.abstractmethod + def get_psf(self): + """ + Abstract method for getting the PSF of the mask. Should be fully compatible with pytorch autograd. + + Returns + ------- + :py:class:`~torch.Tensor` + The PSF of the mask. + """ + raise NotImplementedError + + def update_mask(self): + """Update the mask parameters. Acoording to externaly updated gradiants.""" + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self.project() + self._counter += 1 + + @abc.abstractmethod + def project(self): + """Abstract method for projecting the mask parameters to a valid space (should be a subspace of [0,1]).""" + raise NotImplementedError + + +class TrainablePSF(TrainableMask): + """ + Class for defining an object that directly optimizes the PSF, without any constraints on what can be realized physically. + + Parameters + ---------- + is_rgb : bool, optional + Whether the mask is RGB or not, by default True. + """ + + def __init__(self, is_rgb=True, **kwargs): + super().__init__(**kwargs) + self._is_rgb = is_rgb + + def get_psf(self): + if self._is_rgb: + return self._mask.expand(-1, -1, -1, 3) + else: + return self._mask + + def project(self): + self._mask.data = torch.clamp(self._mask, 0, 1) diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 1124c289..444e3b0a 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -404,6 +404,28 @@ def get_image_estimate(self): """Get current image estimate as [Batch, Depth, Height, Width, Channels].""" return self._form_image() + def _set_psf(self, psf): + """ + Set PSF. + + Parameters + ---------- + psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + PSF to set. + """ + assert len(psf.shape) == 4, "PSF must be 4D: (depth, height, width, channels)." + assert psf.shape[3] == 3 or psf.shape[3] == 1, "PSF must either be rgb (3) or grayscale (1)" + assert self._psf.shape == psf.shape, "new PSF must have same shape as old PSF" + assert isinstance(psf, type(self._psf)), "new PSF must have same type as old PSF" + + self._psf = psf + self._convolver = RealFFTConvolve2D( + psf, + dtype=self._convolver._psf.dtype, + pad=self._convolver.pad, + norm=self._convolver.norm, + ) + def _progress(self): """ Optional method for printing progress update, e.g. relative improvement diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index 5c867cd3..34cca96a 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -57,6 +57,9 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs): self._is_rgb = psf.shape[3] == 3 assert self._is_rgb or psf.shape[3] == 1 + # save normalization + self.norm = norm + # set dtype if dtype is None: if self.is_torch: diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 54d23a1d..5f091c53 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -1,5 +1,5 @@ # ############################################################################# -# dataset.py +# utils.py # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] @@ -15,6 +15,7 @@ import matplotlib.pyplot as plt import torch from lensless.eval.benchmark import benchmark +from lensless.hardware.trainable_mask import TrainableMask from tqdm import tqdm from lensless.recon.drunet.network_unet import UNetRes @@ -222,9 +223,11 @@ def __init__( recon, train_dataset, test_dataset, + mask=None, batch_size=4, loss="l2", lpips=None, + l1_mask=None, optimizer="Adam", optimizer_lr=1e-6, slow_start=None, @@ -242,12 +245,16 @@ def __init__( Dataset to use for training. test_dataset : :py:class:`torch.utils.data.Dataset` Dataset to use for testing. + mask : TrainableMask, optional + Trainable mask to use for training. If none, training with fix psf, by default None. batch_size : int, optional Batch size to use for training, by default 4 loss : str, optional Loss function to use for training "l1" or "l2", by default "l2" lpips : float, optional the weight of the lpips(VGG) in the total loss. If None ignore. By default None + l1_mask : float, optional + the weight of the l1 norm of the mask in the total loss. If None ignore. By default None optimizer : str, optional Optimizer to use durring training. Available : "Adam". By default "Adam" optimizer_lr : float, optional @@ -263,6 +270,15 @@ def __init__( self.device = recon._psf.device self.recon = recon + + if test_dataset is None: + # split train dataset + train_size = int(0.9 * len(train_dataset)) + test_size = len(train_dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split( + train_dataset, [train_size, test_size] + ) + self.train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=batch_size, @@ -273,6 +289,15 @@ def __init__( self.lpips = lpips self.skip_NAN = skip_NAN + if mask is not None: + assert isinstance(mask, TrainableMask) + self.mask = mask + self.use_mask = True + else: + self.use_mask = False + + self.l1_mask = l1_mask + # loss if loss == "l2": self.Loss = torch.nn.MSELoss() @@ -358,8 +383,8 @@ def train_epoch(self, data_loader, disp=-1): ---------- data_loader : :py:class:`torch.utils.data.DataLoader` Data loader to use for training. - disp : int, optional - Display interval, if -1, no display, by default -1 + disp : int + Display interval, if -1, no display Returns ------- @@ -374,6 +399,11 @@ def train_epoch(self, data_loader, disp=-1): X = X.to(self.device) y = y.to(self.device) + # update psf according to mask + if self.use_mask: + self.recon._set_psf(self.mask.get_psf()) + + # forward pass y_pred = self.recon.batch_call(X.to(self.device)) # normalizing each output eps = 1e-12 @@ -404,6 +434,8 @@ def train_epoch(self, data_loader, disp=-1): loss_v = loss_v + self.lpips * torch.mean( self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) ) + if self.use_mask and self.l1_mask: + loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask)) loss_v.backward() torch.nn.utils.clip_grad_norm_(self.recon.parameters(), 1.0) @@ -421,6 +453,10 @@ def train_epoch(self, data_loader, disp=-1): continue self.optimizer.step() + # update mask + if self.use_mask: + self.mask.update_mask() + mean_loss += (loss_v.item() - mean_loss) * (1 / i) pbar.set_description(f"loss : {mean_loss}") i += 1 @@ -488,6 +524,7 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): start_time = time.time() + self.evaluate(-1, save_pt) for epoch in range(n_epoch): print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") mean_loss = self.train_epoch(self.train_dataloader, disp=disp) @@ -497,31 +534,18 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): print(f"Train time : {time.time() - start_time} s") def save(self, path="recon", include_optimizer=False): - """ - Save state of reconstruction algorithm. - - Parameters - ---------- - path : str, optional - Path to save model to, by default "recon" - include_optimizer : bool, optional - Whether to include optimizer state, by default False - - """ # create directory if it does not exist if not os.path.exists(path): os.makedirs(path) - - # TODO : ADD mask support - # # save mask - # if self.use_mask: - # torch.save(self.mask._mask, os.path.join(path, "mask.pt")) - # torch.save(self.mask._optimizer.state_dict(), os.path.join(path, "mask_optim.pt")) - # import matplotlib.pyplot as plt - - # plt.imsave( - # os.path.join(path, "psf.png"), self.mask.get_psf().detach().cpu().numpy()[0, ...] - # ) + # save mask + if self.use_mask: + torch.save(self.mask._mask, os.path.join(path, "mask.pt")) + torch.save(self.mask._optimizer.state_dict(), os.path.join(path, "mask_optim.pt")) + import matplotlib.pyplot as plt + + plt.imsave( + os.path.join(path, "psf.png"), self.mask.get_psf().detach().cpu().numpy()[0, ...] + ) # save optimizer if include_optimizer: torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 2634cb7c..67aa7a8d 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -3,6 +3,7 @@ # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# import numpy as np @@ -144,7 +145,7 @@ def __init__( dataset : :py:class:`torch.utils.data.Dataset` Dataset to propagate. Should output images with shape [H, W, C] unless ``dataset_is_CHW`` is ``True`` (and therefore images have the dimension ordering of [C, H, W]). simulator : :py:class:`lensless.utils.simulation.FarFieldSimulator` - Simulator object used on images from ``dataset``.Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. + Simulator object used on images from ``dataset``. Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. pre_transform : PyTorch Transform or None, optional Transform to apply to the images before simulation, by default ``None``. Note that this transform is applied on HCW images (different from torchvision). dataset_is_CHW : bool, optional @@ -176,7 +177,7 @@ def get_image(self, index): def _get_images_pair(self, index): # load image img, _ = self.get_image(index) - # convert to CHW for simulator and transform + # convert to HWC for simulator and transform if self.dataset_is_CHW: img = img.moveaxis(-3, -1) if self.flip_pre_sim: @@ -446,3 +447,52 @@ def __init__( lensed_fn="lensed", image_ext="npy", ) + + +class SimulatedDatasetTrainableMask(SimulatedFarFieldDataset): + """ + Dataset of propagated images (through simulation) from a Torch Dataset with learnable mask. + The `waveprop `_ package is used for the simulation, + assuming a far-field propagation and a shift-invariant system with a single point spread function (PSF). + To ensure autograd compatibility, the dataloader should have ``num_workers=0``. + """ + + def __init__( + self, + mask, + dataset, + simulator, + **kwargs, + ): + """ + Parameters + ---------- + + mask : :py:class:`lensless.hardware.trainable_mask.TrainableMask` + Mask to use for simulation. Should be a 4D tensor with shape [1, H, W, C]. Simulation of multi-depth data is not supported yet. + dataset : :py:class:`torch.utils.data.Dataset` + Dataset to propagate. Should output images with shape [H, W, C] unless ``dataset_is_CHW`` is ``True`` (and therefore images have the dimension ordering of [C, H, W]). + simulator : :py:class:`lensless.utils.simulation.FarFieldSimulator` + Simulator object used on images from ``dataset``. Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. + """ + + self._mask = mask + + temp_psf = self._mask.get_psf() + test_sim = FarFieldSimulator(psf=temp_psf, **simulator.params) + assert ( + test_sim.conv_dim == simulator.conv_dim + ).all(), "PSF shape should match simulator shape" + assert ( + not simulator.quantize + ), "Simulator should not perform quantization to maintain differentiability. Please set quantize=False" + + super(SimulatedDatasetTrainableMask, self).__init__(dataset, simulator, **kwargs) + + def _get_images_pair(self, index): + # update psf + psf = self._mask.get_psf() + self.sim.set_psf(psf) + + # return simulated images + return super()._get_images_pair(index) diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py index 36aac243..e7f7af3a 100644 --- a/lensless/utils/simulation.py +++ b/lensless/utils/simulation.py @@ -3,6 +3,7 @@ # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# import numpy as np @@ -27,6 +28,7 @@ def __init__( device_conv="cpu", random_shift=False, is_torch=False, + quantize=True, **kwargs ): """ @@ -52,6 +54,8 @@ def __init__( Whether to randomly shift the image, by default False. is_torch : bool, optional Whether to use pytorch, by default False. + quantize : bool, optional + Whether to quantize image, by default True. """ if psf is not None: @@ -70,9 +74,38 @@ def __init__( device_conv, random_shift, is_torch, + quantize, **kwargs ) + # save all the parameters in a dict + self.params = { + "object_height": object_height, + "scene2mask": scene2mask, + "mask2sensor": mask2sensor, + "sensor": sensor, + "output_dim": output_dim, + "snr_db": snr_db, + "max_val": max_val, + "device_conv": device_conv, + "random_shift": random_shift, + "is_torch": is_torch, + "quantize": quantize, + } + self.params.update(kwargs) + + def set_psf(self, psf): + """ + Set point spread function. + + Parameters + ---------- + psf : np.ndarray or torch.Tensor + Point spread function. + """ + psf = psf.squeeze().movedim(-1, 0) + return super().set_psf(psf) + def propagate(self, obj, return_object_plane=False): """ Parameters diff --git a/mask_requirements.txt b/mask_requirements.txt index ee87c51f..699ba552 100644 --- a/mask_requirements.txt +++ b/mask_requirements.txt @@ -1,3 +1,3 @@ sympy>=1.11.1 perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 -waveprop>=0.0.4 \ No newline at end of file +waveprop>=0.0.7 \ No newline at end of file diff --git a/recon_requirements.txt b/recon_requirements.txt index b9e9f324..33e12092 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -3,7 +3,7 @@ lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 click>=8.0.1 -waveprop>=0.0.3 # for simulation +waveprop>=0.0.7 # for simulation # Library for learning algorithm torch >= 2.0.0 diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 7d0a31e1..c669ea2e 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -12,6 +12,11 @@ python scripts/recon/train_unrolled.py ``` +To fine-tune the DiffuserCam PSF, use the following command: +``` +python scripts/recon/train_unrolled.py -cn fine-tune_PSF +``` + """ import hydra @@ -20,7 +25,12 @@ import numpy as np import time from lensless import UnrolledFISTA, UnrolledADMM -from lensless.utils.dataset import DiffuserCamTestDataset, SimulatedFarFieldDataset +from lensless.utils.dataset import ( + DiffuserCamTestDataset, + SimulatedFarFieldDataset, + SimulatedDatasetTrainableMask, +) +import lensless.hardware.trainable_mask from lensless.recon.utils import create_process_network from lensless.utils.image import rgb2gray from lensless.utils.simulation import FarFieldSimulator @@ -29,7 +39,7 @@ from torchvision import transforms, datasets -def simulate_dataset(config, psf): +def simulate_dataset(config, psf, mask=None): # load dataset transforms_list = [transforms.ToTensor()] data_path = os.path.join(get_original_cwd(), "data") @@ -71,9 +81,23 @@ def simulate_dataset(config, psf): # create Pytorch dataset and dataloader if n_files is not None: ds = torch.utils.data.Subset(ds, np.arange(n_files)) - ds_prop = SimulatedFarFieldDataset( - dataset=ds, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv - ) + if mask is None: + ds_prop = SimulatedFarFieldDataset( + dataset=ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + else: + ds_prop = SimulatedDatasetTrainableMask( + dataset=ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) return ds_prop @@ -96,12 +120,11 @@ def train_unrolled( data_dir=path, downsample=config.simulation.downsample ) - psf = benchmark_dataset.psf.to(device) + diffusercam_psf = benchmark_dataset.psf.to(device) background = benchmark_dataset.background # convert psf from BGR to RGB - if config.files.dataset in ["DiffuserCam"]: - psf = psf[..., [2, 1, 0]] + diffusercam_psf = diffusercam_psf[..., [2, 1, 0]] # if using a portrait dataset rotate the PSF @@ -130,7 +153,7 @@ def train_unrolled( # create reconstruction algorithm if config.reconstruction.method == "unrolled_fista": recon = UnrolledFISTA( - psf, + diffusercam_psf, n_iter=config.reconstruction.unrolled_fista.n_iter, tk=config.reconstruction.unrolled_fista.tk, pad=True, @@ -140,7 +163,7 @@ def train_unrolled( ).to(device) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( - psf, + diffusercam_psf, n_iter=config.reconstruction.unrolled_admm.n_iter, mu1=config.reconstruction.unrolled_admm.mu1, mu2=config.reconstruction.unrolled_admm.mu2, @@ -164,6 +187,25 @@ def train_unrolled( # transform from BGR to RGB transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + # create mask + if config.trainable_mask.mask_type is not None: + mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) + if config.trainable_mask.initial_value == "random": + mask = mask_class( + torch.rand_like(diffusercam_psf), optimizer="Adam", lr=config.trainable_mask.mask_lr + ) + elif config.trainable_mask.initial_value == "DiffuserCam": + mask = mask_class(diffusercam_psf, optimizer="Adam", lr=config.trainable_mask.mask_lr) + elif config.trainable_mask.initial_value == "DiffuserCam_gray": + mask = mask_class( + diffusercam_psf[:, :, :, 0, None], + optimizer="Adam", + lr=config.trainable_mask.mask_lr, + is_rgb=not config.simulation.grayscale, + ) + else: + mask = None + # load dataset and create dataloader if config.files.dataset == "DiffuserCam": # Use a ParallelDataset @@ -174,11 +216,12 @@ def train_unrolled( max_indices = config.files.n_files + 1000 data_path = os.path.join(get_original_cwd(), "data", "DiffuserCam") + assert os.path.exists(data_path), "DiffuserCam dataset not found" dataset = MeasuredDataset( root_dir=data_path, indices=range(1000, max_indices), background=background, - psf=psf, + psf=diffusercam_psf, lensless_fn="diffuser_images", lensed_fn="ground_truth_lensed", downsample=config.simulation.downsample / 4, @@ -187,17 +230,25 @@ def train_unrolled( ) else: # Use a simulated dataset - dataset = simulate_dataset(config, psf) + if config.trainable_mask.use_mask_in_dataset: + dataset = simulate_dataset(config, diffusercam_psf, mask=mask) + # the mask use will differ from the one in the benchmark dataset + print("Trainable Mask will be used in the test dataset") + benchmark_dataset = None + else: + dataset = simulate_dataset(config, diffusercam_psf, mask=None) print(f"Setup time : {time.time() - start_time} s") - + print(f"PSF shape : {diffusercam_psf.shape}") trainer = Trainer( recon, dataset, benchmark_dataset, + mask=mask, batch_size=config.training.batch_size, loss=config.loss, lpips=config.lpips, + l1_mask=config.trainable_mask.L1_strength, optimizer=config.optimizer.type, optimizer_lr=config.optimizer.lr, slow_start=config.training.slow_start, @@ -205,8 +256,18 @@ def train_unrolled( algorithm_name=algorithm_name, ) - trainer.train(n_epoch=config.training.epoch, save_pt=save) - trainer.save(path=os.path.join(save, "recon.pt")) + trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) + + if mask is not None: + print("Saving mask") + print(f"mask shape: {mask._mask.shape}") + torch.save(mask._mask, os.path.join(save, "mask.pt")) + # save as image using plt + import matplotlib.pyplot as plt + + print(f"mask max: {mask._mask.max()}") + print(f"mask min: {mask._mask.min()}") + plt.imsave(os.path.join(save, "mask.png"), mask._mask.detach().cpu().numpy()[0, ...]) if __name__ == "__main__": From 3f78c2421a170057718f01952c1cc53c3e8d42d7 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 5 Sep 2023 16:23:31 -0700 Subject: [PATCH 09/12] Fix readme rendering. (#88) --- README.rst | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index eb5e7e72..5a88de08 100644 --- a/README.rst +++ b/README.rst @@ -60,7 +60,8 @@ Python 3.9, as some Python library versions may not be available with earlier versions of Python. Moreover, its `end-of-life `__ is Oct 2025. -**Local machine** +*Local machine setup* +===================== Below are commands that worked for our configuration (Ubuntu 21.04), but there are certainly other ways to download a repository and @@ -83,9 +84,13 @@ install the library locally. # (optional) try reconstruction on local machine python scripts/recon/admm.py + # (optional) try reconstruction on local machine with GPU + python scripts/recon/admm.py -cn pytorch -Note (25-04-2023): for using the reconstruction method based on Pycsou (now [Pyxu](https://github.com/matthieumeo/pyxu)) -``lensless.apgd.APGD``, a specific commit has to be installed (as there was no release at the time of implementation): + +Note (25-04-2023): for using the :py:class:`~lensless.recon.apgd.APGD` reconstruction method based on Pycsou +(now `Pyxu `__), a specific commit has +to be installed (as there was no release at the time of implementation): .. code:: bash @@ -102,7 +107,8 @@ Moreover, ``numba`` (requirement for Pycsou V2) may require an older version of pip install numpy==1.23.5 -**Raspberry Pi** +*Raspberry Pi setup* +==================== After `flashing your Raspberry Pi with SSH enabled `__, you need to set it up for `passwordless access `__. From 4ae024747a0dbf0b56728accf5087f59b61bc696 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 5 Sep 2023 16:28:28 -0700 Subject: [PATCH 10/12] Bump version to v1.0.5. --- CHANGELOG.rst | 18 ++++++++++++++++++ lensless/version.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ddf3c78a..3ae1221a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,24 @@ Unreleased Added ~~~~~ +- Nothing + +Changed +~~~~~~~ + +- Nothing + +Bugfix +~~~~~~ + +- Nothing + +1.0.5 - (2023-09-05) +-------------------- + +Added +~~~~~ + - Sensor module. - Single-script and Telegram demo. - Link and citation for JOSS. diff --git a/lensless/version.py b/lensless/version.py index 92192eed..68cdeee4 100644 --- a/lensless/version.py +++ b/lensless/version.py @@ -1 +1 @@ -__version__ = "1.0.4" +__version__ = "1.0.5" From fa91052906ad1ba174cca4956f3297f819c08888 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 5 Sep 2023 16:56:11 -0700 Subject: [PATCH 11/12] Simplify setup config for PyPI rendering. --- setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 79468810..392fc7fe 100644 --- a/setup.py +++ b/setup.py @@ -6,15 +6,16 @@ exec(f.read()) assert __version__ is not None -with open("README.rst", "r", encoding="utf-8") as fh: - long_description = fh.read() +# with open("README.rst", "r", encoding="utf-8") as fh: +# long_description = fh.read() +long_description = "See the documentation at https://lensless.readthedocs.io/en/latest/" setuptools.setup( name="lensless", version=__version__, author="Eric Bezzam", author_email="ebezzam@gmail.com", - description="Package to control and image with a lensless camera running on a Raspberry Pi.", + description="All-in-one package for lensless imaging: design, simulation, measurement, reconstruction.", long_description=long_description, long_description_content_type="text/x-rst", url="https://github.com/LCAV/LenslessPiCam", From 753a64a1a712ab521751d26528d2c757ed7ddf45 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 20 Sep 2023 17:52:44 +0200 Subject: [PATCH 12/12] Clean up unrolled training + PSF learning + simulated datasets. (#90) * New default config * Add new config to train psf from scratch Namely, it simulates the dataset with the mask/PSF that's being optimized. * Clean up configs. * Clean up config, clearer downsample. * Clean up test set loading. * Add download for drunet. * Fix message. * Improve diffusercam test dataset api. * New object for Mirflickr dataset. * Index train set correctly. * Update requirements for reconstruction. * Update documentation. * Adapt ADMM script for intermediate output. * remove raise error * Update number epochs. * Fix normalization. * Add logic for saving best model. * Clean up PSF fine-tuning. * Clean up fine-tuning PSF. * Clean up training with simulated dataset. * Update CHANGELOG. --------- Co-authored-by: YohannPerron Co-authored-by: YohannPerron <73244423+YohannPerron@users.noreply.github.com> --- CHANGELOG.rst | 11 +- configs/defaults_recon.yaml | 2 + .../diffusercam_mirflickr_single_admm.yaml | 43 +++ configs/fine-tune_PSF.yaml | 117 +------ configs/train_pre-post-processing.yaml | 24 ++ configs/train_psf_from_scratch.yaml | 18 ++ ...led_recon.yaml => train_unrolledADMM.yaml} | 42 +-- lensless/eval/benchmark.py | 15 +- lensless/hardware/trainable_mask.py | 32 +- lensless/recon/trainable_recon.py | 44 ++- lensless/recon/utils.py | 150 +++++++-- lensless/utils/dataset.py | 152 +++++++-- lensless/utils/image.py | 17 + lensless/utils/io.py | 19 +- lensless/utils/simulation.py | 58 +++- mask_requirements.txt | 2 +- recon_requirements.txt | 3 +- scripts/eval/benchmark_recon.py | 4 +- scripts/recon/admm.py | 114 ++++++- scripts/recon/train_unrolled.py | 295 +++++++++++------- 20 files changed, 815 insertions(+), 347 deletions(-) create mode 100644 configs/diffusercam_mirflickr_single_admm.yaml create mode 100644 configs/train_pre-post-processing.yaml create mode 100644 configs/train_psf_from_scratch.yaml rename configs/{unrolled_recon.yaml => train_unrolledADMM.yaml} (74%) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3ae1221a..a0492898 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,17 +13,22 @@ Unreleased Added ~~~~~ -- Nothing +- Trainable reconstruction can return intermediate outputs (between pre- and post-processing). +- Auto-download for DRUNet model. +- ``utils.dataset.DiffuserCamMirflickr`` helper class for Mirflickr dataset. Changed ~~~~~~~ -- Nothing +- Better logic for saving best model. Based on desired metric rather than last epoch, and intermediate models can be saved. +- Optional normalization in ``utils.io.load_image``. Bugfix ~~~~~~ -- Nothing +- Support for unrolled reconstruction with grayscale, needed to copy to three channels for LPIPS. +- Fix bad train/test split for DiffuserCamMirflickr in unrolled training. + 1.0.5 - (2023-09-05) -------------------- diff --git a/configs/defaults_recon.yaml b/configs/defaults_recon.yaml index 324aa679..1771ff8a 100644 --- a/configs/defaults_recon.yaml +++ b/configs/defaults_recon.yaml @@ -8,11 +8,13 @@ input: # File path for raw data data: data/raw_data/thumbs_up_rgb.png dtype: float32 + original: null # ground truth image torch: False torch_device: 'cpu' preprocess: + normalize: True # Downsampling factor along X and Y downsample: 4 # Image shape (height, width) for reconstruction. diff --git a/configs/diffusercam_mirflickr_single_admm.yaml b/configs/diffusercam_mirflickr_single_admm.yaml new file mode 100644 index 00000000..5055bf6f --- /dev/null +++ b/configs/diffusercam_mirflickr_single_admm.yaml @@ -0,0 +1,43 @@ +# python scripts/recon/admm.py -cn diffusercam_mirflickr_single_admm +defaults: + - defaults_recon + - _self_ + + +display: + gamma: null + +input: + # File path for recorded PSF + psf: data/DiffuserCam_Test/psf.tiff + # File path for raw data + data: data/DiffuserCam_Test/diffuser/im5.npy + dtype: float32 + original: data/DiffuserCam_Test/lensed/im5.npy + +torch: True +torch_device: 'cuda:0' + +preprocess: + downsample: 8 # factor for PSF, which is 4x resolution of image + normalize: False + +admm: + # Number of iterations + n_iter: 20 + # Hyperparameters + mu1: 1e-6 + mu2: 1e-5 + mu3: 4e-5 + tau: 0.0001 + #Loading unrolled model + unrolled: True + # checkpoint_fp: pretrained_models/Pre_Unrolled_Post-DiffuserCam/model_weights.pt + checkpoint_fp: outputs/2023-09-11/22-06-49/recon.pt # pre unet and post drunet + pre_process_model: + network : UnetRes # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process_model: + network : DruNet # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + \ No newline at end of file diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml index 040dc81b..af55e03a 100644 --- a/configs/fine-tune_PSF.yaml +++ b/configs/fine-tune_PSF.yaml @@ -1,119 +1,18 @@ -hydra: - job: - chdir: True # change to output folder - -#Reconstruction algorithm -input: - # File path for recorded PSF - psf: data/DiffuserCam_Mirflickr_200_3011302021_11h43_seed11/psf.tiff - dtype: float32 - -torch: True -torch_device: 'cuda' - -preprocess: - # Image shape (height, width) for reconstruction. - shape: null - # Whether image is raw bayer data. - bayer: False - blue_gain: null - red_gain: null - # Same PSF for all channels (sum) or unique PSF for RGB. - single_psf: False - # Whether to perform construction in grayscale. - gray: False - - -display: - # How many iterations to wait for intermediate plot. - # Set to negative value for no intermediate plots. - disp: 500 - # Whether to plot results. - plot: True - # Gamma factor for plotting. - gamma: null - -# Whether to save intermediate and final reconstructions. -save: True - -reconstruction: - # Method: unrolled_admm, unrolled_fista - method: unrolled_admm - - # Hyperparameters for each method - unrolled_fista: # for unrolled_fista - # Number of iterations - n_iter: 20 - tk: 1 - learn_tk: True - unrolled_admm: - # Number of iterations - n_iter: 20 - # Hyperparameters - mu1: 1e-4 - mu2: 1e-4 - mu3: 1e-4 - tau: 2e-4 - pre_process: - network : null # UnetRes or DruNet or null - depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet - post_process: - network : null # UnetRes or DruNet or null - depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet +# python scripts/recon/train_unrolled.py -cn fine-tune_PSF +defaults: + - train_unrolledADMM + - _self_ #Trainable Mask trainable_mask: mask_type: TrainablePSF #Null or "TrainablePSF" - initial_value: "DiffuserCam" # "random" or "DiffuserCam" or "DiffuserCam_gray" + initial_value: psf mask_lr: 1e-3 L1_strength: 1.0 #False or float - use_mask_in_dataset : False # Work only with simulated dataset - -# Train Dataset -files: - dataset: "DiffuserCam" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" - n_files: null # null to use all - -target: "object_plane" # "original" or "object_plane" or "label" - -#for simulated dataset -simulation: - grayscale: False - # random variations - object_height: 0.04 # range for random height or scalar - flip: True # change the orientation of the object (from vertical to horizontal) - random_shift: False - random_vflip: 0.5 - random_hflip: 0.5 - random_rotate: False - # these distance parameters are typically fixed for a given PSF - # for DiffuserCam psf # for tape_rgb psf - scene2mask: 10e-2 # scene2mask: 40e-2 - mask2sensor: 9e-3 # mask2sensor: 4e-3 - # see waveprop.devices - sensor: "rpi_hq" - snr_db: 10 - # simulate different sensor resolution - # output_dim: [24, 32] # [H, W] or null - # Downsampling for PSF - downsample: 8 - # max val in simulated measured (quantized 8 bits) - max_val: 255 #Training - training: - batch_size: 8 - epoch: 10 - #In case of instable training - skip_NAN: True - slow_start: False #float how much to reduce lr for first epoch - + save_every: 5 -optimizer: - type: Adam - lr: 1e-4 - -loss: 'l2' -# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) -lpips: 1.0 \ No newline at end of file +display: + gamma: 2.2 diff --git a/configs/train_pre-post-processing.yaml b/configs/train_pre-post-processing.yaml new file mode 100644 index 00000000..f4d6ba98 --- /dev/null +++ b/configs/train_pre-post-processing.yaml @@ -0,0 +1,24 @@ +# python scripts/recon/train_unrolled.py -cn train_pre-post-processing +defaults: + - train_unrolledADMM + - _self_ + +display: + disp: 400 + +reconstruction: + method: unrolled_admm + + pre_process: + network: UnetRes + depth: 2 + post_process: + network: DruNet + depth: 4 + +training: + epoch: 50 + slow_start: 0.01 + +loss: l2 +lpips: 1.0 diff --git a/configs/train_psf_from_scratch.yaml b/configs/train_psf_from_scratch.yaml new file mode 100644 index 00000000..b4eef0ed --- /dev/null +++ b/configs/train_psf_from_scratch.yaml @@ -0,0 +1,18 @@ +# python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +defaults: + - train_unrolledADMM + - _self_ + +# Train Dataset +files: + dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + celeba_root: /scratch/bezzam + downsample: 8 + +#Trainable Mask +trainable_mask: + mask_type: TrainablePSF #Null or "TrainablePSF" + initial_value: "random" + +simulation: + grayscale: False diff --git a/configs/unrolled_recon.yaml b/configs/train_unrolledADMM.yaml similarity index 74% rename from configs/unrolled_recon.yaml rename to configs/train_unrolledADMM.yaml index 2673f20c..3871be0d 100644 --- a/configs/unrolled_recon.yaml +++ b/configs/train_unrolledADMM.yaml @@ -1,29 +1,20 @@ +# python scripts/recon/train_unrolled.py hydra: job: chdir: True # change to output folder -#Reconstruction algorithm -input: - # File path for recorded PSF - psf: data/DiffuserCam_Mirflickr_200_3011302021_11h43_seed11/psf.tiff - dtype: float32 +# Dataset +files: + dataset: data/DiffuserCam # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + psf: data/psf.tiff + diffusercam_psf: True + n_files: null # null to use all for both train/test + downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution torch: True torch_device: 'cuda' -preprocess: - # Image shape (height, width) for reconstruction. - shape: null - # Whether image is raw bayer data. - bayer: False - blue_gain: null - red_gain: null - # Same PSF for all channels (sum) or unique PSF for RGB. - single_psf: False - # Whether to perform construction in grayscale. - gray: False - - display: # How many iterations to wait for intermediate plot. # Set to negative value for no intermediate plots. @@ -64,15 +55,11 @@ reconstruction: #Trainable Mask trainable_mask: mask_type: Null #Null or "TrainablePSF" - initial_value: "DiffuserCam_gray" # "random" or "DiffuserCam" or "DiffuserCam_gray" + # "random" (with shape of config.files.psf) or "psf" (using config.files.psf) + initial_value: psf + grayscale: False mask_lr: 1e-3 L1_strength: 1.0 #False or float - use_mask_in_dataset : True # Work only with simulated dataset - -# Train Dataset -files: - dataset: "DiffuserCam" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" - n_files: null # null to use all target: "object_plane" # "original" or "object_plane" or "label" @@ -98,13 +85,16 @@ simulation: # Downsampling for PSF downsample: 8 # max val in simulated measured (quantized 8 bits) + quantize: False # must be False for differentiability max_val: 255 #Training training: batch_size: 8 - epoch: 10 + epoch: 50 + metric_for_best_model: null # e.g. LPIPS_Vgg, null does test loss + save_every: null #In case of instable training skip_NAN: True slow_start: False #float how much to reduce lr for first epoch diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index f93b754d..885766f3 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -93,7 +93,20 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): if metric == "ReconstructionError": metrics_values[metric] += model.reconstruction_error().cpu().item() else: - metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() + if "LPIPS" in metric: + if prediction.shape[1] == 1: + # LPIPS needs 3 channels + metrics_values[metric] += ( + metrics[metric]( + prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + ) + .cpu() + .item() + ) + else: + metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() + else: + metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() model.reset() diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 593c2360..9bc70bc8 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -3,13 +3,15 @@ # ================== # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# import abc import torch +from lensless.utils.image import is_grayscale -class TrainableMask(metaclass=abc.ABCMeta): +class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): """ Abstract class for defining trainable masks. @@ -33,6 +35,7 @@ def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): lr : float, optional Learning rate for the mask parameters, by default 1e-3 """ + super().__init__() self._mask = torch.nn.Parameter(initial_mask) self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs) self._counter = 0 @@ -68,18 +71,31 @@ class TrainablePSF(TrainableMask): Parameters ---------- - is_rgb : bool, optional - Whether the mask is RGB or not, by default True. + grayscale : bool, optional + Whether mask should be returned as grayscale when calling :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`. + Otherwise PSF will be returned as RGB. By default False. """ - def __init__(self, is_rgb=True, **kwargs): - super().__init__(**kwargs) - self._is_rgb = is_rgb + def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs): + super().__init__(initial_mask, optimizer, lr, **kwargs) + assert ( + len(initial_mask.shape) == 4 + ), "Mask must be of shape (depth, height, width, channels)" + self.grayscale = grayscale + self._is_grayscale = is_grayscale(initial_mask) + if grayscale: + assert self._is_grayscale, "Mask must be grayscale" def get_psf(self): - if self._is_rgb: - return self._mask.expand(-1, -1, -1, 3) + if self._is_grayscale: + if self.grayscale: + # simulation in grayscale + return self._mask + else: + # replicate to 3 channels + return self._mask.expand(-1, -1, -1, 3) else: + # assume RGB return self._mask def project(self): diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index e554f6b0..82fd883d 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -5,7 +5,10 @@ # Yohann PERRON [yohann.perron@gmail.com] # ############################################################################# +import pathlib as plib +from matplotlib import pyplot as plt from lensless.recon.recon import ReconstructionAlgorithm +from lensless.utils.plot import plot_image try: import torch @@ -153,7 +156,15 @@ def batch_call(self, batch): return image_est def apply( - self, disp_iter=10, plot_pause=0.2, plot=True, save=False, gamma=None, ax=None, reset=True + self, + disp_iter=10, + plot_pause=0.2, + plot=True, + save=False, + gamma=None, + ax=None, + reset=True, + output_intermediate=False, ): """ Method for performing iterative reconstruction. Contrary to non-trainable reconstruction @@ -178,6 +189,8 @@ def apply( Gamma correction factor to apply for plots. Default is None. ax : :py:class:`~matplotlib.axes.Axes`, optional `Axes` object to fill for plotting/saving, default is to create one. + output_intermediate : bool, optional + Whether to output intermediate reconstructions after preprocessing and before postprocessing. Returns ------- @@ -188,8 +201,11 @@ def apply( returning if `plot` or `save` is True. """ + pre_processed_image = None if self.pre_process is not None: self._data = self.pre_process(self._data, self.pre_process_param) + if output_intermediate: + pre_processed_image = self._data[0, ...].clone() im = super(TrainableReconstructionAlgorithm, self).apply( n_iter=self._n_iter, @@ -201,6 +217,30 @@ def apply( ax=ax, reset=reset, ) + + # remove plot if returned + if plot: + im, _ = im + + # post process data + pre_post_process_image = None if self.post_process is not None: + # apply post process + if output_intermediate: + pre_post_process_image = im.clone() im = self.post_process(im, self.post_process_param) - return im + + if plot: + ax = plot_image(self._get_numpy_data(im[0]), ax=ax, gamma=gamma) + ax.set_title( + "Final reconstruction after {} iterations and post process".format(self._n_iter) + ) + if save: + plt.savefig(plib.Path(save) / "final.png") + + if output_intermediate: + return im, pre_processed_image, pre_post_process_image + elif plot: + return im, ax + else: + return im diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 5f091c53..2409dd80 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -9,25 +9,28 @@ import json import math +import numpy as np +import matplotlib.pyplot as plt import time from hydra.utils import get_original_cwd import os -import matplotlib.pyplot as plt import torch from lensless.eval.benchmark import benchmark from lensless.hardware.trainable_mask import TrainableMask from tqdm import tqdm from lensless.recon.drunet.network_unet import UNetRes +from lensless.utils.io import save_image +from lensless.utils.plot import plot_image -def load_drunet(model_path, n_channels=3, requires_grad=False): +def load_drunet(model_path=None, n_channels=3, requires_grad=False): """ Load a pre-trained Drunet model. Parameters ---------- - model_path : str - Path to pre-trained model. + model_path : str, optional + Path to pre-trained model. Download if not provided. n_channels : int Number of channels in input image. requires_grad : bool @@ -39,6 +42,25 @@ def load_drunet(model_path, n_channels=3, requires_grad=False): Loaded model. """ + if model_path is None: + model_path = os.path.join(get_original_cwd(), "models", "drunet_color.pth") + if not os.path.exists(model_path): + try: + from torchvision.datasets.utils import download_url + except ImportError: + exit() + msg = "Do you want to download the pretrained DRUNet model (130MB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + output_path = os.path.join(get_original_cwd(), "models") + if valid: + url = "https://drive.switch.ch/index.php/s/jTdeMHom025RFRQ/download" + filename = "drunet_color.pth" + download_url(url, output_path, filename=filename) + + assert os.path.exists(model_path), f"Model path {model_path} does not exist" + model = UNetRes( in_nc=n_channels + 1, out_nc=n_channels, @@ -192,9 +214,7 @@ def create_process_network(network, depth, device="cpu"): if network == "DruNet": from lensless.recon.utils import load_drunet - process = load_drunet( - os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True - ).to(device) + process = load_drunet(requires_grad=True).to(device) process_name = "DruNet" elif network == "UnetRes": from lensless.recon.drunet.network_unet import UNetRes @@ -223,6 +243,7 @@ def __init__( recon, train_dataset, test_dataset, + test_size=0.15, mask=None, batch_size=4, loss="l2", @@ -233,10 +254,19 @@ def __init__( slow_start=None, skip_NAN=False, algorithm_name="Unknown", + metric_for_best_model=None, + save_every=None, + gamma=None, ): """ Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. + The train and test metrics at the end of each epoch can be found in ``self.metrics``, + with "LOSS" being the train loss. The test loss can be found in "MSE" (if loss is "l2") or + "MAE" (if loss is "l1"). If ``lpips`` is not None, the LPIPS loss is also added + to the train loss, such that the test loss can be computed as "MSE" + ``lpips`` * "LPIPS_Vgg" + (or "MAE" + ``lpips`` * "LPIPS_Vgg"). + Parameters ---------- recon : :py:class:`lensless.TrainableReconstructionAlgorithm` @@ -245,39 +275,51 @@ def __init__( Dataset to use for training. test_dataset : :py:class:`torch.utils.data.Dataset` Dataset to use for testing. + test_size : float, optional + If test_dataset is None, fraction of the train dataset to use for testing, by default 0.15. mask : TrainableMask, optional Trainable mask to use for training. If none, training with fix psf, by default None. batch_size : int, optional - Batch size to use for training, by default 4 + Batch size to use for training, by default 4. loss : str, optional - Loss function to use for training "l1" or "l2", by default "l2" + Loss function to use for training "l1" or "l2", by default "l2". lpips : float, optional - the weight of the lpips(VGG) in the total loss. If None ignore. By default None + the weight of the lpips(VGG) in the total loss. If None ignore. By default None. l1_mask : float, optional - the weight of the l1 norm of the mask in the total loss. If None ignore. By default None + the weight of the l1 norm of the mask in the total loss. If None ignore. By default None. optimizer : str, optional - Optimizer to use durring training. Available : "Adam". By default "Adam" + Optimizer to use durring training. Available : "Adam". By default "Adam". optimizer_lr : float, optional - Learning rate for the optimizer, by default 1e-6 + Learning rate for the optimizer, by default 1e-6. slow_start : float, optional Multiplicative factor to reduce the learning rate during the first two epochs. If None, ignored. Default is None. skip_NAN : bool, optional Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False algorithm_name : str, optional Algorithm name for logging, by default "Unknown". + metric_for_best_model : str, optional + Metric to use for saving the best model. If None, will default to evaluation loss. Default is None. + save_every : int, optional + Save model every ``save_every`` epochs. If None, just save best model. + gamma : float, optional + Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None. + """ self.device = recon._psf.device self.recon = recon + assert train_dataset is not None if test_dataset is None: + assert test_size < 1.0 and test_size > 0.0 # split train dataset - train_size = int(0.9 * len(train_dataset)) + train_size = int((1 - test_size) * len(train_dataset)) test_size = len(train_dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split( train_dataset, [train_size, test_size] ) + print(f"Train size : {train_size}, Test size : {test_size}") self.train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, @@ -297,6 +339,7 @@ def __init__( self.use_mask = False self.l1_mask = l1_mask + self.gamma = gamma # loss if loss == "l2": @@ -345,7 +388,7 @@ def learning_rate_function(epoch): ) self.metrics = { - "LOSS": [], + "LOSS": [], # train loss "MSE": [], "MAE": [], "LPIPS_Vgg": [], @@ -355,7 +398,15 @@ def learning_rate_function(epoch): "ReconstructionError": [], "n_iter": self.recon._n_iter, "algorithm": algorithm_name, + "metric_for_best_model": metric_for_best_model, + "best_epoch": 0, + "best_eval_score": 0 + if metric_for_best_model == "PSNR" or metric_for_best_model == "SSIM" + else np.inf, } + if metric_for_best_model is not None: + assert metric_for_best_model in self.metrics.keys() + self.save_every = save_every # Backward hook that detect NAN in the gradient and print the layer weights if not self.skip_NAN: @@ -430,6 +481,12 @@ def train_epoch(self, data_loader, disp=-1): loss_v = self.Loss(y_pred, y) if self.lpips: + + if y_pred.shape[1] == 1: + # if only one channel, repeat for LPIPS + y_pred = y_pred.repeat(1, 3, 1, 1) + y = y.repeat(1, 3, 1, 1) + # value for LPIPS needs to be in range [-1, 1] loss_v = loss_v + self.lpips * torch.mean( self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) @@ -489,7 +546,18 @@ def evaluate(self, mean_loss, save_pt): with open(os.path.join(save_pt, "metrics.json"), "w") as f: json.dump(self.metrics, f) - def on_epoch_end(self, mean_loss, save_pt): + # check best metric + if self.metrics["metric_for_best_model"] is None: + eval_loss = current_metrics["MSE"] + if self.lpips is not None: + eval_loss += self.lpips * current_metrics["LPIPS_Vgg"] + if self.use_mask and self.l1_mask: + eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy())) + return eval_loss + else: + return current_metrics[self.metrics["metric_for_best_model"]] + + def on_epoch_end(self, mean_loss, save_pt, epoch): """ Called at the end of each epoch. @@ -499,14 +567,35 @@ def on_epoch_end(self, mean_loss, save_pt): Mean loss of the last epoch. save_pt : str Path to save metrics dictionary to. If None, no logging of metrics. + epoch : int + Current epoch. """ if save_pt is None: # Use current directory save_pt = os.getcwd() # save model - self.save(path=save_pt, include_optimizer=False) - self.evaluate(mean_loss, save_pt) + # self.save(path=save_pt, include_optimizer=False) + epoch_eval_metric = self.evaluate(mean_loss, save_pt) + new_best = False + if ( + self.metrics["metric_for_best_model"] == "PSNR" + or self.metrics["metric_for_best_model"] == "SSIM" + ): + if epoch_eval_metric > self.metrics["best_eval_score"]: + self.metrics["best_eval_score"] = epoch_eval_metric + new_best = True + else: + if epoch_eval_metric < self.metrics["best_eval_score"]: + self.metrics["best_eval_score"] = epoch_eval_metric + new_best = True + + if new_best: + self.metrics["best_epoch"] = epoch + self.save(path=save_pt, include_optimizer=False, epoch="BEST") + + if self.save_every is not None and epoch % self.save_every == 0: + self.save(path=save_pt, include_optimizer=False, epoch=epoch) def train(self, n_epoch=1, save_pt=None, disp=-1): """ @@ -528,26 +617,31 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): for epoch in range(n_epoch): print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") mean_loss = self.train_epoch(self.train_dataloader, disp=disp) - self.on_epoch_end(mean_loss, save_pt) + # offset because of evaluate before loop + self.on_epoch_end(mean_loss, save_pt, epoch + 1) self.scheduler.step() print(f"Train time : {time.time() - start_time} s") - def save(self, path="recon", include_optimizer=False): + def save(self, epoch, path="recon", include_optimizer=False): # create directory if it does not exist if not os.path.exists(path): os.makedirs(path) # save mask if self.use_mask: - torch.save(self.mask._mask, os.path.join(path, "mask.pt")) - torch.save(self.mask._optimizer.state_dict(), os.path.join(path, "mask_optim.pt")) - import matplotlib.pyplot as plt - - plt.imsave( - os.path.join(path, "psf.png"), self.mask.get_psf().detach().cpu().numpy()[0, ...] + torch.save(self.mask._mask, os.path.join(path, f"mask_epoch{epoch}.pt")) + torch.save( + self.mask._optimizer.state_dict(), os.path.join(path, f"mask_optim_epoch{epoch}.pt") ) + + psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] + psf_np = psf_np.squeeze() # remove (potential) singleton color channel + save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png")) + plot_image(psf_np, gamma=self.gamma) + plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png")) + # save optimizer if include_optimizer: - torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) + torch.save(self.optimizer.state_dict(), os.path.join(path, f"optim_epoch{epoch}.pt")) # save recon - torch.save(self.recon.state_dict(), os.path.join(path, "recon.pt")) + torch.save(self.recon.state_dict(), os.path.join(path, f"recon_epoch{epoch}")) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 67aa7a8d..a5a2e8a9 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -26,7 +26,9 @@ class DualDataset(Dataset): def __init__( self, indices=None, + # psf_path=None, background=None, + # background_pix=(0, 15), downsample=1, flip=False, transform_lensless=None, @@ -38,18 +40,22 @@ def __init__( Parameters ---------- - indices : range or int or None - Indices of the images to use in the dataset (if integer, it should be interpreted as range(indices)), by default None. - background : :py:class:`~torch.Tensor` or None, optional - If not ``None``, background is removed from lensless images, by default ``None``. - downsample : int, optional - Downsample factor of the lensless images, by default 1. - flip : bool, optional - If ``True``, lensless images are flipped, by default ``False``. - transform_lensless : PyTorch Transform or None, optional - Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). - transform_lensed : PyTorch Transform or None, optional - Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + indices : range or int or None + Indices of the images to use in the dataset (if integer, it should be interpreted as range(indices)), by default None. + psf_path : str + Path to the PSF of the imaging system, by default None. + background : :py:class:`~torch.Tensor` or None, optional + If not ``None``, background is removed from lensless images, by default ``None``. If PSF is provided, background is estimated from the PSF. + background_pix : tuple, optional + Pixels to use for background estimation, by default (0, 15). + downsample : int, optional + Downsample factor of the lensless images, by default 1. + flip : bool, optional + If ``True``, lensless images are flipped, by default ``False``. + transform_lensless : PyTorch Transform or None, optional + Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + transform_lensed : PyTorch Transform or None, optional + Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). """ if isinstance(indices, int): indices = range(indices) @@ -60,6 +66,21 @@ def __init__( self.transform_lensless = transform_lensless self.transform_lensed = transform_lensed + # self.psf = None + # if psf_path is not None: + # psf, background = load_psf( + # psf_path, + # downsample=downsample, + # return_float=True, + # return_bg=True, + # bg_pix=background_pix, + # ) + # if self.background is None: + # self.background = background + # self.psf = torch.from_numpy(psf) + # if self.transform_lensless is not None: + # self.psf = self.transform_lensless(self.psf) + @abstractmethod def __len__(self): """ @@ -151,7 +172,7 @@ def __init__( dataset_is_CHW : bool, optional If True, the input dataset is expected to output images with shape [C, H, W], by default ``False``. flip : bool, optional - If True, images are flipped beffore the simulation, by default ``False``.. + If True, images are flipped beffore the simulation, by default ``False``. """ # we do the flipping before the simualtion @@ -171,6 +192,10 @@ def __init__( assert simulator.fft_shape is not None, "Simulator should have a psf" self.sim = simulator + @property + def psf(self): + return self.sim.get_psf() + def get_image(self, index): return self.dataset[index] @@ -185,7 +210,14 @@ def _get_images_pair(self, index): if self._pre_transform is not None: img = self._pre_transform(img) - lensless, lensed = self.sim.propagate(img, return_object_plane=True) + lensless, lensed = self.sim.propagate_image(img, return_object_plane=True) + + if lensed.shape[-1] == 1 and lensless.shape[-1] == 3: + # copy to 3 channels + lensed = lensed.repeat(1, 1, 3) + assert ( + lensed.shape[-1] == lensless.shape[-1] + ), "Lensed and lensless should have same number of channels" return lensless, lensed @@ -240,6 +272,9 @@ def __init__( self.root_dir = root_dir self.lensless_dir = os.path.join(root_dir, lensless_fn) self.original_dir = os.path.join(root_dir, original_fn) + assert os.path.isdir(self.lensless_dir) + assert os.path.isdir(self.original_dir) + self.image_ext = image_ext.lower() self.original_ext = original_ext.lower() if original_ext is not None else image_ext.lower() @@ -295,7 +330,7 @@ def _get_images_pair(self, idx): # project original image to lensed space with torch.no_grad(): - lensed = self.sim.propagate() + lensed = self.sim.propagate_image() return lensless, lensed @@ -336,6 +371,9 @@ def __init__( self.root_dir = root_dir self.lensless_dir = os.path.join(root_dir, lensless_fn) self.lensed_dir = os.path.join(root_dir, lensed_fn) + assert os.path.isdir(self.lensless_dir) + assert os.path.isdir(self.lensed_dir) + self.image_ext = image_ext.lower() files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) @@ -359,6 +397,7 @@ def _get_images_pair(self, idx): lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) lensless = np.load(lensless_fp) lensed = np.load(lensed_fp) + else: # more standard image formats: png, jpg, tiff, etc. lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) @@ -378,6 +417,59 @@ def _get_images_pair(self, idx): return lensless, lensed +class DiffuserCamMirflickr(MeasuredDataset): + """ + Helper class for DiffuserCam Mirflickr dataset. + + Note that image colors are in BGR format: https://github.com/Waller-Lab/LenslessLearning/blob/master/utils.py#L432 + """ + + def __init__( + self, + dataset_dir, + psf_path, + downsample=2, + **kwargs, + ): + + psf, background = load_psf( + psf_path, + downsample=downsample * 4, # PSF is 4x the resolution of the images + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + self.psf = transform_BRG2RGB(torch.from_numpy(psf)) + self.allowed_idx = np.arange(2, 25001) + + super().__init__( + root_dir=dataset_dir, + background=background, + downsample=downsample, + flip=False, + transform_lensless=transform_BRG2RGB, + transform_lensed=transform_BRG2RGB, + lensless_fn="diffuser_images", + lensed_fn="ground_truth_lensed", + image_ext="npy", + **kwargs, + ) + + def _get_images_pair(self, idx): + + assert idx >= self.allowed_idx.min(), f"idx should be >= {self.allowed_idx.min()}" + assert idx <= self.allowed_idx.max(), f"idx should be <= {self.allowed_idx.max()}" + + fn = f"im{idx}.npy" + lensless_fp = os.path.join(self.lensless_dir, fn) + lensed_fp = os.path.join(self.lensed_dir, fn) + lensless = np.load(lensless_fp) + lensed = np.load(lensed_fp) + + return lensless, lensed + + class DiffuserCamTestDataset(MeasuredDataset): """ Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking. @@ -385,8 +477,8 @@ class DiffuserCamTestDataset(MeasuredDataset): def __init__( self, - data_dir="data", - n_files=200, + data_dir=None, + n_files=None, downsample=2, ): """ @@ -396,17 +488,20 @@ def __init__( Parameters ---------- data_dir : str, optional - The path to the folder containing the DiffuserCam_Test dataset, by default "data". + The path to ``DiffuserCam_Test`` dataset, by default looks inside the ``data`` folder. n_files : int, optional - Number of image pairs to load in the dataset , by default 200. + Number of image pairs to load in the dataset , by default use all. downsample : int, optional - Downsample factor of the lensless images, by default 8. + Downsample factor of the lensless images, by default 2. Note that the PSF has a resolution of 4x of the images. """ # download dataset if necessary - main_dir = data_dir - data_dir = os.path.join(data_dir, "DiffuserCam_Test") + if data_dir is None: + data_dir = os.path.join( + os.path.dirname(__file__), "..", "..", "data", "DiffuserCam_Test" + ) if not os.path.isdir(data_dir): + main_dir = os.path.join(os.path.dirname(__file__), "..", "..", "data") print("No dataset found for benchmarking.") try: from torchvision.datasets.utils import download_and_extract_archive @@ -424,7 +519,7 @@ def __init__( psf_fp = os.path.join(data_dir, "psf.tiff") psf, background = load_psf( psf_fp, - downsample=downsample, + downsample=downsample * 4, # PSF is 4x the resolution of the images return_float=True, return_bg=True, bg_pix=(0, 15), @@ -435,11 +530,16 @@ def __init__( self.psf = transform_BRG2RGB(torch.from_numpy(psf)) + if n_files is None: + indices = None + else: + indices = range(n_files) + super().__init__( root_dir=data_dir, - indices=range(n_files), + indices=indices, background=background, - downsample=downsample / 4, + downsample=downsample, flip=False, transform_lensless=transform_BRG2RGB, transform_lensed=transform_BRG2RGB, @@ -492,7 +592,7 @@ def __init__( def _get_images_pair(self, index): # update psf psf = self._mask.get_psf() - self.sim.set_psf(psf) + self.sim.set_point_spread_function(psf) # return simulated images return super()._get_images_pair(index) diff --git a/lensless/utils/image.py b/lensless/utils/image.py index f3bbe28f..748aaf50 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -77,6 +77,23 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC): return np.clip(resized, min_val, max_val) +def is_grayscale(img): + """ + Check if image is RGB. Assuming image is of shape ([depth,] height, width, color). + + Parameters + ---------- + img : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Image array. + + Returns + ------- + bool + Whether image is RGB. + """ + return img.shape[-1] == 1 + + def rgb2gray(rgb, weights=None, keepchanneldim=True): """ Convert RGB array to grayscale. diff --git a/lensless/utils/io.py b/lensless/utils/io.py index f502719a..1b2b234f 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -34,6 +34,7 @@ def load_image( return_float=False, shape=None, dtype=None, + normalize=True, ): """ Load image as numpy array. @@ -73,6 +74,8 @@ def load_image( Shape (H, W, C) to resize to. dtype : str, optional Data type of returned data. Default is to use that of input. + normalize : bool, default True + If ``return_float``, whether to normalize data to maximum value of 1. Returns ------- @@ -136,7 +139,7 @@ def load_image( if bg is not None: # if bg is float vector, turn into int-valued vector - if bg.max() <= 1: + if bg.max() <= 1 and img.dtype not in [np.float32, np.float64]: bg = bg * get_max_val(img) img = img - bg @@ -160,7 +163,8 @@ def load_image( dtype = np.float32 assert dtype == np.float32 or dtype == np.float64 img = img.astype(dtype) - img /= img.max() + if normalize: + img /= img.max() else: if dtype is None: @@ -336,6 +340,7 @@ def load_psf( def load_data( psf_fp, data_fp, + return_float=True, downsample=None, bg_pix=(5, 25), plot=True, @@ -350,6 +355,7 @@ def load_data( shape=None, torch=False, torch_device="cpu", + normalize=False, ): """ Load data for image reconstruction. @@ -360,6 +366,8 @@ def load_data( Full path to PSF file. data_fp : str Full path to measurement file. + return_float : bool, optional + Whether to return PSF as float array, or unsigned int. downsample : int or float Downsampling factor. bg_pix : tuple, optional @@ -386,6 +394,8 @@ def load_data( Whether to sum RGB channels into single PSF, same across channels. Done in "Learned reconstructions for practical mask-based lensless imaging" of Kristina Monakhova et. al. + normalize : bool default True + Whether to normalize data to maximum value of 1. Returns ------- @@ -415,7 +425,7 @@ def load_data( psf, bg = load_psf( psf_fp, downsample=downsample, - return_float=True, + return_float=return_float, bg_pix=bg_pix, return_bg=True, flip=flip, @@ -437,8 +447,9 @@ def load_data( red_gain=red_gain, bg=bg, as_4d=True, - return_float=True, + return_float=return_float, shape=shape, + normalize=normalize, ) if data.shape != psf.shape: diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py index e7f7af3a..b77fabcb 100644 --- a/lensless/utils/simulation.py +++ b/lensless/utils/simulation.py @@ -6,8 +6,8 @@ # Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# -import numpy as np from waveprop.simulation import FarFieldSimulator as FarFieldSimulator_wp +import torch class FarFieldSimulator(FarFieldSimulator_wp): @@ -34,7 +34,7 @@ def __init__( """ Parameters ---------- - psf : np.ndarray, optional. + psf : np.ndarray or torch.Tensor, optional. Point spread function. If not provided, return image at object plane. object_height : float or (float, float) Height of object in meters. Or range of values to randomly sample from. @@ -58,9 +58,15 @@ def __init__( Whether to quantize image, by default True. """ - if psf is not None: - # convert HWC to CHW - psf = psf.squeeze().movedim(-1, 0) + assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" + + if torch.is_tensor(psf): + # drop depth dimension, and convert HWC to CHW + psf = psf[0].movedim(-1, 0) + assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + psf = psf[0] + assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" super().__init__( object_height, @@ -78,6 +84,13 @@ def __init__( **kwargs ) + if self.is_torch: + assert self.psf.shape[0] == 1 or self.psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + assert ( + self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 + ), "PSF must have 1 or 3 channels" + # save all the parameters in a dict self.params = { "object_height": object_height, @@ -94,7 +107,15 @@ def __init__( } self.params.update(kwargs) - def set_psf(self, psf): + def get_psf(self): + if self.is_torch: + # convert CHW to HWC + return self.psf.movedim(0, -1).unsqueeze(0) + else: + return self.psf[None, ...] + + # needs different name from parent class + def set_point_spread_function(self, psf): """ Set point spread function. @@ -103,19 +124,32 @@ def set_psf(self, psf): psf : np.ndarray or torch.Tensor Point spread function. """ - psf = psf.squeeze().movedim(-1, 0) + assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" + + if torch.is_tensor(psf): + # convert HWC to CHW + psf = psf[0].movedim(-1, 0) + assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + psf = psf[0] + assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" + return super().set_psf(psf) - def propagate(self, obj, return_object_plane=False): + def propagate_image(self, obj, return_object_plane=False): """ Parameters ---------- obj : np.ndarray or torch.Tensor - Single image to propagate at format HWC. + Single image to propagate of format HWC. return_object_plane : bool, optional Whether to return object plane, by default False. """ + + assert obj.shape[-1] == 1 or obj.shape[-1] == 3, "Image must have 1 or 3 channels" + if self.is_torch: + # channel in first dimension as expected by waveprop for pytorch obj = obj.moveaxis(-1, 0) res = super().propagate(obj, return_object_plane) if isinstance(res, tuple): @@ -124,10 +158,6 @@ def propagate(self, obj, return_object_plane=False): res = res.moveaxis(-3, -1) return res else: - obj = np.moveaxis(obj, -1, 0) + # TODO: not tested, but normally don't need to move dimensions for numpy res = super().propagate(obj, return_object_plane) - if isinstance(res, tuple): - res = np.moveaxis(res[0], -3, -1), np.moveaxis(res[1], -3, -1) - else: - res = np.moveaxis(res, -3, -1) return res diff --git a/mask_requirements.txt b/mask_requirements.txt index 699ba552..9e9c28a4 100644 --- a/mask_requirements.txt +++ b/mask_requirements.txt @@ -1,3 +1,3 @@ sympy>=1.11.1 perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 -waveprop>=0.0.7 \ No newline at end of file +waveprop>=0.0.8 \ No newline at end of file diff --git a/recon_requirements.txt b/recon_requirements.txt index 33e12092..0b90adf2 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -3,9 +3,10 @@ lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 click>=8.0.1 -waveprop>=0.0.7 # for simulation +waveprop>=0.0.8 # for simulation # Library for learning algorithm torch >= 2.0.0 torchvision +torchmetrics lpips \ No newline at end of file diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 6611ceec..89a31309 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -44,9 +44,7 @@ def benchmark_recon(config): device = "cpu" # Benchmark dataset - benchmark_dataset = DiffuserCamTestDataset( - data_dir=os.path.join(get_original_cwd(), "data"), n_files=n_files, downsample=downsample - ) + benchmark_dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample) psf = benchmark_dataset.psf.to(device) model_list = [] # list of algoritms to benchmark diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 2a053722..c84d5b92 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -13,8 +13,9 @@ import pathlib as plib import matplotlib.pyplot as plt import numpy as np -from lensless.utils.io import load_data +from lensless.utils.io import load_data, load_image from lensless import ADMM +from lensless.utils.plot import plot_image @hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") @@ -42,6 +43,7 @@ def admm(config): torch=config.torch, torch_device=config.torch_device, bg_pix=config.preprocess.bg_pix, + normalize=config.preprocess.normalize, ) disp = config["display"]["disp"] @@ -52,6 +54,15 @@ def admm(config): if save: save = os.getcwd() + if save: + if config.torch: + org_data = data.cpu().numpy() + else: + org_data = data + ax = plot_image(org_data, gamma=config["display"]["gamma"]) + ax.set_title("Original measurement") + plt.savefig(plib.Path(save) / "lensless.png") + start_time = time.time() if not config.admm.unrolled: recon = ADMM(psf, **config.admm) @@ -60,14 +71,14 @@ def admm(config): from lensless.recon.unrolled_admm import UnrolledADMM import lensless.recon.utils - pre_process = lensless.recon.utils.create_process_network( + pre_process, _ = lensless.recon.utils.create_process_network( network=config.admm.pre_process_model.network, - depth=config.admm.pre_process_depth.depth, + depth=config.admm.pre_process_model.depth, device=config.torch_device, ) - post_process = lensless.recon.utils.create_process_network( + post_process, _ = lensless.recon.utils.create_process_network( network=config.admm.post_process_model.network, - depth=config.admm.post_process_depth.depth, + depth=config.admm.post_process_model.depth, device=config.torch_device, ) @@ -76,18 +87,28 @@ def admm(config): print("Loading checkpoint from : ", path) assert os.path.exists(path), "Checkpoint does not exist" recon.load_state_dict(torch.load(path, map_location=config.torch_device)) + recon.set_data(data) print(f"Setup time : {time.time() - start_time} s") start_time = time.time() if config.torch: with torch.no_grad(): - res = recon.apply( - disp_iter=disp, - save=save, - gamma=config["display"]["gamma"], - plot=config["display"]["plot"], - ) + if config.admm.unrolled: + res = recon.apply( + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + output_intermediate=True, + ) + else: + res = recon.apply( + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + ) else: res = recon.apply( disp_iter=disp, @@ -105,7 +126,78 @@ def admm(config): if config["display"]["plot"]: plt.show() if save: + + if config.admm.unrolled: + # Save intermediate results + if res[1] is not None: + pre_processed_image = res[1].cpu().numpy() + ax = plot_image(pre_processed_image, gamma=config["display"]["gamma"]) + ax.set_title("Image after preprocessing") + plt.savefig(plib.Path(save) / "pre_processed.png") + + if res[2] is not None: + pre_post_process_image = res[2].cpu().numpy() + ax = plot_image(pre_post_process_image, gamma=config["display"]["gamma"]) + ax.set_title("Image prior to post-processing") + plt.savefig(plib.Path(save) / "pre_post_process.png") + np.save(plib.Path(save) / "final_reconstruction.npy", img) + + if config.input.original is not None: + original = load_image( + to_absolute_path(config.input.original), + flip=config["preprocess"]["flip"], + red_gain=config["preprocess"]["red_gain"], + blue_gain=config["preprocess"]["blue_gain"], + shape=img.shape[-3:], + ) + ax = plot_image(original, gamma=config["display"]["gamma"]) + ax.set_title("Ground truth image") + plt.savefig(plib.Path(save) / "original.png") + + # compute metrics + from torchmetrics.image import lpip, psnr + + lpips_func = lpip.LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True) + psnr_funct = psnr.PeakSignalNoiseRatio() + + img_torch = torch.from_numpy(img).squeeze(0) + original_torch = torch.from_numpy(original).unsqueeze(0) + + # channel as first dimension + img_torch = img_torch.movedim(-1, -3) + original_torch = original_torch.movedim(-1, -3) + + # normalize, TODO img max value is 14 which seems strange + img_torch = img_torch / torch.amax(img_torch) + + # compute metrics + lpips = lpips_func(img_torch, original_torch) + psnr = psnr_funct(img_torch, original_torch) + print(f"LPIPS : {lpips}") + print(f"PSNR : {psnr}") + + # If the recon algorithm is unrolled and has a preprocessing step, plot result without preprocessing + if config.admm.unrolled and recon.pre_process is not None: + recon.set_data(data) + recon.pre_process = None + with torch.no_grad(): + res = recon.apply( + disp_iter=disp, + save=False, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + output_intermediate=True, + ) + + img = res[0].cpu().numpy() + np.save(plib.Path(save) / "final_reconstruction_no_preprocessing.npy", img[0]) + ax = plot_image(img, gamma=config["display"]["gamma"]) + plt.savefig(plib.Path(save) / "final_reconstruction_no_preprocessing.png") + pre_post_process_image = res[2].cpu().numpy() + ax = plot_image(pre_post_process_image, gamma=config["display"]["gamma"]) + plt.savefig(plib.Path(save) / "pre_post_process_no_preprocessing.png") + print(f"Files saved to : {save}") diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index c669ea2e..5cbee7bf 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -3,6 +3,7 @@ # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# """ @@ -12,13 +13,26 @@ python scripts/recon/train_unrolled.py ``` +By default it uses the configuration from the file `configs/train_unrolledADMM.yaml`. + +To train pre- and post-processing networks, use the following command: +``` +python scripts/recon/train_unrolled.py -cn train_pre-post-processing +``` + To fine-tune the DiffuserCam PSF, use the following command: ``` python scripts/recon/train_unrolled.py -cn fine-tune_PSF ``` +To train a PSF from scratch with a simulated dataset, use the following command: +``` +python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +``` + """ +import logging import hydra from hydra.utils import get_original_cwd import os @@ -26,20 +40,50 @@ import time from lensless import UnrolledFISTA, UnrolledADMM from lensless.utils.dataset import ( - DiffuserCamTestDataset, + DiffuserCamMirflickr, SimulatedFarFieldDataset, SimulatedDatasetTrainableMask, ) +from torch.utils.data import Subset import lensless.hardware.trainable_mask from lensless.recon.utils import create_process_network -from lensless.utils.image import rgb2gray +from lensless.utils.image import rgb2gray, is_grayscale from lensless.utils.simulation import FarFieldSimulator from lensless.recon.utils import Trainer import torch from torchvision import transforms, datasets +from lensless.utils.io import load_psf +from lensless.utils.io import save_image +from lensless.utils.plot import plot_image +import matplotlib.pyplot as plt + +# A logger for this file +log = logging.getLogger(__name__) + +def simulate_dataset(config): + + if config.torch_device == "cuda" and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # prepare PSF + psf_fp = os.path.join(get_original_cwd(), config.files.psf) + psf, _ = load_psf( + psf_fp, + downsample=config.files.downsample, + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + if config.files.diffusercam_psf: + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + psf = transform_BRG2RGB(torch.from_numpy(psf)) + + # drop depth dimension + psf = psf.to(device) -def simulate_dataset(config, psf, mask=None): # load dataset transforms_list = [transforms.ToTensor()] data_path = os.path.join(get_original_cwd(), "data") @@ -47,26 +91,38 @@ def simulate_dataset(config, psf, mask=None): transforms_list.append(transforms.Grayscale()) transform = transforms.Compose(transforms_list) if config.files.dataset == "mnist": - ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) elif config.files.dataset == "fashion_mnist": - ds = datasets.FashionMNIST(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.FashionMNIST( + root=data_path, train=True, download=True, transform=transform + ) + test_ds = datasets.FashionMNIST( + root=data_path, train=False, download=True, transform=transform + ) elif config.files.dataset == "cifar10": - ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) elif config.files.dataset == "CelebA": - ds = datasets.CelebA(root=data_path, split="train", download=True, transform=transform) + root = config.files.celeba_root + data_path = os.path.join(root, "celeba") + assert os.path.isdir( + data_path + ), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" + train_ds = datasets.CelebA(root=root, split="train", download=False, transform=transform) + test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) else: raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") # convert PSF - if config.simulation.grayscale: + if config.simulation.grayscale and not is_grayscale(psf): psf = rgb2gray(psf) - if not isinstance(psf, torch.Tensor): - psf = transforms.ToTensor()(psf) - n_files = config.files.n_files - device_conv = config.torch_device + # prepare mask + mask = prep_trainable_mask(config, psf, grayscale=config.simulation.grayscale) # check if gpu is available + device_conv = config.torch_device if device_conv == "cuda" and torch.cuda.is_available(): device_conv = "cuda" else: @@ -78,55 +134,74 @@ def simulate_dataset(config, psf, mask=None): is_torch=True, **config.simulation, ) + # create Pytorch dataset and dataloader + n_files = config.files.n_files if n_files is not None: - ds = torch.utils.data.Subset(ds, np.arange(n_files)) + train_ds = torch.utils.data.Subset(train_ds, np.arange(n_files)) + test_ds = torch.utils.data.Subset(test_ds, np.arange(n_files)) if mask is None: - ds_prop = SimulatedFarFieldDataset( - dataset=ds, + train_ds_prop = SimulatedFarFieldDataset( + dataset=train_ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + test_ds_prop = SimulatedFarFieldDataset( + dataset=test_ds, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv, flip=config.simulation.flip, ) else: - ds_prop = SimulatedDatasetTrainableMask( - dataset=ds, + train_ds_prop = SimulatedDatasetTrainableMask( + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + test_ds_prop = SimulatedDatasetTrainableMask( + dataset=test_ds, mask=mask, simulator=simulator, dataset_is_CHW=True, device_conv=device_conv, flip=config.simulation.flip, ) - return ds_prop + return train_ds_prop, test_ds_prop, mask -@hydra.main(version_base=None, config_path="../../configs", config_name="unrolled_recon") -def train_unrolled( - config, -): - if config.torch_device == "cuda" and torch.cuda.is_available(): - print("Using GPU for training.") - device = "cuda" - else: - print("Using CPU for training.") - device = "cpu" - # torch.autograd.set_detect_anomaly(True) +def prep_trainable_mask(config, psf, grayscale=False): + mask = None + if config.trainable_mask.mask_type is not None: + mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - # benchmarking dataset: - path = os.path.join(get_original_cwd(), "data") - benchmark_dataset = DiffuserCamTestDataset( - data_dir=path, downsample=config.simulation.downsample - ) + if config.trainable_mask.initial_value == "random": + initial_mask = torch.rand_like(psf) + elif config.trainable_mask.initial_value == "psf": + initial_mask = psf.clone() + else: + raise ValueError( + f"Initial PSF value {config.trainable_mask.initial_value} not supported" + ) + + if config.trainable_mask.grayscale and not is_grayscale(initial_mask): + initial_mask = rgb2gray(initial_mask) + + mask = mask_class( + initial_mask, optimizer="Adam", lr=config.trainable_mask.mask_lr, grayscale=grayscale + ) - diffusercam_psf = benchmark_dataset.psf.to(device) - background = benchmark_dataset.background + return mask - # convert psf from BGR to RGB - diffusercam_psf = diffusercam_psf[..., [2, 1, 0]] - # if using a portrait dataset rotate the PSF +@hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") +def train_unrolled(config): disp = config.display.disp if disp < 0: @@ -136,6 +211,63 @@ def train_unrolled( if save: save = os.getcwd() + if config.torch_device == "cuda" and torch.cuda.is_available(): + print("Using GPU for training.") + device = "cuda" + else: + print("Using CPU for training.") + device = "cpu" + + # load dataset and create dataloader + train_set = None + test_set = None + psf = None + if "DiffuserCam" in config.files.dataset: + + original_path = os.path.join(get_original_cwd(), config.files.dataset) + psf_path = os.path.join(get_original_cwd(), config.files.psf) + dataset = DiffuserCamMirflickr( + dataset_dir=original_path, + psf_path=psf_path, + downsample=config.files.downsample, + ) + dataset.psf = dataset.psf.to(device) + # train-test split as in https://waller-lab.github.io/LenslessLearning/dataset.html + # first 1000 files for test, the rest for training + train_indices = dataset.allowed_idx[dataset.allowed_idx > 1000] + test_indices = dataset.allowed_idx[dataset.allowed_idx <= 1000] + if config.files.n_files is not None: + train_indices = train_indices[: config.files.n_files] + test_indices = test_indices[: config.files.n_files] + + train_set = Subset(dataset, train_indices) + test_set = Subset(dataset, test_indices) + + # -- if learning mask + mask = prep_trainable_mask(config, dataset.psf) + if mask is not None: + # plot initial PSF + psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] + if config.trainable_mask.grayscale: + psf_np = psf_np[:, :, -1] + + save_image(psf_np, os.path.join(save, "psf_initial.png")) + plot_image(psf_np, gamma=config.display.gamma) + plt.savefig(os.path.join(save, "psf_initial_plot.png")) + + psf = dataset.psf + + else: + + train_set, test_set, mask = simulate_dataset(config) + psf = train_set.psf + + assert train_set is not None + assert psf is not None + + print("Train test size : ", len(train_set)) + print("Test test size : ", len(test_set)) + start_time = time.time() # Load pre process model @@ -150,10 +282,11 @@ def train_unrolled( config.reconstruction.post_process.depth, device=device, ) + # create reconstruction algorithm if config.reconstruction.method == "unrolled_fista": recon = UnrolledFISTA( - diffusercam_psf, + psf, n_iter=config.reconstruction.unrolled_fista.n_iter, tk=config.reconstruction.unrolled_fista.tk, pad=True, @@ -163,7 +296,7 @@ def train_unrolled( ).to(device) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( - diffusercam_psf, + psf, n_iter=config.reconstruction.unrolled_admm.n_iter, mu1=config.reconstruction.unrolled_admm.mu1, mu2=config.reconstruction.unrolled_admm.mu2, @@ -183,67 +316,17 @@ def train_unrolled( algorithm_name += "_" + post_process_name # print number of parameters - print(f"Training model with {sum(p.numel() for p in recon.parameters())} parameters") - # transform from BGR to RGB - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - - # create mask - if config.trainable_mask.mask_type is not None: - mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - if config.trainable_mask.initial_value == "random": - mask = mask_class( - torch.rand_like(diffusercam_psf), optimizer="Adam", lr=config.trainable_mask.mask_lr - ) - elif config.trainable_mask.initial_value == "DiffuserCam": - mask = mask_class(diffusercam_psf, optimizer="Adam", lr=config.trainable_mask.mask_lr) - elif config.trainable_mask.initial_value == "DiffuserCam_gray": - mask = mask_class( - diffusercam_psf[:, :, :, 0, None], - optimizer="Adam", - lr=config.trainable_mask.mask_lr, - is_rgb=not config.simulation.grayscale, - ) - else: - mask = None - - # load dataset and create dataloader - if config.files.dataset == "DiffuserCam": - # Use a ParallelDataset - from lensless.utils.dataset import MeasuredDataset - - max_indices = 30000 - if config.files.n_files is not None: - max_indices = config.files.n_files + 1000 - - data_path = os.path.join(get_original_cwd(), "data", "DiffuserCam") - assert os.path.exists(data_path), "DiffuserCam dataset not found" - dataset = MeasuredDataset( - root_dir=data_path, - indices=range(1000, max_indices), - background=background, - psf=diffusercam_psf, - lensless_fn="diffuser_images", - lensed_fn="ground_truth_lensed", - downsample=config.simulation.downsample / 4, - transform_lensless=transform_BRG2RGB, - transform_lensed=transform_BRG2RGB, - ) - else: - # Use a simulated dataset - if config.trainable_mask.use_mask_in_dataset: - dataset = simulate_dataset(config, diffusercam_psf, mask=mask) - # the mask use will differ from the one in the benchmark dataset - print("Trainable Mask will be used in the test dataset") - benchmark_dataset = None - else: - dataset = simulate_dataset(config, diffusercam_psf, mask=None) + n_param = sum(p.numel() for p in recon.parameters()) + if mask is not None: + n_param += sum(p.numel() for p in mask.parameters()) + log.info(f"Training model with {n_param} parameters") print(f"Setup time : {time.time() - start_time} s") - print(f"PSF shape : {diffusercam_psf.shape}") + print(f"PSF shape : {psf.shape}") trainer = Trainer( - recon, - dataset, - benchmark_dataset, + recon=recon, + train_dataset=train_set, + test_dataset=test_set, mask=mask, batch_size=config.training.batch_size, loss=config.loss, @@ -254,21 +337,13 @@ def train_unrolled( slow_start=config.training.slow_start, skip_NAN=config.training.skip_NAN, algorithm_name=algorithm_name, + metric_for_best_model=config.training.metric_for_best_model, + save_every=config.training.save_every, + gamma=config.display.gamma, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) - if mask is not None: - print("Saving mask") - print(f"mask shape: {mask._mask.shape}") - torch.save(mask._mask, os.path.join(save, "mask.pt")) - # save as image using plt - import matplotlib.pyplot as plt - - print(f"mask max: {mask._mask.max()}") - print(f"mask min: {mask._mask.min()}") - plt.imsave(os.path.join(save, "mask.png"), mask._mask.detach().cpu().numpy()[0, ...]) - if __name__ == "__main__": train_unrolled()