From b11568f429a3f8f4a6a1b65ca4ddc7c4507b9fcb Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Sat, 18 Nov 2023 08:23:14 +0100 Subject: [PATCH] Fixes and better support for RPi Global shutter, and unrolled training for DigiCam + HITL. (#96) * Add support for training from measured CelebA. * Update example to compare with original. * Update default unrolled config. * Clean up global shutter capture. * Fix nbits for global shutter. * Long exposure comments. * Fix path. * Fix aperture. * Update setup for Python 3.11 * Improve benchmarking. * Use natural sorting. * Save analysis. * Save eval examples. * Set seed. * Fix typo. * Add support to benchmark on DigiCamCelebA dataset. * Better align simulated PSF. * Add support to train adafruit mask. * Fix data type of shape for new PyTorch. * Add sensor. * Add option to set number of channels. * Add more options to analyzing measured dataset. * Fix resizing. * Update configs. * Add option to train mask color filter. * Add and improve hardware utilities. * Add more features to unrolled training. * Update configs. * update changelog. --------- Co-authored-by: Eric Bezzam --- CHANGELOG.rst | 8 + README.rst | 15 +- configs/analyze_dataset.yaml | 5 +- configs/benchmark.yaml | 70 ++- configs/capture_bayer.yaml | 2 +- ...apture_rpi_gs.yaml => capture_rpi_gs.yaml} | 4 +- configs/compute_metrics_from_original.yaml | 4 + configs/digicam.yaml | 13 +- configs/fine-tune_PSF.yaml | 22 +- configs/sim_digicam_psf.yaml | 13 +- configs/train_celeba_digicam.yaml | 72 +++ configs/train_celeba_digicam_mask.yaml | 97 ++++ configs/train_pre-post-processing.yaml | 3 + configs/train_psf_from_scratch.yaml | 17 +- configs/train_unrolledADMM.yaml | 47 +- lensless/eval/benchmark.py | 52 +- lensless/eval/metric.py | 2 + lensless/hardware/aperture.py | 3 +- lensless/hardware/sensor.py | 2 +- lensless/hardware/slm.py | 122 +++-- lensless/hardware/trainable_mask.py | 132 +++++- lensless/hardware/utils.py | 249 +++++++++- lensless/recon/drunet/network_unet.py | 2 + lensless/recon/trainable_recon.py | 109 ++++- lensless/recon/unrolled_admm.py | 22 +- lensless/recon/unrolled_fista.py | 15 +- lensless/recon/utils.py | 303 ++++++++---- lensless/utils/dataset.py | 444 +++++++++++++++--- lensless/utils/io.py | 61 ++- lensless/utils/simulation.py | 32 +- scripts/eval/benchmark_recon.py | 178 +++++-- scripts/eval/compute_metrics_from_original.py | 4 + scripts/hardware/config_digicam.py | 28 +- scripts/hardware/set_digicam_mask_distance.py | 16 + scripts/measure/analyze_image.py | 14 +- scripts/measure/analyze_measured_dataset.py | 30 +- scripts/measure/collect_dataset_on_device.py | 6 +- scripts/measure/on_device_capture.py | 9 + scripts/measure/remote_capture.py | 66 ++- scripts/recon/admm.py | 6 +- scripts/recon/train_unrolled.py | 385 ++++++++++++--- scripts/sim/digicam_psf.py | 72 ++- setup.py | 6 +- 43 files changed, 2312 insertions(+), 450 deletions(-) rename configs/{remote_capture_rpi_gs.yaml => capture_rpi_gs.yaml} (74%) create mode 100644 configs/train_celeba_digicam.yaml create mode 100644 configs/train_celeba_digicam_mask.yaml create mode 100644 scripts/hardware/set_digicam_mask_distance.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e1ceb17d..df62cc13 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,13 @@ Added - Trainable reconstruction can return intermediate outputs (between pre- and post-processing). - Auto-download for DRUNet model. - ``utils.dataset.DiffuserCamMirflickr`` helper class for Mirflickr dataset. +- Option to crop section of image for computing loss when training unrolled. +- Option to learn color filter of RGB mask. +- Trainable mask for Adafruit LCD. +- Utility for capture image. +- Option to freeze/unfreeze/add pre- and post-processor components during training. +- Option to skip unrolled training and just use U-Net. +- Dataset objects for Adafruit LCD: measured CelebA and hardware-in-the-loop. Changed ~~~~~~~ @@ -29,6 +36,7 @@ Bugfix - Support for unrolled reconstruction with grayscale, needed to copy to three channels for LPIPS. - Fix bad train/test split for DiffuserCamMirflickr in unrolled training. - Resize utility. +- Aperture, index to dimension conversion. 1.0.5 - (2023-09-05) diff --git a/README.rst b/README.rst index 5a88de08..620041b9 100644 --- a/README.rst +++ b/README.rst @@ -71,14 +71,21 @@ install the library locally. # download from GitHub git clone git@github.com:LCAV/LenslessPiCam.git - - # install in virtual environment cd LenslessPiCam - python3 -m venv lensless_env + + # create virtual environment (as of Oct 4 2023, rawpy is not compatible with Python 3.12) + # -- using conda + conda create -n lensless python=3.11 + conda activate lensless + + # -- OR venv + python3.11 -m venv lensless_env source lensless_env/bin/activate + + # install package pip install -e . - # -- extra dependencies for local machine for plotting/reconstruction + # extra dependencies for local machine for plotting/reconstruction pip install -r recon_requirements.txt # (optional) try reconstruction on local machine diff --git a/configs/analyze_dataset.yaml b/configs/analyze_dataset.yaml index d39ccfec..53d6a130 100644 --- a/configs/analyze_dataset.yaml +++ b/configs/analyze_dataset.yaml @@ -3,6 +3,7 @@ hydra: chdir: True # change to output folder dataset_path: null -desired_range: [190, 254] -delete_saturated: True +desired_range: [150, 254] +delete_bad: False n_files: null +start_idx: null diff --git a/configs/benchmark.yaml b/configs/benchmark.yaml index c1169551..47915514 100644 --- a/configs/benchmark.yaml +++ b/configs/benchmark.yaml @@ -1,3 +1,4 @@ +# python scripts/eval/benchmark_recon.py #Hydra config hydra: run: @@ -5,16 +6,25 @@ hydra: job: chdir: True + +dataset: DiffuserCam # DiffuserCam, DigiCamCelebA +seed: 0 + device: "cuda" # numbers of iterations to benchmark -n_iter_range: [5, 10, 30, 60, 100, 200, 300] +n_iter_range: [5, 10, 20, 50, 100, 200, 300] # number of files to benchmark -n_files: 200 +n_files: 200 # null for all files #How much should the image be downsampled -downsample: 8 +downsample: 2 #algorithm to benchmark algorithms: ["ADMM", "ADMM_Monakhova2019", "FISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"] +# baseline from Monakhova et al. 2019, https://arxiv.org/abs/1908.11502 +baseline: "MONAKHOVA 100iter" + +save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10] + # Hyperparameters nesterov: p: 0 @@ -25,4 +35,56 @@ admm: mu1: 1e-6 mu2: 1e-5 mu3: 4e-5 - tau: 0.0001 \ No newline at end of file + tau: 0.0001 + + +# for DigiCamCelebA +files: + test_size: 0.15 + downsample: 1 + celeba_root: /scratch/bezzam + + + # dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K + # psf: data/psf/adafruit_random_2mm_20231907.png + # vertical_shift: null + # horizontal_shift: null + # crop: null + + dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K + psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png + vertical_shift: -117 + horizontal_shift: -25 + crop: + vertical: [0, 525] + horizontal: [265, 695] + +# for prepping ground truth data +#for simulated dataset +simulation: + grayscale: False + output_dim: null # should be set if no PSF is used + # random variations + object_height: 0.33 # [m], range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) + random_shift: False + random_vflip: 0.5 + random_hflip: 0.5 + random_rotate: False + # these distance parameters are typically fixed for a given PSF + # for DiffuserCam psf # for tape_rgb psf + # scene2mask: 10e-2 # scene2mask: 40e-2 + # mask2sensor: 9e-3 # mask2sensor: 4e-3 + # -- for CelebA + scene2mask: 0.25 # [m] + mask2sensor: 0.002 # [m] + # see waveprop.devices + sensor: "rpi_hq" + snr_db: 10 + # simulate different sensor resolution + # output_dim: [24, 32] # [H, W] or null + # Downsampling for PSF + downsample: 8 + # max val in simulated measured (quantized 8 bits) + quantize: False # must be False for differentiability + max_val: 255 diff --git a/configs/capture_bayer.yaml b/configs/capture_bayer.yaml index 155a6c21..dd980eaa 100644 --- a/configs/capture_bayer.yaml +++ b/configs/capture_bayer.yaml @@ -15,7 +15,7 @@ capture: gamma: null # for visualization exp: 0.02 delay: 2 - script: ~/LenslessPiCam/scripts/on_device_capture.py + script: ~/LenslessPiCam/scripts/measure/on_device_capture.py iso: 100 config_pause: 2 sensor_mode: "0" diff --git a/configs/remote_capture_rpi_gs.yaml b/configs/capture_rpi_gs.yaml similarity index 74% rename from configs/remote_capture_rpi_gs.yaml rename to configs/capture_rpi_gs.yaml index 0466e689..455b1ef0 100644 --- a/configs/remote_capture_rpi_gs.yaml +++ b/configs/capture_rpi_gs.yaml @@ -1,4 +1,4 @@ -# python scripts/measure/remote_capture.py -cn remote_capture_rpi_gs +# python scripts/measure/remote_capture.py -cn capture_rpi_gs defaults: - demo - _self_ @@ -21,3 +21,5 @@ capture: gray: False down: null awb_gains: null + nbits_out: 10 + nbits: 10 # 8 or 10 for global shutter diff --git a/configs/compute_metrics_from_original.yaml b/configs/compute_metrics_from_original.yaml index ab41b388..c5ccb4ef 100644 --- a/configs/compute_metrics_from_original.yaml +++ b/configs/compute_metrics_from_original.yaml @@ -1,3 +1,7 @@ +hydra: + job: + chdir: True # change to output folder + files: # Can be downloaded here: https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Freconstruction recon: data/reconstruction/admm_thumbs_up_rgb.npy diff --git a/configs/digicam.yaml b/configs/digicam.yaml index d84b3a89..040af6ab 100644 --- a/configs/digicam.yaml +++ b/configs/digicam.yaml @@ -1,3 +1,7 @@ +# -- setting mask-to-sensor distance and mask pattern +# python scripts/hardware/config_digicam.py +# -- just setting mask-to-sensor distance +# python scripts/hardware/set_digicam_mask_distance.py rpi: username: null hostname: null @@ -17,7 +21,10 @@ center: [0, 0] aperture: - center: [59,76] - shape: [19,26] + center: null + shape: null +# aperture: +# center: [59,76] +# shape: [19,26] -z: 4 # mask to sensor distance +z: 4 # mask to sensor distance (mm) diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml index af55e03a..c7ff09c9 100644 --- a/configs/fine-tune_PSF.yaml +++ b/configs/fine-tune_PSF.yaml @@ -5,14 +5,32 @@ defaults: #Trainable Mask trainable_mask: - mask_type: TrainablePSF #Null or "TrainablePSF" + mask_type: TrainablePSF initial_value: psf mask_lr: 1e-3 L1_strength: 1.0 #False or float #Training training: - save_every: 5 + save_every: 10 + epoch: 50 + crop_preloss: False display: gamma: 2.2 + +reconstruction: + method: unrolled_admm + + pre_process: + network: UnetRes + depth: 2 + post_process: + network: DruNet + depth: 4 + +optimizer: + slow_start: 0.01 + +loss: l2 +lpips: 1.0 diff --git a/configs/sim_digicam_psf.yaml b/configs/sim_digicam_psf.yaml index 216455cd..e101767f 100644 --- a/configs/sim_digicam_psf.yaml +++ b/configs/sim_digicam_psf.yaml @@ -3,21 +3,24 @@ hydra: job: chdir: True # change to output folder -use_torch: False +use_torch: True dtype: float32 torch_device: cuda -requires_grad: True +requires_grad: False digicam: slm: adafruit sensor: rpi_hq + downsample: null # null for no downsampling # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf pattern: data/psf/adafruit_random_pattern_20230719.npy ap_center: [59, 76] ap_shape: [19, 26] rotate: -0.8 # rotation in degrees + vertical_shift: -20 # [px] + horizontal_shift: -100 # [px] # optionally provide measured PSF for side-by-side comparison # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf @@ -26,13 +29,13 @@ digicam: sim: - # whether SLM is fliped + # whether SLM is flipped flipud: True # in practice found waveprop=True or False doesn't make difference - waveprop: False + waveprop: True # below are ignored if waveprop=False - scene2mask: 0.03 # [m] + scene2mask: 0.3 # [m] mask2sensor: 0.002 # [m] \ No newline at end of file diff --git a/configs/train_celeba_digicam.yaml b/configs/train_celeba_digicam.yaml new file mode 100644 index 00000000..bf06742a --- /dev/null +++ b/configs/train_celeba_digicam.yaml @@ -0,0 +1,72 @@ +# python scripts/recon/train_unrolled.py -cn train_celeba_digicam +defaults: + - train_unrolledADMM + - _self_ + +# Train Dataset +files: + # dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K + # psf: data/psf/adafruit_random_2mm_20231907.png + # vertical_shift: null + # horizontal_shift: null + # crop: null + + downsample: 2 + dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K + psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png + + # ? - 25999 + # vertical_shift: -95 + # horizontal_shift: -30 + # crop: + # vertical: [22, 547] + # horizontal: [260, 690] + + # 0-3000? + vertical_shift: -117 + horizontal_shift: -25 + crop: + vertical: [0, 525] + horizontal: [265, 695] + + celeba_root: /scratch/bezzam + +test_idx: [0, 1, 2, 3, 4] +# test_idx: [1000, 2000, 3000, 4000] + +# for prepping ground truth data +simulation: + scene2mask: 0.25 # [m] + mask2sensor: 0.002 # [m] + object_height: 0.33 # [m] + sensor: "rpi_hq" + snr_db: null + downsample: null + random_vflip: False + random_hflip: False + quantize: False + + +reconstruction: + method: unrolled_admm + unrolled_admm: + # Number of iterations + n_iter: 10 + + pre_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + nc : null + post_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + nc : [32, 64, 128, 256] + + +#Training +training: + batch_size: 2 + epoch: 25 + eval_batch_size: 16 + crop_preloss: True + diff --git a/configs/train_celeba_digicam_mask.yaml b/configs/train_celeba_digicam_mask.yaml new file mode 100644 index 00000000..9657e248 --- /dev/null +++ b/configs/train_celeba_digicam_mask.yaml @@ -0,0 +1,97 @@ +# python scripts/recon/train_unrolled.py -cn train_celeba_digicam_mask +defaults: + - train_celeba_digicam + - _self_ + +# Train Dataset +files: + # dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K + # psf: data/psf/adafruit_random_2mm_20231907.png + # vertical_shift: null + # horizontal_shift: null + # crop: null + + downsample: 2 + dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K + psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png + vertical_shift: -117 + horizontal_shift: -25 + crop: + vertical: [0, 525] + horizontal: [265, 695] + + celeba_root: /scratch/bezzam + + +# for prepping ground truth data +simulation: + scene2mask: 0.25 # [m] + mask2sensor: 0.002 # [m] + object_height: 0.33 # [m] + snr_db: null + downsample: null + random_vflip: False + random_hflip: False + quantize: False + + +reconstruction: + method: unrolled_admm + unrolled_admm: + # Number of iterations + n_iter: 10 + + pre_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + +#Training +training: + batch_size: 2 + epoch: 25 + eval_batch_size: 15 + crop_preloss: True + + save_every: 5 + +#Trainable Mask +trainable_mask: + mask_type: AdafruitLCD #Null or "TrainablePSF" or "AdafruitLCD" + # "random" (with shape of config.files.psf) or path to npy file + grayscale: False + mask_lr: 1e-3 + L1_strength: False + + # for fine-tuning mask values + train_mask_vals: True + train_color_filter: True + + # -- only for AdafruitLCD + # initial_value: data/psf/adafruit_random_pattern_20230719.npy + # ap_center: [59, 76] + # ap_shape: [19, 26] + # rotate: -0.8 # rotation in degrees + # vertical_shift: -20 # [px] + # horizontal_shift: -100 # [px] + + + initial_value: adafruit_random_pattern_20231004_174047.npy + ap_center: [58, 76] + ap_shape: [19, 25] + rotate: 0 # rotation in degrees + # to align with measured PSF (so reconstruction also aligned) + vertical_shift: -80 # [px] + horizontal_shift: -60 # [px] + + slm: adafruit + sensor: rpi_hq + flipud: True + waveprop: False + + # below are ignored if waveprop=False + scene2mask: 0.3 # [m] + mask2sensor: 0.002 # [m] + \ No newline at end of file diff --git a/configs/train_pre-post-processing.yaml b/configs/train_pre-post-processing.yaml index f4d6ba98..bf7c3d04 100644 --- a/configs/train_pre-post-processing.yaml +++ b/configs/train_pre-post-processing.yaml @@ -18,6 +18,9 @@ reconstruction: training: epoch: 50 + crop_preloss: False + +optimizer: slow_start: 0.01 loss: l2 diff --git a/configs/train_psf_from_scratch.yaml b/configs/train_psf_from_scratch.yaml index b4eef0ed..82586751 100644 --- a/configs/train_psf_from_scratch.yaml +++ b/configs/train_psf_from_scratch.yaml @@ -11,8 +11,21 @@ files: #Trainable Mask trainable_mask: - mask_type: TrainablePSF #Null or "TrainablePSF" - initial_value: "random" + mask_type: TrainablePSF + initial_value: random simulation: grayscale: False + flip: False + scene2mask: 40e-2 + mask2sensor: 2e-3 + sensor: "rpi_hq" + downsample: 16 + object_height: 0.30 + +training: + crop_preloss: False # crop region for computing loss + batch_size: 8 + epoch: 25 + eval_batch_size: 16 + save_every: 5 diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 3871be0d..f7602f01 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -3,22 +3,37 @@ hydra: job: chdir: True # change to output folder + +seed: 0 +start_delay: null + # Dataset files: - dataset: data/DiffuserCam # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html - psf: data/psf.tiff + psf: data/psf/diffusercam_psf.tiff diffusercam_psf: True n_files: null # null to use all for both train/test downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution + test_size: 0.15 + + vertical_shift: null + horizontal_shift: null + crop: null + # vertical: null + # horizontal: null torch: True torch_device: 'cuda' +measure: null # if measuring data on-the-fly + +# see some outputs of classical ADMM before training +test_idx: [0, 1, 2, 3, 4] + +# test set example to visualize at the end of every epoch +eval_disp_idx: [0, 1, 2, 3, 4] display: - # How many iterations to wait for intermediate plot. - # Set to negative value for no intermediate plots. - disp: 500 # Whether to plot results. plot: True # Gamma factor for plotting. @@ -30,6 +45,7 @@ save: True reconstruction: # Method: unrolled_admm, unrolled_fista method: unrolled_admm + skip_unrolled: False # Hyperparameters for each method unrolled_fista: # for unrolled_fista @@ -48,13 +64,22 @@ reconstruction: pre_process: network : null # UnetRes or DruNet or null depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + nc: null + delay: null # add component after this may epochs + freeze: null + unfreeze: null post_process: network : null # UnetRes or DruNet or null depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + nc: null + delay: null # add component after this may epochs + freeze: null + unfreeze: null + train_last_layer: False #Trainable Mask trainable_mask: - mask_type: Null #Null or "TrainablePSF" + mask_type: null #Null or "TrainablePSF" or "AdafruitLCD" # "random" (with shape of config.files.psf) or "psf" (using config.files.psf) initial_value: psf grayscale: False @@ -66,6 +91,7 @@ target: "object_plane" # "original" or "object_plane" or "label" #for simulated dataset simulation: grayscale: False + output_dim: null # should be set if no PSF is used # random variations object_height: 0.04 # range for random height or scalar flip: True # change the orientation of the object (from vertical to horizontal) @@ -93,16 +119,23 @@ simulation: training: batch_size: 8 epoch: 50 + eval_batch_size: 10 metric_for_best_model: null # e.g. LPIPS_Vgg, null does test loss save_every: null #In case of instable training skip_NAN: True - slow_start: False #float how much to reduce lr for first epoch + clip_grad: 1.0 + crop_preloss: True # crop region for computing loss optimizer: type: Adam lr: 1e-4 + slow_start: False #float how much to reduce lr for first epoch + # Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html + step: False # int, period of learning rate decay. False to not apply + gamma: 0.1 # float, factor for learning rate decay + loss: 'l2' # set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 885766f3..8abd254e 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -8,7 +8,10 @@ from lensless.utils.dataset import DiffuserCamTestDataset +from lensless.utils.io import save_image from tqdm import tqdm +import os +import numpy as np try: import torch @@ -22,7 +25,16 @@ ) -def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): +def benchmark( + model, + dataset, + batchsize=1, + metrics=None, + crop=None, + save_idx=None, + output_dir=None, + **kwargs, +): """ Compute multiple metrics for a reconstruction algorithm. @@ -36,6 +48,12 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): Batch size for processing. For maximum compatibility use 1 (batchsize above 1 are not supported on all algorithm), by default 1 metrics : dict, optional Dictionary of metrics to compute. If None, MSE, MAE, SSIM, LPIPS and PSNR are computed. + save_idx : list of int, optional + List of indices to save the predictions, by default None (not to save any). + output_dir : str, optional + Directory to save the predictions, by default save in working directory if save_idx is provided. + crop : dict, optional + Dictionary of crop parameters (vertical: [start, end], horizontal: [start, end]), by default None (no crop). Returns ------- @@ -45,6 +63,13 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): assert isinstance(model._psf, torch.Tensor), "model need to be constructed with torch support" device = model._psf.device + if output_dir is None: + output_dir = os.getcwd() + else: + output_dir = str(output_dir) + if not os.path.exists(output_dir): + os.mkdir(output_dir) + if metrics is None: metrics = { "MSE": MSELoss().to(device), @@ -64,6 +89,7 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): # loop over batches dataloader = DataLoader(dataset, batch_size=batchsize, pin_memory=(device != "cpu")) model.reset() + idx = 0 for lensless, lensed in tqdm(dataloader): lensless = lensless.to(device) lensed = lensed.to(device) @@ -80,6 +106,29 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): # Convert to [N*D, C, H, W] for torchmetrics prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) + + if crop is not None: + prediction = prediction[ + ..., + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + lensed = lensed[ + ..., + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + + if save_idx is not None: + batch_idx = np.arange(idx, idx + batchsize) + + for i, idx in enumerate(batch_idx): + if idx in save_idx: + prediction_np = prediction.cpu().numpy()[i].squeeze() + # switch to [H, W, C] + prediction_np = np.moveaxis(prediction_np, 0, -1) + save_image(prediction_np, fp=os.path.join(output_dir, f"{idx}.png")) + # normalization prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True) if torch.all(prediction_max != 0): @@ -109,6 +158,7 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() model.reset() + idx += batchsize # average metrics for metric in metrics: diff --git a/lensless/eval/metric.py b/lensless/eval/metric.py index 0e1780e8..ae11e0af 100644 --- a/lensless/eval/metric.py +++ b/lensless/eval/metric.py @@ -301,6 +301,8 @@ def extract( estimate = rotate( estimate[vertical_crop[0] : vertical_crop[1], horizontal_crop[0] : horizontal_crop[1]], angle=rotation, + mode="nearest", + reshape=False, ) estimate /= estimate.max() estimate = np.clip(estimate, 0, 1) diff --git a/lensless/hardware/aperture.py b/lensless/hardware/aperture.py index 37e8e37b..c2e0e62b 100644 --- a/lensless/hardware/aperture.py +++ b/lensless/hardware/aperture.py @@ -183,6 +183,7 @@ def rect_aperture(slm_shape, pixel_pitch, apert_dim, center=None): apert_dim = np.array(apert_dim) top_left = center - apert_dim / 2 bottom_right = top_left + apert_dim + if ( top_left[0] < 0 or top_left[1] < 0 @@ -329,7 +330,7 @@ def _m_to_cell_idx(val, cell_m): :return: The cell index. :rtype: int """ - return int(val / cell_m) + return int(np.round(val / cell_m)) def prepare_index_vals(key, pixel_pitch): diff --git a/lensless/hardware/sensor.py b/lensless/hardware/sensor.py index c842dc2d..0785204e 100644 --- a/lensless/hardware/sensor.py +++ b/lensless/hardware/sensor.py @@ -88,7 +88,7 @@ class SensorParam: SensorParam.RESOLUTION: np.array([1088, 1456]), SensorParam.DIAGONAL: 6.3e-3, SensorParam.COLOR: True, - SensorParam.BIT_DEPTH: [8, 12], + SensorParam.BIT_DEPTH: [8, 10], SensorParam.MAX_EXPOSURE: 15534385e-6, SensorParam.MIN_EXPOSURE: 29e-6, }, diff --git a/lensless/hardware/slm.py b/lensless/hardware/slm.py index 572ae4a7..7e9c9f0e 100644 --- a/lensless/hardware/slm.py +++ b/lensless/hardware/slm.py @@ -9,12 +9,12 @@ import os import numpy as np from lensless.hardware.utils import check_username_hostname -from lensless.utils.io import get_dtype, get_ctypes +from lensless.utils.io import get_ctypes from slm_controller.hardware import SLMParam, slm_devices from waveprop.spherical import spherical_prop from waveprop.color import ColorSystem from waveprop.rs import angular_spectrum -from waveprop.slm import get_centers, get_color_filter +from waveprop.slm import get_centers from waveprop.devices import SLMParam as SLMParam_wp from scipy.ndimage import rotate as rotate_func @@ -35,7 +35,7 @@ } -def set_programmable_mask(pattern, device, rpi_username, rpi_hostname): +def set_programmable_mask(pattern, device, rpi_username, rpi_hostname, verbose=False): """ Set LCD pattern on Raspberry Pi. @@ -79,9 +79,12 @@ def set_programmable_mask(pattern, device, rpi_username, rpi_hostname): # copy pattern to Raspberry Pi remote_path = f"~/{pattern_fn}" - print(f"PUTTING {local_path} to {remote_path}") + if verbose: + print(f"PUTTING {local_path} to {remote_path}") - os.system('scp %s "%s@%s:%s" ' % (local_path, rpi_username, rpi_hostname, remote_path)) + os.system( + 'scp %s "%s@%s:%s" >/dev/null 2>&1' % (local_path, rpi_username, rpi_hostname, remote_path) + ) # # -- not sure why this doesn't work... permission denied # sftp = client.open_sftp() # sftp.put(local_path, remote_path, confirm=True) @@ -89,9 +92,11 @@ def set_programmable_mask(pattern, device, rpi_username, rpi_hostname): # run script on Raspberry Pi to set mask pattern command = f"{rpi_python} {script} --file_path {remote_path}" - print(f"COMMAND : {command}") + if verbose: + print(f"COMMAND : {command}") _stdin, _stdout, _stderr = client.exec_command(command) - print(_stdout.read().decode()) + if verbose: + print(_stdout.read().decode()) client.close() os.remove(local_path) @@ -104,6 +109,7 @@ def get_programmable_mask( rotate=None, flipud=False, nbits=8, + color_filter=None, ): """ Get mask as a numpy or torch array. Return same type. @@ -136,22 +142,21 @@ def get_programmable_mask( pixel_pitch = slm_param[SLMParam_wp.PITCH] centers = get_centers(n_active_slm_pixels, pixel_pitch=pixel_pitch) - if SLMParam_wp.COLOR_FILTER in slm_param.keys(): + if color_filter is None and SLMParam_wp.COLOR_FILTER in slm_param.keys(): color_filter = slm_param[SLMParam_wp.COLOR_FILTER] - if flipud: - color_filter = np.flipud(color_filter) - - cf = get_color_filter( - slm_dim=n_active_slm_pixels, - color_filter=color_filter, - shift=0, - flat=True, - ) + if isinstance(vals, torch.Tensor): + color_filter = torch.tensor(color_filter).to(vals) - else: + if color_filter is not None: - # monochrome - cf = None + if isinstance(color_filter, np.ndarray): + if flipud: + color_filter = np.flipud(color_filter) + elif isinstance(color_filter, torch.Tensor): + if flipud: + color_filter = torch.flip(color_filter, dims=(0,)) + else: + raise ValueError("color_filter must be numpy array or torch tensor") d1 = sensor.pitch _height_pixel, _width_pixel = (slm_param[SLMParam_wp.CELL_SIZE] / d1).astype(int) @@ -171,29 +176,26 @@ def get_programmable_mask( _center_pixel[1] + 1 - np.floor(_width_pixel / 2).astype(int), ) - if cf is not None: - _rect = np.tile(cf[i][:, np.newaxis, np.newaxis], (1, _height_pixel, _width_pixel)) - else: - _rect = np.ones((1, _height_pixel, _width_pixel)) - - if use_torch: - _rect = torch.tensor(_rect).to(slm_vals_flat) + color_filter_idx = i // n_active_slm_pixels[1] % n_color_filter + mask_val = slm_vals_flat[i] * color_filter[color_filter_idx][0] + if isinstance(mask_val, np.ndarray): + mask_val = mask_val[:, np.newaxis, np.newaxis] + elif isinstance(mask_val, torch.Tensor): + mask_val = mask_val.unsqueeze(-1).unsqueeze(-1) mask[ :, _center_top_left_pixel[0] : _center_top_left_pixel[0] + _height_pixel, _center_top_left_pixel[1] : _center_top_left_pixel[1] + _width_pixel, - ] = ( - slm_vals_flat[i] * _rect - ) + ] = mask_val - # quantize mask - if use_torch: - mask = mask / torch.max(mask) - mask = torch.round(mask * (2**nbits - 1)) / (2**nbits - 1) - else: - mask = mask / np.max(mask) - mask = np.round(mask * (2**nbits - 1)) / (2**nbits - 1) + # # quantize mask + # if use_torch: + # mask = mask / torch.max(mask) + # mask = torch.round(mask * (2**nbits - 1)) / (2**nbits - 1) + # else: + # mask = mask / np.max(mask) + # mask = np.round(mask * (2**nbits - 1)) / (2**nbits - 1) # rotate if rotate is not None: @@ -205,6 +207,46 @@ def get_programmable_mask( return mask +def adafruit_sub2full( + subpattern, + center, +): + sub_shape = subpattern.shape + controllable_shape = (3, sub_shape[0] // 3, sub_shape[1]) + subpattern_rgb = subpattern.reshape(controllable_shape, order="F") + subpattern_rgb *= 255 + + # pad to full pattern + pattern = np.zeros((3, 128, 160), dtype=np.uint8) + topleft = [center[0] - controllable_shape[1] // 2, center[1] - controllable_shape[2] // 2] + pattern[ + :, + topleft[0] : topleft[0] + controllable_shape[1], + topleft[1] : topleft[1] + controllable_shape[2], + ] = subpattern_rgb.astype(np.uint8) + return pattern + + +def full2subpattern( + pattern, + shape, + center, + slm=None, +): + shape = np.array(shape) + center = np.array(center) + + # extract region + idx_1 = center[0] - shape[0] // 2 + idx_2 = center[1] - shape[1] // 2 + subpattern = pattern[:, idx_1 : idx_1 + shape[0], idx_2 : idx_2 + shape[1]] + subpattern = subpattern / 255.0 + if slm == "adafruit": + # flatten color channel along rows + subpattern = subpattern.reshape((-1, subpattern.shape[-1]), order="F") + return subpattern + + def get_intensity_psf( mask, waveprop=False, @@ -239,8 +281,8 @@ def get_intensity_psf( is_torch = False device = None - if torch_available: - is_torch = isinstance(mask, torch.Tensor) + if torch_available and isinstance(mask, torch.Tensor): + is_torch = True device = mask.device dtype = mask.dtype @@ -268,7 +310,7 @@ def get_intensity_psf( wv=color_system.wv, dz=scene2mask, return_psf=True, - is_torch=True, + is_torch=is_torch, device=device, dtype=dtype, ) diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 9bc70bc8..a20d502d 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -9,6 +9,9 @@ import abc import torch from lensless.utils.image import is_grayscale +from lensless.hardware.slm import get_programmable_mask, get_intensity_psf +from lensless.hardware.sensor import VirtualSensor +from waveprop.devices import slm_dict class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): @@ -38,6 +41,7 @@ def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): super().__init__() self._mask = torch.nn.Parameter(initial_mask) self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs) + self.train_mask_vals = True self._counter = 0 @abc.abstractmethod @@ -53,12 +57,16 @@ def get_psf(self): raise NotImplementedError def update_mask(self): - """Update the mask parameters. Acoording to externaly updated gradiants.""" + """Update the mask parameters. According to externaly updated gradiants.""" self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self.project() self._counter += 1 + def get_vals(self): + """Get the mask parameters.""" + return self._mask + @abc.abstractmethod def project(self): """Abstract method for projecting the mask parameters to a valid space (should be a subspace of [0,1]).""" @@ -100,3 +108,125 @@ def get_psf(self): def project(self): self._mask.data = torch.clamp(self._mask, 0, 1) + + +class AdafruitLCD(TrainableMask): + def __init__( + self, + initial_vals, + sensor, + slm, + optimizer="Adam", + lr=1e-3, + train_mask_vals=True, + color_filter=None, + rotate=None, + flipud=False, + use_waveprop=None, + vertical_shift=None, + horizontal_shift=None, + scene2mask=None, + mask2sensor=None, + downsample=None, + min_val=0, + **kwargs + ): + """ + Parameters + ---------- + initial_vals : :py:class:`~torch.Tensor` + Initial mask parameters. + sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` + Sensor object. + slm_param : :py:class:`~lensless.hardware.slm.SLMParam` + SLM parameters. + rotate : float, optional + Rotation angle in degrees, by default None + flipud : bool, optional + Whether to flip the mask vertically, by default False + """ + + super().__init__(initial_vals, **kwargs) + self.train_mask_vals = train_mask_vals + if color_filter is not None: + self.color_filter = torch.nn.Parameter(color_filter) + if train_mask_vals: + param = [self._mask, self.color_filter] + else: + del self._mask + self._mask = initial_vals + param = [self.color_filter] + self._optimizer = getattr(torch.optim, optimizer)(param, lr=lr) + else: + self.color_filter = None + assert ( + train_mask_vals + ), "If color filter is not trainable, mask values must be trainable" + + self.slm_param = slm_dict[slm] + self.device = slm + self.sensor = VirtualSensor.from_name(sensor, downsample=downsample) + self.rotate = rotate + self.flipud = flipud + self.use_waveprop = use_waveprop + self.scene2mask = scene2mask + self.mask2sensor = mask2sensor + self.vertical_shift = vertical_shift + self.horizontal_shift = horizontal_shift + self.min_val = min_val + if downsample is not None and vertical_shift is not None: + self.vertical_shift = vertical_shift // downsample + if downsample is not None and horizontal_shift is not None: + self.horizontal_shift = horizontal_shift // downsample + if self.use_waveprop: + assert self.scene2mask is not None + assert self.mask2sensor is not None + + def get_psf(self): + + mask = get_programmable_mask( + vals=self._mask, + sensor=self.sensor, + slm_param=self.slm_param, + rotate=self.rotate, + flipud=self.flipud, + color_filter=self.color_filter, + ) + + if self.vertical_shift is not None: + mask = torch.roll(mask, self.vertical_shift, dims=1) + + if self.horizontal_shift is not None: + mask = torch.roll(mask, self.horizontal_shift, dims=2) + + psf_in = get_intensity_psf( + mask=mask, + sensor=self.sensor, + waveprop=self.use_waveprop, + scene2mask=self.scene2mask, + mask2sensor=self.mask2sensor, + ) + + # add first dimension (depth) + psf_in = psf_in.unsqueeze(0) + + # move channels to last dimension + psf_in = psf_in.permute(0, 2, 3, 1) + + # flip mask + psf_in = torch.flip(psf_in, dims=[-3, -2]) + + # normalize + psf_in = psf_in / psf_in.norm() + + return psf_in + + def project(self): + if self.train_mask_vals: + self._mask.data = torch.clamp(self._mask, self.min_val, 1) + if self.color_filter is not None: + self.color_filter.data = torch.clamp(self.color_filter, 0, 1) + # normalize each row to 1 + self.color_filter.data = self.color_filter / self.color_filter.sum( + dim=[1, 2] + ).unsqueeze(-1).unsqueeze(-1) diff --git a/lensless/hardware/utils.py b/lensless/hardware/utils.py index a52668c5..adadbc1c 100644 --- a/lensless/hardware/utils.py +++ b/lensless/hardware/utils.py @@ -3,9 +3,235 @@ import socket import subprocess import time - import paramiko +from pprint import pprint from paramiko.ssh_exception import AuthenticationException, BadHostKeyException, SSHException +from lensless.hardware.sensor import SensorOptions +import cv2 +from lensless.utils.image import print_image_info +from lensless.utils.io import load_image + + +import logging + +logging.getLogger("paramiko").setLevel(logging.WARNING) + + +def capture( + rpi_username, + rpi_hostname, + sensor, + bayer, + exp, + fn="capture", + iso=100, + config_pause=2, + sensor_mode="0", + nbits_out=12, + legacy=True, + rgb=False, + gray=False, + nbits=12, + down=None, + awb_gains=None, + rpi_python="~/LenslessPiCam/lensless_env/bin/python", + capture_script="~/LenslessPiCam/scripts/measure/on_device_capture.py", + verbose=False, + output_path=None, + **kwargs, +): + """ + Capture image. + + Parameters + ---------- + fn : str + File name captured image. + rpi_username : str + Username of Raspberry Pi. + rpi_hostname : str + Hostname of Raspberry Pi. + sensor : str + Sensor name + bayer : bool + Whether to return bayer data (larger file size to transfer back). + exp : int + Exposure time in microseconds. + iso : int + ISO. + config_pause : int + Time to pause after configuring camera. + sensor_mode : str + Sensor mode. + nbits_out : int + Number of bits of output image. + legacy : bool + Whether to use legacy capture software of Raspberry Pi. + rgb : bool + Whether to capture RGB image. + gray : bool + Whether to capture grayscale image. + nbits : int + Number of bits of image. + down : int + Downsample factor. + awb_gains : list + AWB gains (red, blue). + rpi_python : str + Path to Python on Raspberry Pi. + capture_script : str + Path to capture script on Raspberry Pi. + output_path : str + Path to save image. + verbose : bool + Whether to print extra info. + + """ + + # check_username_hostname(rpi_username, rpi_hostname) + assert sensor in SensorOptions.values(), f"Sensor must be one of {SensorOptions.values()}" + + # form command + remote_fn = "remote_capture" + pic_command = ( + f"{rpi_python} {capture_script} sensor={sensor} bayer={bayer} fn={remote_fn} exp={exp} iso={iso} " + f"config_pause={config_pause} sensor_mode={sensor_mode} nbits_out={nbits_out} " + f"legacy={legacy} rgb={rgb} gray={gray} " + ) + if nbits > 8: + pic_command += " sixteen=True" + if down: + pic_command += f" down={down}" + if awb_gains: + pic_command += f" awb_gains=[{awb_gains[0]},{awb_gains[1]}]" + + if verbose: + print(f"COMMAND : {pic_command}") + + # take picture + ssh = subprocess.Popen( + ["ssh", "%s@%s" % (rpi_username, rpi_hostname), pic_command], + shell=False, + # stdout=DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + result = ssh.stdout.readlines() + error = ssh.stderr.readlines() + + if error != [] and legacy: # new camera software seems to return error even if it works + print("ERROR: %s" % error) + return + if result == []: + error = ssh.stderr.readlines() + print("ERROR: %s" % error) + return + else: + result = [res.decode("UTF-8") for res in result] + result = [res for res in result if len(res) > 3] + result_dict = dict() + for res in result: + _key = res.split(":")[0].strip() + _val = "".join(res.split(":")[1:]).strip() + result_dict[_key] = _val + # result_dict = dict(map(lambda s: map(str.strip, s.split(":")), result)) + if verbose: + print("COMMAND OUTPUT : ") + pprint(result_dict) + + # copy over file + if ( + "RPi distribution" in result_dict.keys() + and "bullseye" in result_dict["RPi distribution"] + and not legacy + ): + + if bayer: + + # copy over DNG file + remotefile = f"~/{remote_fn}.dng" + localfile = f"{fn}.dng" + if output_path is not None: + localfile = os.path.join(output_path, localfile) + if verbose: + print(f"\nCopying over picture as {localfile}...") + os.system( + 'scp "%s@%s:%s" %s >/dev/null 2>&1' + % (rpi_username, rpi_hostname, remotefile, localfile) + ) + + img = load_image(localfile, verbose=True, bayer=bayer, nbits_out=nbits_out) + + # print image properties + print_image_info(img) + + # save as PNG + png_out = f"{fn}.png" + print(f"Saving RGB file as: {png_out}") + cv2.imwrite(png_out, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + else: + + remotefile = f"~/{remote_fn}.png" + localfile = f"{fn}.png" + if output_path is not None: + localfile = os.path.join(output_path, localfile) + if verbose: + print(f"\nCopying over picture as {localfile}...") + os.system( + 'scp "%s@%s:%s" %s >/dev/null 2>&1' + % (rpi_username, rpi_hostname, remotefile, localfile) + ) + + img = load_image(localfile, verbose=True) + + # legacy software running on RPi + else: + # copy over file + # more pythonic? https://stackoverflow.com/questions/250283/how-to-scp-in-python + remotefile = f"~/{remote_fn}.png" + localfile = f"{fn}.png" + if output_path is not None: + localfile = os.path.join(output_path, localfile) + if verbose: + print(f"\nCopying over picture as {localfile}...") + os.system( + 'scp "%s@%s:%s" %s >/dev/null 2>&1' + % (rpi_username, rpi_hostname, remotefile, localfile) + ) + + if rgb or gray: + img = load_image(localfile, verbose=verbose) + + else: + + if not bayer: + # red_gain = config.camera.red_gain + # blue_gain = config.camera.blue_gain + red_gain = awb_gains[0] + blue_gain = awb_gains[1] + else: + red_gain = None + blue_gain = None + + # load image + if verbose: + print("\nLoading picture...") + + img = load_image( + localfile, + verbose=True, + bayer=bayer, + blue_gain=blue_gain, + red_gain=red_gain, + nbits_out=nbits_out, + ) + + # write RGB data + if not bayer: + cv2.imwrite(localfile, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + return localfile, img def display( @@ -18,6 +244,7 @@ def display( pad=0, vshift=0, hshift=0, + verbose=False, **kwargs, ): """ @@ -43,16 +270,20 @@ def display( remote_tmp_file = "~/tmp_display.png" display_path = "~/LenslessPiCam_display/test.png" - os.system('scp %s "%s@%s:%s" ' % (fp, rpi_username, rpi_hostname, remote_tmp_file)) + os.system( + 'scp %s "%s@%s:%s" >/dev/null 2>&1' % (fp, rpi_username, rpi_hostname, remote_tmp_file) + ) # run script on Raspberry Pi to prepare image to display prep_command = f"{rpi_python} {script} --fp {remote_tmp_file} \ --pad {pad} --vshift {vshift} --hshift {hshift} --screen_res {screen_res[0]} {screen_res[1]} \ --brightness {brightness} --rot90 {rot90} --output_path {display_path} " - # print(f"COMMAND : {prep_command}") + if verbose: + print(f"COMMAND : {prep_command}") subprocess.Popen( ["ssh", "%s@%s" % (rpi_username, rpi_hostname), prep_command], shell=False, + # stdout=DEVNULL ) @@ -65,6 +296,7 @@ def check_username_hostname(username, hostname, timeout=10): client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: + # with suppress_stdout(): client.connect(hostname, username=username, timeout=timeout) except (BadHostKeyException, AuthenticationException, SSHException, socket.error) as e: raise ValueError(f"Could not connect to {username}@{hostname}\n{e}") @@ -98,7 +330,9 @@ def get_distro(): return f"{RELEASE_DATA['NAME']} {RELEASE_DATA['VERSION']}" -def set_mask_sensor_distance(distance, rpi_username, rpi_hostname, motor=1): +def set_mask_sensor_distance( + distance, rpi_username, rpi_hostname, motor=1, max_distance=16, timeout=5 +): """ Set the distance between the mask and sensor. @@ -115,13 +349,10 @@ def set_mask_sensor_distance(distance, rpi_username, rpi_hostname, motor=1): Hostname of Raspberry Pi. """ - MAX_DISTANCE = 16 # mm - timeout = 5 - client = check_username_hostname(rpi_username, rpi_hostname) assert motor in [0, 1] assert distance >= 0, "Distance must be non-negative" - assert distance < MAX_DISTANCE, f"Distance must be less than {MAX_DISTANCE} mm" + assert distance <= max_distance, f"Distance must be less than {max_distance} mm" # assumes that `StepperDriver` is in home directory rpi_python = "python3" @@ -130,7 +361,7 @@ def set_mask_sensor_distance(distance, rpi_username, rpi_hostname, motor=1): # reset to zero print("Resetting to zero distance...") try: - command = f"{rpi_python} {script} {motor} REV {MAX_DISTANCE * 1000}" + command = f"{rpi_python} {script} {motor} REV {max_distance * 1000}" _stdin, _stdout, _stderr = client.exec_command(command, timeout=timeout) except socket.timeout: # socket.timeout pass diff --git a/lensless/recon/drunet/network_unet.py b/lensless/recon/drunet/network_unet.py index 6f9c390e..8d51d00b 100644 --- a/lensless/recon/drunet/network_unet.py +++ b/lensless/recon/drunet/network_unet.py @@ -112,6 +112,8 @@ def __init__( ): super(UNetRes, self).__init__() + assert len(nc) == 4, "nc's length should be 4." + self.m_head = B.conv(in_nc, nc[0], bias=False, mode="C") # downsample diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 82fd883d..52343c19 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -49,6 +49,7 @@ def __init__( n_iter=1, pre_process=None, post_process=None, + skip_unrolled=False, **kwargs, ): """ @@ -79,19 +80,13 @@ def __init__( psf, dtype=dtype, n_iter=n_iter, **kwargs ) - # pre processing - ( - self.pre_process, - self.pre_process_model, - self.pre_process_param, - ) = self._prepare_process_block(pre_process) - - # post processing - ( - self.post_process, - self.post_process_model, - self.post_process_param, - ) = self._prepare_process_block(post_process) + self.set_pre_process(pre_process) + self.set_post_process(post_process) + self.skip_unrolled = skip_unrolled + if self.skip_unrolled: + assert ( + post_process is not None or pre_process is not None + ), "If skip_unrolled is True, pre_process or post_process must be defined." def _prepare_process_block(self, process): """ @@ -115,6 +110,7 @@ def _prepare_process_block(self, process): else: process_function = None process_model = None + if process_function is not None: process_param = torch.nn.Parameter(torch.tensor([1.0], device=self._psf.device)) else: @@ -122,6 +118,60 @@ def _prepare_process_block(self, process): return process_function, process_model, process_param + def set_pre_process(self, pre_process): + ( + self.pre_process, + self.pre_process_model, + self.pre_process_param, + ) = self._prepare_process_block(pre_process) + + def set_post_process(self, post_process): + ( + self.post_process, + self.post_process_model, + self.post_process_param, + ) = self._prepare_process_block(post_process) + + def freeze_pre_process(self): + """ + Method for freezing the pre process block. + """ + if self.pre_process_param is not None: + self.pre_process_param.requires_grad = False + if self.pre_process_model is not None: + for param in self.pre_process_model.parameters(): + param.requires_grad = False + + def freeze_post_process(self): + """ + Method for freezing the post process block. + """ + if self.post_process_param is not None: + self.post_process_param.requires_grad = False + if self.post_process_model is not None: + for param in self.post_process_model.parameters(): + param.requires_grad = False + + def unfreeze_pre_process(self): + """ + Method for unfreezing the pre process block. + """ + if self.pre_process_param is not None: + self.pre_process_param.requires_grad = True + if self.pre_process_model is not None: + for param in self.pre_process_model.parameters(): + param.requires_grad = True + + def unfreeze_post_process(self): + """ + Method for unfreezing the post process block. + """ + if self.post_process_param is not None: + self.post_process_param.requires_grad = True + if self.post_process_model is not None: + for param in self.post_process_model.parameters(): + param.requires_grad = True + def batch_call(self, batch): """ Method for performing iterative reconstruction on a batch of images. @@ -147,10 +197,14 @@ def batch_call(self, batch): self.reset(batch_size=batch_size) - for i in range(self._n_iter): - self._update(i) + if not self.skip_unrolled: + for i in range(self._n_iter): + self._update(i) + image_est = self._form_image() + + else: + image_est = self._data - image_est = self._form_image() if self.post_process is not None: image_est = self.post_process(image_est, self.post_process_param) return image_est @@ -207,16 +261,19 @@ def apply( if output_intermediate: pre_processed_image = self._data[0, ...].clone() - im = super(TrainableReconstructionAlgorithm, self).apply( - n_iter=self._n_iter, - disp_iter=disp_iter, - plot_pause=plot_pause, - plot=plot, - save=save, - gamma=gamma, - ax=ax, - reset=reset, - ) + if not self.skip_unrolled: + im = super(TrainableReconstructionAlgorithm, self).apply( + n_iter=self._n_iter, + disp_iter=disp_iter, + plot_pause=plot_pause, + plot=plot, + save=save, + gamma=gamma, + ax=ax, + reset=reset, + ) + else: + im = self._data # remove plot if returned if plot: diff --git a/lensless/recon/unrolled_admm.py b/lensless/recon/unrolled_admm.py index 43b6b956..d428ac17 100644 --- a/lensless/recon/unrolled_admm.py +++ b/lensless/recon/unrolled_admm.py @@ -78,10 +78,24 @@ def __init__( psf, n_iter=n_iter, dtype=dtype, pad=pad, norm=norm, reset=False, **kwargs ) - self._mu1_p = torch.nn.Parameter(torch.ones(self._n_iter, device=self._psf.device) * mu1) - self._mu2_p = torch.nn.Parameter(torch.ones(self._n_iter, device=self._psf.device) * mu2) - self._mu3_p = torch.nn.Parameter(torch.ones(self._n_iter, device=self._psf.device) * mu3) - self._tau_p = torch.nn.Parameter(torch.ones(self._n_iter, device=self._psf.device) * tau) + if not self.skip_unrolled: + self._mu1_p = torch.nn.Parameter( + torch.ones(self._n_iter, device=self._psf.device) * mu1 + ) + self._mu2_p = torch.nn.Parameter( + torch.ones(self._n_iter, device=self._psf.device) * mu2 + ) + self._mu3_p = torch.nn.Parameter( + torch.ones(self._n_iter, device=self._psf.device) * mu3 + ) + self._tau_p = torch.nn.Parameter( + torch.ones(self._n_iter, device=self._psf.device) * tau + ) + else: + self._mu1_p = torch.ones(self._n_iter, device=self._psf.device) * mu1 + self._mu2_p = torch.ones(self._n_iter, device=self._psf.device) * mu2 + self._mu3_p = torch.ones(self._n_iter, device=self._psf.device) * mu3 + self._tau_p = torch.ones(self._n_iter, device=self._psf.device) * tau # set prior if psi is None: diff --git a/lensless/recon/unrolled_fista.py b/lensless/recon/unrolled_fista.py index 1361cda1..7d083424 100644 --- a/lensless/recon/unrolled_fista.py +++ b/lensless/recon/unrolled_fista.py @@ -61,17 +61,22 @@ def __init__(self, psf, n_iter=5, dtype=None, proj=non_neg, learn_tk=True, tk=1, # learnable step size initialize as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) - self._alpha_p = torch.nn.Parameter( - torch.ones(self._n_iter, self._psf_shape[3]).to(psf.device) - * (1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) - ) + if not self.skip_unrolled: + self._alpha_p = torch.nn.Parameter( + torch.ones(self._n_iter, self._psf_shape[3]).to(psf.device) + * (1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) + ) + else: + self._alpha_p = torch.ones(self._n_iter, self._psf_shape[3]).to(psf.device) * ( + 1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values + ) # set tk, can be learnt if learn_tk=True self._tk_p = [tk] for i in range(self._n_iter): self._tk_p.append((1 + np.sqrt(1 + 4 * self._tk_p[i] ** 2)) / 2) self._tk_p = torch.Tensor(self._tk_p) - if learn_tk: + if learn_tk and not self.skip_unrolled: self._tk_p = torch.nn.Parameter(self._tk_p).to(psf.device) self.reset() diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 2ca758c6..53f23a1b 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -193,7 +193,7 @@ def measure_gradient(model): return total_norm -def create_process_network(network, depth, device="cpu"): +def create_process_network(network, depth, device="cpu", nc=None): """ Helper function to create a process network. @@ -211,6 +211,12 @@ def create_process_network(network, depth, device="cpu"): :py:class:`torch.nn.Module` New process network. Already trained for Drunet. """ + + if nc is None: + nc = [64, 128, 256, 512] + else: + assert len(nc) == 4 + if network == "DruNet": from lensless.recon.utils import load_drunet @@ -223,7 +229,7 @@ def create_process_network(network, depth, device="cpu"): process = UNetRes( in_nc=n_channels + 1, out_nc=n_channels, - nc=[64, 128, 256, 512], + nc=nc, nb=depth, act_mode="R", downsample_mode="strideconv", @@ -250,15 +256,24 @@ def __init__( loss="l2", lpips=None, l1_mask=None, - optimizer="Adam", - optimizer_lr=1e-6, - slow_start=None, + optimizer=None, skip_NAN=False, algorithm_name="Unknown", metric_for_best_model=None, save_every=None, gamma=None, logger=None, + crop=None, + clip_grad=1.0, + # for adding components during training + pre_process=None, + pre_process_delay=None, + pre_process_freeze=None, + pre_process_unfreeze=None, + post_process=None, + post_process_delay=None, + post_process_freeze=None, + post_process_unfreeze=None, ): """ Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. @@ -291,12 +306,8 @@ def __init__( the weight of the lpips(VGG) in the total loss. If None ignore. By default None. l1_mask : float, optional the weight of the l1 norm of the mask in the total loss. If None ignore. By default None. - optimizer : str, optional - Optimizer to use durring training. Available : "Adam". By default "Adam". - optimizer_lr : float, optional - Learning rate for the optimizer, by default 1e-6. - slow_start : float, optional - Multiplicative factor to reduce the learning rate during the first two epochs. If None, ignored. Default is None. + optimizer : dict + Optimizer configuration. skip_NAN : bool, optional Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False algorithm_name : str, optional @@ -309,13 +320,55 @@ def __init__( Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None. logger : :py:class:`logging.Logger`, optional Logger to use for logging. If None, just print to terminal. Default is None. + crop : dict, optional + Crop to apply to images before computing loss (by applying a mask). If None, no crop is applied. Default is None. + pre_process : :py:class:`torch.nn.Module`, optional + Pre process component to add during training. Default is None. + pre_process_delay : int, optional + Epoch at which to add pre process component. Default is None. + pre_process_freeze : int, optional + Epoch at which to freeze pre process component. Default is None. + pre_process_unfreeze : int, optional + Epoch at which to unfreeze pre process component. Default is None. + post_process : :py:class:`torch.nn.Module`, optional + Post process component to add during training. Default is None. + post_process_delay : int, optional + Epoch at which to add post process component. Default is None. + post_process_freeze : int, optional + Epoch at which to freeze post process component. Default is None. + post_process_unfreeze : int, optional + Epoch at which to unfreeze post process component. Default is None. """ + global print + self.device = recon._psf.device self.logger = logger + if self.logger is not None: + self.print = self.logger.info + else: + self.print = print self.recon = recon + self.pre_process = pre_process + self.pre_process_delay = pre_process_delay + self.pre_process_freeze = pre_process_freeze + self.pre_process_unfreeze = pre_process_unfreeze + if pre_process_delay is not None: + assert pre_process is not None + else: + self.pre_process_delay = -1 + + self.post_process = post_process + self.post_process_delay = post_process_delay + self.post_process_freeze = post_process_freeze + self.post_process_unfreeze = post_process_unfreeze + if post_process_delay is not None: + assert post_process is not None + else: + self.post_process_delay = -1 + assert train_dataset is not None if test_dataset is None: assert test_size < 1.0 and test_size > 0.0 @@ -325,10 +378,7 @@ def __init__( train_dataset, test_dataset = torch.utils.data.random_split( train_dataset, [train_size, test_size] ) - if self.logger is not None: - self.logger.info(f"Train size : {train_size}, Test size : {test_size}") - else: - print(f"Train size : {train_size}, Test size : {test_size}") + self.print(f"Train size : {train_size}, Test size : {test_size}") self.train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, @@ -341,9 +391,9 @@ def __init__( self.skip_NAN = skip_NAN self.eval_batch_size = eval_batch_size + self.mask = mask if mask is not None: assert isinstance(mask, TrainableMask) - self.mask = mask self.use_mask = True else: self.use_mask = False @@ -370,32 +420,12 @@ def __init__( "lpips package is need for LPIPS loss. Install using : pip install lpips" ) - # optimizer - if optimizer == "Adam": - # the parameters of the base model and non torch.Module process must be added separatly - parameters = [{"params": recon.parameters()}] - self.optimizer = torch.optim.Adam(parameters, lr=optimizer_lr) - else: - raise ValueError(f"Unsuported optimizer : {optimizer}") - # Scheduler - if slow_start: - - def learning_rate_function(epoch): - if epoch == 0: - return slow_start - elif epoch == 1: - return math.sqrt(slow_start) - else: - return 1 - - else: - - def learning_rate_function(epoch): - return 1 + self.crop = crop - self.scheduler = torch.optim.lr_scheduler.LambdaLR( - self.optimizer, lr_lambda=learning_rate_function - ) + # optimizer + self.clip_grad_norm = clip_grad + self.optimizer_config = optimizer + self.set_optimizer() self.metrics = { "LOSS": [], # train loss @@ -429,10 +459,7 @@ def detect_nan(grad): print(grad, flush=True) for name, param in recon.named_parameters(): if param.requires_grad: - if self.logger: - self.logger.info(name, param) - else: - print(name, param) + self.print(name, param) raise ValueError("Gradient is NaN") return grad @@ -442,7 +469,49 @@ def detect_nan(grad): if param.requires_grad: param.register_hook(detect_nan) - def train_epoch(self, data_loader, disp=-1): + def set_optimizer(self, last_epoch=-1): + + if self.optimizer_config.type == "Adam": + parameters = [{"params": self.recon.parameters()}] + self.optimizer = torch.optim.Adam(parameters, lr=self.optimizer_config.lr) + else: + raise ValueError(f"Unsupported optimizer : {self.optimizer_config.type}") + + # Scheduler + if self.optimizer_config.slow_start: + + def learning_rate_function(epoch): + if epoch == 0: + return self.optimizer_config.slow_start + elif epoch == 1: + return math.sqrt(self.optimizer_config.slow_start) + else: + return 1 + + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=learning_rate_function, last_epoch=last_epoch + ) + + elif self.optimizer_config.step: + + self.scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, + step_size=self.optimizer_config.step, + gamma=self.optimizer_config.gamma, + last_epoch=last_epoch, + verbose=True, + ) + + else: + + def learning_rate_function(epoch): + return 1 + + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=learning_rate_function, last_epoch=last_epoch + ) + + def train_epoch(self, data_loader): """ Train for one epoch. @@ -450,8 +519,6 @@ def train_epoch(self, data_loader, disp=-1): ---------- data_loader : :py:class:`torch.utils.data.DataLoader` Data loader to use for training. - disp : int - Display interval, if -1, no display Returns ------- @@ -468,7 +535,7 @@ def train_epoch(self, data_loader, disp=-1): # update psf according to mask if self.use_mask: - self.recon._set_psf(self.mask.get_psf()) + self.recon._set_psf(self.mask.get_psf().to(self.device)) # forward pass y_pred = self.recon.batch_call(X.to(self.device)) @@ -481,20 +548,24 @@ def train_epoch(self, data_loader, disp=-1): y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps y = y / y_max - if i % disp == 1: - img_pred = y_pred[0, 0].cpu().detach().numpy() - img_truth = y[0, 0].cpu().detach().numpy() - - plt.imshow(img_pred) - plt.savefig(f"y_pred_{i-1}.png") - plt.imshow(img_truth) - plt.savefig(f"y_{i-1}.png") - self.optimizer.zero_grad(set_to_none=True) # convert to CHW for loss and remove depth y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) + # crop + if self.crop is not None: + y_pred = y_pred[ + ..., + self.crop["vertical"][0] : self.crop["vertical"][1], + self.crop["horizontal"][0] : self.crop["horizontal"][1], + ] + y = y[ + ..., + self.crop["vertical"][0] : self.crop["vertical"][1], + self.crop["horizontal"][0] : self.crop["horizontal"][1], + ] + loss_v = self.Loss(y_pred, y) if self.lpips: @@ -511,20 +582,18 @@ def train_epoch(self, data_loader, disp=-1): loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask)) loss_v.backward() - torch.nn.utils.clip_grad_norm_(self.recon.parameters(), 1.0) + if self.clip_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(self.recon.parameters(), self.clip_grad_norm) # if any gradient is NaN, skip training step if self.skip_NAN: is_NAN = False for param in self.recon.parameters(): - if torch.isnan(param.grad).any(): + if param.grad is not None and torch.isnan(param.grad).any(): is_NAN = True break if is_NAN: - if self.logger is not None: - self.logger.info("NAN detected in gradiant, skipping training step") - else: - print("NAN detected in gradiant, skipping training step") + self.print("NAN detected in gradiant, skipping training step") i += 1 continue self.optimizer.step() @@ -537,12 +606,11 @@ def train_epoch(self, data_loader, disp=-1): pbar.set_description(f"loss : {mean_loss}") i += 1 - if self.logger is not None: - self.logger.info(f"loss : {mean_loss}") + self.print(f"loss : {mean_loss}") return mean_loss - def evaluate(self, mean_loss, save_pt): + def evaluate(self, mean_loss, save_pt, epoch, disp=None): """ Evaluate the reconstruction algorithm on the test dataset. @@ -552,11 +620,28 @@ def evaluate(self, mean_loss, save_pt): Mean loss of the last epoch. save_pt : str Path to save metrics dictionary to. If None, no logging of metrics. + disp : list of int, optional + Test set examples to visualize at the end of each epoch, by default None. """ if self.test_dataset is None: return + + output_dir = None + if disp is not None: + output_dir = os.path.join("eval_recon") + if not os.path.exists(output_dir): + os.mkdir(output_dir) + output_dir = os.path.join(output_dir, str(epoch)) + # benchmarking - current_metrics = benchmark(self.recon, self.test_dataset, batchsize=self.eval_batch_size) + current_metrics = benchmark( + self.recon, + self.test_dataset, + batchsize=self.eval_batch_size, + save_idx=disp, + output_dir=output_dir, + crop=self.crop, + ) # update metrics with current metrics self.metrics["LOSS"].append(mean_loss) @@ -566,7 +651,7 @@ def evaluate(self, mean_loss, save_pt): if save_pt: # save dictionary metrics to file with json with open(os.path.join(save_pt, "metrics.json"), "w") as f: - json.dump(self.metrics, f) + json.dump(self.metrics, f, indent=4) # check best metric if self.metrics["metric_for_best_model"] is None: @@ -579,7 +664,7 @@ def evaluate(self, mean_loss, save_pt): else: return current_metrics[self.metrics["metric_for_best_model"]] - def on_epoch_end(self, mean_loss, save_pt, epoch): + def on_epoch_end(self, mean_loss, save_pt, epoch, disp=None): """ Called at the end of each epoch. @@ -591,6 +676,8 @@ def on_epoch_end(self, mean_loss, save_pt, epoch): Path to save metrics dictionary to. If None, no logging of metrics. epoch : int Current epoch. + disp : list of int, optional + Test set examples to visualize at the end of each epoch, by default None. """ if save_pt is None: # Use current directory @@ -598,7 +685,7 @@ def on_epoch_end(self, mean_loss, save_pt, epoch): # save model # self.save(path=save_pt, include_optimizer=False) - epoch_eval_metric = self.evaluate(mean_loss, save_pt) + epoch_eval_metric = self.evaluate(mean_loss, save_pt, epoch, disp=disp) new_best = False if ( self.metrics["metric_for_best_model"] == "PSNR" @@ -619,7 +706,7 @@ def on_epoch_end(self, mean_loss, save_pt, epoch): if self.save_every is not None and epoch % self.save_every == 0: self.save(path=save_pt, include_optimizer=False, epoch=epoch) - def train(self, n_epoch=1, save_pt=None, disp=-1): + def train(self, n_epoch=1, save_pt=None, disp=None): """ Train the reconstruction algorithm. @@ -629,27 +716,56 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): Number of epochs to train for, by default 1 save_pt : str, optional Path to save metrics dictionary to. If None, use current directory, by default None - disp : int, optional - Display interval, if -1, no display. Default is -1. + disp : list of int, optional + test set examples to visualize at the end of each epoch, by default None. """ start_time = time.time() - self.evaluate(-1, save_pt) + self.evaluate(-1, save_pt, epoch=0, disp=disp) for epoch in range(n_epoch): - if self.logger is not None: - self.logger.info(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") - else: - print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") - mean_loss = self.train_epoch(self.train_dataloader, disp=disp) + + # add extra components (if specified) + changing_n_param = False + if epoch == self.pre_process_delay: + self.print("Adding pre process component") + self.recon.set_pre_process(self.pre_process) + changing_n_param = True + if epoch == self.post_process_delay: + self.print("Adding post process component") + self.recon.set_post_process(self.post_process) + changing_n_param = True + if epoch == self.pre_process_freeze: + self.print("Freezing pre process") + self.recon.freeze_pre_process() + changing_n_param = True + if epoch == self.post_process_freeze: + self.print("Freezing post process") + self.recon.freeze_post_process() + changing_n_param = True + if epoch == self.pre_process_unfreeze: + self.print("Unfreezing pre process") + self.recon.unfreeze_pre_process() + changing_n_param = True + if epoch == self.post_process_unfreeze: + self.print("Unfreezing post process") + self.recon.unfreeze_post_process() + changing_n_param = True + + # count number of parameters with requires_grad = True + if changing_n_param: + n_param = sum(p.numel() for p in self.recon.parameters() if p.requires_grad) + if self.mask is not None: + n_param += sum(p.numel() for p in self.mask.parameters() if p.requires_grad) + self.print(f"Training {n_param} parameters") + + self.print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") + mean_loss = self.train_epoch(self.train_dataloader) # offset because of evaluate before loop - self.on_epoch_end(mean_loss, save_pt, epoch + 1) + self.on_epoch_end(mean_loss, save_pt, epoch + 1, disp=disp) self.scheduler.step() - if self.logger is not None: - self.logger.info(f"Train time : {time.time() - start_time} s") - else: - print(f"Train time : {time.time() - start_time} s") + self.print(f"Train time : {time.time() - start_time} s") def save(self, epoch, path="recon", include_optimizer=False): # create directory if it does not exist @@ -657,7 +773,22 @@ def save(self, epoch, path="recon", include_optimizer=False): os.makedirs(path) # save mask if self.use_mask: - torch.save(self.mask._mask, os.path.join(path, f"mask_epoch{epoch}.pt")) + # torch.save(self.mask._mask, os.path.join(path, f"mask_epoch{epoch}.pt")) + + # save mask as numpy array + if self.mask.train_mask_vals: + np.save( + os.path.join(path, f"mask_epoch{epoch}.npy"), + self.mask._mask.cpu().detach().numpy(), + ) + + if self.mask.color_filter is not None: + # save save numpy array + np.save( + os.path.join(path, f"mask_color_filter_epoch{epoch}.npy"), + self.mask.color_filter.cpu().detach().numpy(), + ) + torch.save( self.mask._optimizer.state_dict(), os.path.join(path, f"mask_optim_epoch{epoch}.pt") ) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index a5a2e8a9..093cc298 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -16,6 +16,22 @@ from lensless.utils.simulation import FarFieldSimulator from lensless.utils.io import load_image, load_psf from lensless.utils.image import resize +import re +from lensless.hardware.utils import capture +from lensless.hardware.utils import display +from lensless.hardware.slm import set_programmable_mask, adafruit_sub2full + + +def convert(text): + return int(text) if text.isdigit() else text.lower() + + +def alphanum_key(key): + return [convert(c) for c in re.split("([0-9]+)", key)] + + +def natural_sort(arr): + return sorted(arr, key=alphanum_key) class DualDataset(Dataset): @@ -133,8 +149,8 @@ def __getitem__(self, idx): # flip image x and y if needed if self.flip: - lensless = torch.rot90(lensless, dims=(-3, -2)) - lensed = torch.rot90(lensed, dims=(-3, -2)) + lensless = torch.rot90(lensless, dims=(-3, -2), k=2) + lensed = torch.rot90(lensed, dims=(-3, -2), k=2) if self.transform_lensless: lensless = self.transform_lensless(lensless) if self.transform_lensed: @@ -157,6 +173,10 @@ def __init__( pre_transform=None, dataset_is_CHW=False, flip=False, + vertical_shift=None, + horizontal_shift=None, + crop=None, + downsample=1, **kwargs, ): """ @@ -181,11 +201,25 @@ def __init__( assert isinstance(dataset, Dataset) self.dataset = dataset self.n_files = len(dataset) - self.dataset_is_CHW = dataset_is_CHW self._pre_transform = pre_transform self.flip_pre_sim = flip + self.vertical_shift = vertical_shift + self.horizontal_shift = horizontal_shift + self.crop = crop + if downsample != 1: + if self.vertical_shift is not None: + self.vertical_shift = int(self.vertical_shift // downsample) + if self.horizontal_shift is not None: + self.horizontal_shift = int(self.horizontal_shift // downsample) + + if crop is not None: + self.crop["vertical"][0] = int(self.crop["vertical"][0] // downsample) + self.crop["vertical"][1] = int(self.crop["vertical"][1] // downsample) + self.crop["horizontal"][0] = int(self.crop["horizontal"][0] // downsample) + self.crop["horizontal"][1] = int(self.crop["horizontal"][1] // downsample) + # check simulator assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" assert simulator.is_torch, "Simulator should be a pytorch simulator" @@ -212,6 +246,11 @@ def _get_images_pair(self, index): lensless, lensed = self.sim.propagate_image(img, return_object_plane=True) + if self.vertical_shift is not None: + lensed = torch.roll(lensed, self.vertical_shift, dims=-3) + if self.horizontal_shift is not None: + lensed = torch.roll(lensed, self.horizontal_shift, dims=-2) + if lensed.shape[-1] == 1 and lensless.shape[-1] == 3: # copy to 3 channels lensed = lensed.repeat(1, 1, 3) @@ -230,20 +269,27 @@ def __len__(self): class MeasuredDatasetSimulatedOriginal(DualDataset): """ + Abstract class for defining a dataset of paired lensed and lensless images. + Dataset consisting of lensless image captured from a screen and the corresponding image shown on the screen. Unlike :py:class:`lensless.utils.dataset.MeasuredDataset`, the ground-truth lensed image is simulated using a :py:class:`lensless.utils.simulation.FarFieldSimulator` object rather than measured with a lensed camera. + + The class assumes that the ``measured_dir`` and ``original_dir`` have file names that match. + + The method ``_get_images_pair`` must be defined. """ def __init__( self, - root_dir, + measured_dir, + original_dir, simulator, - lensless_fn="diffuser", - original_fn="lensed", - image_ext="npy", - original_ext=None, + measurement_ext="png", + original_ext="jpg", downsample=1, + background=None, + flip=False, **kwargs, ): """ @@ -251,42 +297,34 @@ def __init__( Parameters ---------- - root_dir : str - Path to the test dataset. It is expected to contain two folders: one of lensless images and one of original images. - simulator : :py:class:`lensless.utils.simulatorFarFieldSimulator` - Simulator to use for the projection of the original image to object space. The PSF **should not** be specified, and it is expect to have ``is_torch = True``. - lensless_fn : str, optional - Name of the folder containing the lensless images, by default "diffuser". - lensed_fn : str, optional - Name of the folder containing the lensed images, by default "lensed". - image_ext : str, optional - Extension of the images, by default "npy". - original_ext : str, optional - Extension of the original image if different from lenless, by default None. - downsample : int, optional - Downsample factor of the lensless images, by default 1. """ - super(MeasuredDatasetSimulatedOriginal, self).__init__(downsample=1, **kwargs) + super(MeasuredDatasetSimulatedOriginal, self).__init__( + downsample=1, background=background, flip=flip, **kwargs + ) self.pre_downsample = downsample - self.root_dir = root_dir - self.lensless_dir = os.path.join(root_dir, lensless_fn) - self.original_dir = os.path.join(root_dir, original_fn) - assert os.path.isdir(self.lensless_dir) + self.measured_dir = measured_dir + self.original_dir = original_dir + assert os.path.isdir(self.measured_dir) assert os.path.isdir(self.original_dir) - self.image_ext = image_ext.lower() - self.original_ext = original_ext.lower() if original_ext is not None else image_ext.lower() + self.measurement_ext = measurement_ext.lower() + self.original_ext = original_ext.lower() + + files = natural_sort(glob.glob(os.path.join(self.measured_dir, "*." + measurement_ext))) - files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) - files.sort() self.files = [os.path.basename(fn) for fn in files] if len(self.files) == 0: raise FileNotFoundError( - f"No files found in {self.lensless_dir} with extension {image_ext}" + f"No files found in {self.measured_dir} with extension {self.measurement_ext }" ) + # check that corresponding files exist + for fn in self.files: + original_fp = os.path.join(self.original_dir, fn[:-3] + self.original_ext) + assert os.path.exists(original_fp), f"File {original_fp} does not exist" + # check simulator assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" assert simulator.is_torch, "Simulator should be a pytorch simulator" @@ -299,30 +337,200 @@ def __len__(self): else: return len([i for i in self.indices if i < len(self.files)]) - def _get_images_pair(self, idx): - if self.image_ext == "npy" or self.image_ext == "npz": - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - original_fp = os.path.join(self.original_dir, self.files[idx]) - lensless = np.load(lensless_fp) - lensless = resize(lensless, factor=1 / self.downsample) - original = np.load(original_fp[:-3] + self.original_ext) - else: - # more standard image formats: png, jpg, tiff, etc. - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - original_fp = os.path.join(self.original_dir, self.files[idx]) - lensless = load_image(lensless_fp, downsample=self.pre_downsample) - original = load_image( - original_fp[:-3] + self.original_ext, downsample=self.pre_downsample + # def _get_images_pair(self, idx): + # if self.image_ext == "npy" or self.image_ext == "npz": + # lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + # original_fp = os.path.join(self.original_dir, self.files[idx]) + # lensless = np.load(lensless_fp) + # lensless = resize(lensless, factor=1 / self.downsample) + # original = np.load(original_fp[:-3] + self.original_ext) + # else: + # # more standard image formats: png, jpg, tiff, etc. + # lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + # original_fp = os.path.join(self.original_dir, self.files[idx]) + # lensless = load_image(lensless_fp, downsample=self.pre_downsample) + # original = load_image( + # original_fp[:-3] + self.original_ext, downsample=self.pre_downsample + # ) + + # # convert to float + # if lensless.dtype == np.uint8: + # lensless = lensless.astype(np.float32) / 255 + # original = original.astype(np.float32) / 255 + # else: + # # 16 bit + # lensless = lensless.astype(np.float32) / 65535 + # original = original.astype(np.float32) / 65535 + + # # convert to torch + # lensless = torch.from_numpy(lensless) + # original = torch.from_numpy(original) + + # # project original image to lensed space + # with torch.no_grad(): + # lensed = self.sim.propagate_image() + + # return lensless, lensed + + +class DigiCamCelebA(MeasuredDatasetSimulatedOriginal): + def __init__( + self, + celeba_root, + data_dir=None, + psf_path=None, + downsample=1, + flip=True, + vertical_shift=None, + horizontal_shift=None, + crop=None, + simulation_config=None, + **kwargs, + ): + """ + + Some parameters default to work for the ``celeba_adafruit_random_2mm_20230720_10K`` dataset, + namely: flip, vertical_shift, horizontal_shift, crop, simulation_config. + + Parameters + ---------- + celeba_root : str + Path to the CelebA dataset. + data_dir : str, optional + Path to the lensless images, by default looks inside the ``data`` folder. Can download if not available. + psf_path : str, optional + Path to the PSF of the imaging system, by default looks inside the ``data/psf`` folder. Can download if not available. + downsample : int, optional + Downsample factor of the lensless images, by default 1. + flip : bool, optional + If True, measurements are flipped, by default ``True``. Does not get applied to the original images. + vertical_shift : int, optional + Vertical shift (in pixels) of the lensed images to align. + horizontal_shift : int, optional + Horizontal shift (in pixels) of the lensed images to align. + crop : dict, optional + Dictionary of crop parameters (vertical: [start, end], horizontal: [start, end]) to select region of interest. + """ + + if vertical_shift is None: + # default to (no downsampling) of celeba_adafruit_random_2mm_20230720_10K + vertical_shift = -85 + horizontal_shift = -5 + + if crop is None: + crop = {"vertical": [30, 560], "horizontal": [285, 720]} + self.crop = crop + + self.vertical_shift = vertical_shift + self.horizontal_shift = horizontal_shift + if downsample != 1: + self.vertical_shift = int(self.vertical_shift // downsample) + self.horizontal_shift = int(self.horizontal_shift // downsample) + + self.crop["vertical"][0] = int(self.crop["vertical"][0] // downsample) + self.crop["vertical"][1] = int(self.crop["vertical"][1] // downsample) + self.crop["horizontal"][0] = int(self.crop["horizontal"][0] // downsample) + self.crop["horizontal"][1] = int(self.crop["horizontal"][1] // downsample) + + # download dataset if necessary + if data_dir is None: + data_dir = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "data", + "celeba_adafruit_random_2mm_20230720_10K", ) + if not os.path.isdir(data_dir): + main_dir = os.path.join(os.path.dirname(__file__), "..", "..", "data") + print("DigiCam CelebA dataset not found.") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download this dataset of 10K examples (12.2GB)?" - # convert to float - if lensless.dtype == np.uint8: - lensless = lensless.astype(np.float32) / 255 - original = original.astype(np.float32) / 255 - else: - # 16 bit - lensless = lensless.astype(np.float32) / 65535 - original = original.astype(np.float32) / 65535 + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/9NNGCJs3DoBDGlY/download" + filename = "celeba_adafruit_random_2mm_20230720_10K.zip" + download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True) + + # download PSF if necessary + if psf_path is None: + psf_path = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "data", + "psf", + "adafruit_random_2mm_20231907.png", + ) + if not os.path.exists(psf_path): + try: + from torchvision.datasets.utils import download_url + except ImportError: + exit() + msg = "Do you want to download the PSF (38.8MB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + output_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "psf") + if valid: + url = "https://drive.switch.ch/index.php/s/kfN5vOqvVkNyHmc/download" + filename = "adafruit_random_2mm_20231907.png" + download_url(url, output_path, filename=filename) + + # load PSF + self.flip_measurement = flip + psf, background = load_psf( + psf_path, + downsample=downsample * 4, # PSF is 4x the resolution of the images + return_float=True, + return_bg=True, + flip=flip, + bg_pix=(0, 15), + ) + self.psf = torch.from_numpy(psf) + + # create simulator + simulation_config["output_dim"] = tuple(self.psf.shape[-3:-1]) + simulator = FarFieldSimulator( + is_torch=True, + **simulation_config, + ) + + super().__init__( + measured_dir=data_dir, + original_dir=os.path.join(celeba_root, "celeba", "img_align_celeba"), + simulator=simulator, + measurement_ext="png", + original_ext="jpg", + downsample=downsample, + background=background, + flip=False, # will do flipping only on measurement + **kwargs, + ) + + def _get_images_pair(self, idx): + + # more standard image formats: png, jpg, tiff, etc. + lensless_fp = os.path.join(self.measured_dir, self.files[idx]) + original_fp = os.path.join(self.original_dir, self.files[idx][:-3] + self.original_ext) + lensless = load_image( + lensless_fp, downsample=self.pre_downsample, flip=self.flip_measurement + ) + original = load_image(original_fp[:-3] + self.original_ext) + + # convert to float + if lensless.dtype == np.uint8: + lensless = lensless.astype(np.float32) / 255 + original = original.astype(np.float32) / 255 + else: + # 16 bit + lensless = lensless.astype(np.float32) / 65535 + original = original.astype(np.float32) / 65535 # convert to torch lensless = torch.from_numpy(lensless) @@ -330,7 +538,12 @@ def _get_images_pair(self, idx): # project original image to lensed space with torch.no_grad(): - lensed = self.sim.propagate_image() + lensed = self.sim.propagate_image(original, return_object_plane=True) + + if self.vertical_shift is not None: + lensed = torch.roll(lensed, self.vertical_shift, dims=-3) + if self.horizontal_shift is not None: + lensed = torch.roll(lensed, self.horizontal_shift, dims=-2) return lensless, lensed @@ -376,8 +589,7 @@ def __init__( self.image_ext = image_ext.lower() - files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) - files.sort() + files = natural_sort(glob.glob(os.path.join(self.lensless_dir, "*." + image_ext))) self.files = [os.path.basename(fn) for fn in files] if len(self.files) == 0: @@ -432,6 +644,26 @@ def __init__( **kwargs, ): + # check psf path exist + if not os.path.exists(psf_path): + psf_path = os.path.join( + os.path.dirname(__file__), "..", "..", "data", "psf", "diffusercam_psf.tiff" + ) + + try: + from torchvision.datasets.utils import download_url + except ImportError: + exit() + msg = "Do you want to download the DiffuserCam PSF (5.9MB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + output_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "psf") + if valid: + url = "https://drive.switch.ch/index.php/s/BteiuEcONmhmDSn/download" + filename = "diffusercam_psf.tiff" + download_url(url, output_path, filename=filename) + psf, background = load_psf( psf_path, downsample=downsample * 4, # PSF is 4x the resolution of the images @@ -443,6 +675,10 @@ def __init__( self.psf = transform_BRG2RGB(torch.from_numpy(psf)) self.allowed_idx = np.arange(2, 25001) + assert os.path.isdir(os.path.join(dataset_dir, "diffuser_images")) and os.path.isdir( + os.path.join(dataset_dir, "ground_truth_lensed") + ), "Dataset should contain 'diffuser_images' and 'ground_truth_lensed' folders. It can be downloaded from https://waller-lab.github.io/LenslessLearning/dataset.html" + super().__init__( root_dir=dataset_dir, background=background, @@ -502,12 +738,12 @@ def __init__( ) if not os.path.isdir(data_dir): main_dir = os.path.join(os.path.dirname(__file__), "..", "..", "data") - print("No dataset found for benchmarking.") + print("DiffuserCam test set not found for benchmarking.") try: from torchvision.datasets.utils import download_and_extract_archive except ImportError: exit() - msg = "Do you want to download the sample dataset (3.5GB)?" + msg = "Do you want to download the dataset (3.5GB)?" # default to yes if no input is given valid = input("%s (Y/n) " % msg).lower() != "n" @@ -596,3 +832,93 @@ def _get_images_pair(self, index): # return simulated images return super()._get_images_pair(index) + + +class HITLDatasetTrainableMask(SimulatedDatasetTrainableMask): + """ + Dataset of on-the-fly measurements and simulated ground-truth. + """ + + def __init__( + self, + rpi_username, + rpi_hostname, + celeba_root, + display_config, + capture_config, + mask_center, + **kwargs, + ): + self.rpi_username = rpi_username + self.rpi_hostname = rpi_hostname + self.celeba_root = celeba_root + assert os.path.isdir(self.celeba_root) + + self.display_config = display_config + self.capture_config = capture_config + self.mask_center = mask_center + + super(HITLDatasetTrainableMask, self).__init__(**kwargs) + + def __getitem__(self, index): + + # propagate through mask in digital model + _, lensed = super().__getitem__(index) + + ## measure lensless image + # get image file path + idx = self.dataset.indices[index] + + # twice nested as we do train-test split of subset of CelebA + fn = self.dataset.dataset.dataset.filename[idx] + fp = os.path.join(self.celeba_root, "celeba", "img_align_celeba", fn) + + # display on screen + display( + fp=fp, + rpi_username=self.rpi_username, + rpi_hostname=self.rpi_hostname, + **self.display_config, + ) + + # set mask + with torch.no_grad(): + subpattern = self._mask.get_vals() + subpattern_np = subpattern.detach().cpu().numpy().copy() + pattern = adafruit_sub2full( + subpattern_np, + center=self.mask_center, + ) + set_programmable_mask( + pattern, + self._mask.device, + self.rpi_username, + self.rpi_hostname, + ) + + # take picture + _, img = capture( + rpi_username=self.rpi_username, + rpi_hostname=self.rpi_hostname, + verbose=False, + **self.capture_config, + ) + + # -- normalize + img = img.astype(np.float32) / img.max() + + # prep + img = torch.from_numpy(img) + # -- if [H, W, C] -> [D, H, W, C] + if len(img.shape) == 3: + img = img.unsqueeze(0) + + if self.background is not None: + img = img - self.background + + # flip image x and y if needed + if self.capture_config.flip: + img = torch.rot90(img, dims=(-3, -2), k=2) + + # return simulated images (replace simulated with measured) + return img, lensed diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 1b2b234f..750b0e0e 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -83,13 +83,36 @@ def load_image( RGB image of dimension (height, width, 3). """ assert os.path.isfile(fp) + + nbits = None # input bit depth if "dng" in fp: import rawpy assert bayer raw = rawpy.imread(fp) img = raw.raw_image - # TODO : use raw.postprocess? + # # # TODO : use raw.postprocess? to much unknown processing... + # img = raw.postprocess( + # adjust_maximum_thr=0, # default 0.75 + # no_auto_scale=False, + # # no_auto_scale=True, + # gamma=(1, 1), + # bright=1, # default 1 + # exp_shift=1, + # no_auto_bright=True, + # # use_camera_wb=True, + # # use_auto_wb=False, + # # -- gives better balance for PSF measurement + # use_camera_wb=False, + # use_auto_wb=True, # default is False? f both use_camera_wb and use_auto_wb are True, then use_auto_wb has priority. + # ) + + # if red_gain is None or blue_gain is None: + # camera_wb = raw.camera_whitebalance + # red_gain = camera_wb[0] + # blue_gain = camera_wb[1] + + nbits = int(np.ceil(np.log2(raw.white_level))) ccm = raw.color_matrix[:, :3] black_level = np.array(raw.black_level_per_channel[:3]).astype(np.float32) elif "npy" in fp or "npz" in fp: @@ -99,11 +122,12 @@ def load_image( if bayer: assert len(img.shape) == 2, img.shape - if img.max() > 255: - # HQ camera - n_bits = 12 - else: - n_bits = 8 + if nbits is None: + if img.max() > 255: + # HQ camera + nbits = 12 + else: + nbits = 8 if back: back_img = cv2.imread(back, cv2.IMREAD_UNCHANGED) @@ -112,10 +136,11 @@ def load_image( img = np.clip(img, a_min=0, a_max=img.max()) img = img.astype(dtype) if nbits_out is None: - nbits_out = n_bits + nbits_out = nbits + img = bayer2rgb_cc( img, - nbits=n_bits, + nbits=nbits, blue_gain=blue_gain, red_gain=red_gain, black_level=black_level, @@ -504,17 +529,19 @@ def load_data( def save_image(img, fp, max_val=255): """Save as uint8 image.""" - if img.dtype == np.uint16: - img = img.astype(np.float32) + img_tmp = img.copy() + + if img_tmp.dtype == np.uint16: + img_tmp = img_tmp.astype(np.float32) - if img.dtype == np.float64 or img.dtype == np.float32: - img -= img.min() - img /= img.max() - img *= max_val - img = img.astype(np.uint8) + if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32: + img_tmp -= img_tmp.min() + img_tmp /= img_tmp.max() + img_tmp *= max_val + img_tmp = img_tmp.astype(np.uint8) - img = Image.fromarray(img) - img.save(fp) + img_tmp = Image.fromarray(img_tmp) + img_tmp.save(fp) def get_dtype(dtype=None, is_torch=False): diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py index b77fabcb..53d6257b 100644 --- a/lensless/utils/simulation.py +++ b/lensless/utils/simulation.py @@ -58,15 +58,16 @@ def __init__( Whether to quantize image, by default True. """ - assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" + if psf is not None: + assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" - if torch.is_tensor(psf): - # drop depth dimension, and convert HWC to CHW - psf = psf[0].movedim(-1, 0) - assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" - else: - psf = psf[0] - assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" + if torch.is_tensor(psf): + # drop depth dimension, and convert HWC to CHW + psf = psf[0].movedim(-1, 0) + assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + psf = psf[0] + assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" super().__init__( object_height, @@ -84,12 +85,15 @@ def __init__( **kwargs ) - if self.is_torch: - assert self.psf.shape[0] == 1 or self.psf.shape[0] == 3, "PSF must have 1 or 3 channels" - else: - assert ( - self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 - ), "PSF must have 1 or 3 channels" + if psf is not None: + if self.is_torch: + assert ( + self.psf.shape[0] == 1 or self.psf.shape[0] == 3 + ), "PSF must have 1 or 3 channels" + else: + assert ( + self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 + ), "PSF must have 1 or 3 channels" # save all the parameters in a dict self.params = { diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 89a31309..565864f4 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -3,12 +3,13 @@ # ========= # Authors : # Yohann PERRON +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# """ Benchmark reconstruction algorithms ============== -This script benchmarks reconstruction algorithms on the DiffuserCam dataset. +This script benchmarks reconstruction algorithms on the DiffuserCam test dataset. The algorithm benchmarked and the number of iterations can be set in the config file : benchmark.yaml. For unrolled algorithms, the results of the unrolled training (json file) are loaded from the benchmark/results folder. """ @@ -16,6 +17,8 @@ import hydra from hydra.utils import get_original_cwd +import time +import numpy as np import glob import json import os @@ -23,16 +26,21 @@ from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent -from lensless.utils.dataset import DiffuserCamTestDataset +from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA +from lensless.utils.io import save_image -try: - import torch -except ImportError: - raise ImportError("Torch and torchmetrics are needed to benchmark reconstruction algorithm") +import torch +from torch.utils.data import Subset @hydra.main(version_base=None, config_path="../../configs", config_name="benchmark") def benchmark_recon(config): + + # set seed + torch.manual_seed(config.seed) + np.random.seed(config.seed) + generator = torch.Generator().manual_seed(config.seed) + downsample = config.downsample n_files = config.n_files n_iter_range = config.n_iter_range @@ -44,8 +52,41 @@ def benchmark_recon(config): device = "cpu" # Benchmark dataset - benchmark_dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample) - psf = benchmark_dataset.psf.to(device) + dataset = config.dataset + if dataset == "DiffuserCam": + benchmark_dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample) + psf = benchmark_dataset.psf.to(device) + crop = None + + elif dataset == "DigiCamCelebA": + + dataset = DigiCamCelebA( + data_dir=os.path.join(get_original_cwd(), config.files.dataset), + celeba_root=config.files.celeba_root, + psf_path=os.path.join(get_original_cwd(), config.files.psf), + downsample=config.files.downsample, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + simulation_config=config.simulation, + crop=config.files.crop, + ) + dataset.psf = dataset.psf.to(device) + psf = dataset.psf + crop = dataset.crop + + # train-test split + train_size = int((1 - config.files.test_size) * len(dataset)) + test_size = len(dataset) - train_size + _, benchmark_dataset = torch.utils.data.random_split( + dataset, [train_size, test_size], generator=generator + ) + if config.n_files is not None: + benchmark_dataset = Subset(benchmark_dataset, np.arange(config.n_files)) + else: + raise ValueError(f"Dataset {dataset} not supported") + + print(f"Number of files : {len(benchmark_dataset)}") + print(f"Data shape : {dataset[0][0].shape}") model_list = [] # list of algoritms to benchmark if "ADMM" in config.algorithms: @@ -81,14 +122,63 @@ def benchmark_recon(config): # model_list.append(("APGD", APGD(psf))) results = {} + output_dir = None + if config.save_idx is not None: + + assert np.max(config.save_idx) < len( + benchmark_dataset + ), "save_idx values must be smaller than dataset size" + + os.mkdir("GROUND_TRUTH") + for idx in config.save_idx: + ground_truth = benchmark_dataset[idx][1] + ground_truth_np = ground_truth.cpu().numpy()[0] + + if crop is not None: + ground_truth_np = ground_truth_np[ + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + + save_image( + ground_truth_np, + fp=os.path.join("GROUND_TRUTH", f"{idx}.png"), + ) # benchmark each model for different number of iteration and append result to results + # -- batchsize has to equal 1 as baseline models don't support batch processing + start_time = time.time() for model_name, model in model_list: - results[model_name] = [] - print(f"Running benchmark for {model_name}") + + if config.save_idx is not None: + # make directory for outputs + os.mkdir(model_name) + + results[model_name] = dict() for n_iter in n_iter_range: - result = benchmark(model, benchmark_dataset, batchsize=1, n_iter=n_iter) - result["n_iter"] = n_iter - results[model_name].append(result) + + print(f"Running benchmark for {model_name} with {n_iter} iterations") + + if config.save_idx is not None: + output_dir = os.path.join(model_name, str(n_iter)) + os.mkdir(output_dir) + + result = benchmark( + model, + benchmark_dataset, + batchsize=1, + n_iter=n_iter, + save_idx=config.save_idx, + output_dir=output_dir, + crop=crop, + ) + results[model_name][int(n_iter)] = result + + # -- save results as easy to read JSON + results_path = "results.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=4) + proc_time = (time.time() - start_time) / 60 + print(f"Total processing time: {proc_time:.2f} min") # create folder to load results from trained algorithms result_dir = os.path.join(get_original_cwd(), "benchmark", "trained_results") @@ -113,11 +203,40 @@ def benchmark_recon(config): unrolled_results[model_name][metric] = result[metric] # Baseline results - baseline_results = { - "MSE": 0.0618, - "LPIPS_Alex": 0.4434, - "ReconstructionError": 13.70, - } + baseline_label = config.baseline + baseline_results = None + if dataset == "DiffuserCam": + # (Monakhova et al. 2019, https://arxiv.org/abs/1908.11502) + # -- ADMM (100) + if baseline_label == "MONAKHOVA 100iter": + baseline_results = { + "MSE": 0.0622, + "LPIPS_Alex": 0.5711, + "ReconstructionError": 13.62, + } + # -- ADMM (5) + elif baseline_label == "MONAKHOVA 5iter": + baseline_results = { + "MSE": 0.1041, + "LPIPS_Alex": 0.6309, + "ReconstructionError": 11.32, + } + # -- Le-ADMM (Unrolled 5) + elif baseline_label == "MONAKHOVA Unrolled 5iter": + baseline_results = { + "MSE": 0.0618, + "LPIPS_Alex": 0.4434, + "ReconstructionError": 13.70, + } + # -- Le-ADMM-U (Unrolled 5 + UNet post-denoiser) + elif baseline_label == "MONAKHOVA Unrolled 5iter + UNet": + baseline_results = { + "MSE": 0.0074, + "LPIPS_Alex": 0.1904, + "ReconstructionError": 22.14, + } + else: + raise ValueError(f"Baseline {baseline_label} not supported") # for each metrics plot the results comparing each model metrics_to_plot = ["SSIM", "PSNR", "MSE", "LPIPS_Vgg", "LPIPS_Alex", "ReconstructionError"] @@ -126,20 +245,21 @@ def benchmark_recon(config): # plot benchmarked algorithm for model_name in results.keys(): plt.plot( - [result["n_iter"] for result in results[model_name]], - [result[metric] for result in results[model_name]], + n_iter_range, + [results[model_name][n_iter][metric] for n_iter in n_iter_range], label=model_name, ) # plot baseline as horizontal dotted line - if metric in baseline_results.keys(): - plt.hlines( - baseline_results[metric], - 0, - max(n_iter_range), - linestyles="dashed", - label="Unrolled MONAKHOVA 5iter", - color="orange", - ) + if baseline_results is not None: + if metric in baseline_results.keys(): + plt.hlines( + baseline_results[metric], + 0, + max(n_iter_range), + linestyles="dashed", + label=baseline_label, + color="orange", + ) # plot unrolled algorithms results color_list = ["red", "green", "blue", "orange", "purple"] diff --git a/scripts/eval/compute_metrics_from_original.py b/scripts/eval/compute_metrics_from_original.py index 8dbae6c3..e4986e37 100644 --- a/scripts/eval/compute_metrics_from_original.py +++ b/scripts/eval/compute_metrics_from_original.py @@ -15,6 +15,7 @@ """ import hydra +import os from hydra.utils import to_absolute_path import numpy as np import matplotlib.pyplot as plt @@ -61,6 +62,9 @@ def compute_metrics(config): print("SSIM", ssim(img_resize, est)) print("LPIPS", lpips(img_resize, est)) + plt.savefig("comparison.png") + save = os.getcwd() + "/comparison.png" + print(f"Save comparison to {save}") plt.show() diff --git a/scripts/hardware/config_digicam.py b/scripts/hardware/config_digicam.py index cd8cab86..0807519b 100644 --- a/scripts/hardware/config_digicam.py +++ b/scripts/hardware/config_digicam.py @@ -60,6 +60,25 @@ def config_digicam(config): else: raise ValueError(f"Pattern {config.pattern} not supported") + # apply aperture + if config.aperture is not None: + + # aperture = np.zeros(shape, dtype=np.uint8) + # top_left = np.array(config.aperture.center) - np.array(config.aperture.shape) // 2 + # bottom_right = top_left + np.array(config.aperture.shape) + # aperture[:, top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]] = 1 + + apert_dim = np.array(config.aperture.shape) * np.array(pixel_pitch) + ap = rect_aperture( + apert_dim=apert_dim, + slm_shape=slm_devices[device][SLMParam.SLM_SHAPE], + pixel_pitch=pixel_pitch, + center=np.array(config.aperture.center) * pixel_pitch, + ) + aperture = ap.values + aperture[aperture > 0] = 1 + pattern = pattern * aperture + # save pattern if not config.pattern.endswith(".npy") and config.save: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -72,15 +91,6 @@ def config_digicam(config): print("Pattern min : ", pattern.min()) print("Pattern max : ", pattern.max()) - # apply aperture - if config.aperture is not None: - - aperture = np.zeros(shape, dtype=np.uint8) - top_left = np.array(config.aperture.center) - np.array(config.aperture.shape) // 2 - bottom_right = top_left + np.array(config.aperture.shape) - aperture[:, top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]] = 1 - pattern = pattern * aperture - assert pattern is not None n_nonzero = np.count_nonzero(pattern) diff --git a/scripts/hardware/set_digicam_mask_distance.py b/scripts/hardware/set_digicam_mask_distance.py new file mode 100644 index 00000000..dcd0dd79 --- /dev/null +++ b/scripts/hardware/set_digicam_mask_distance.py @@ -0,0 +1,16 @@ +import hydra +from lensless.hardware.utils import set_mask_sensor_distance + + +@hydra.main(version_base=None, config_path="../../configs", config_name="digicam") +def config_digicam(config): + + rpi_username = config.rpi.username + rpi_hostname = config.rpi.hostname + + # set mask to sensor distance + set_mask_sensor_distance(config.z, rpi_username, rpi_hostname, max_distance=40) + + +if __name__ == "__main__": + config_digicam() diff --git a/scripts/measure/analyze_image.py b/scripts/measure/analyze_image.py index 2b55c206..51cea91f 100644 --- a/scripts/measure/analyze_image.py +++ b/scripts/measure/analyze_image.py @@ -114,11 +114,11 @@ def analyze_image(fp, gamma, width, bayer, lens, lensless, bg, rg, plot_width, s assert fp is not None, "Must pass file path." # initialize plotting axis - _, ax_rgb = plt.subplots(ncols=2, nrows=1, num="RGB", figsize=(15, 5)) + fig_rgb, ax_rgb = plt.subplots(ncols=2, nrows=1, num="RGB", figsize=(15, 5)) if lens: - _, ax_gray = plt.subplots(ncols=3, nrows=1, num="Grayscale", figsize=(15, 5)) + fig_gray, ax_gray = plt.subplots(ncols=3, nrows=1, num="Grayscale", figsize=(15, 5)) else: - _, ax_gray = plt.subplots(ncols=2, nrows=1, num="Grayscale", figsize=(15, 5)) + fig_gray, ax_gray = plt.subplots(ncols=2, nrows=1, num="Grayscale", figsize=(15, 5)) # load PSF/image img = load_image( @@ -136,16 +136,16 @@ def analyze_image(fp, gamma, width, bayer, lens, lensless, bg, rg, plot_width, s # plot RGB and grayscale ax = plot_image(img, gamma=gamma, normalize=True, ax=ax_rgb[0]) ax.set_title("RGB") + ax = pixel_histogram(img, ax=ax_rgb[1], nbits=nbits) + ax.set_title("Histogram") + fig_rgb.savefig("rgb_analysis.png") img_grey = rgb2gray(img[None, ...]) ax = plot_image(img_grey, gamma=gamma, normalize=True, ax=ax_gray[0]) ax.set_title("Grayscale") - - # plot histogram, - ax = pixel_histogram(img, ax=ax_rgb[1], nbits=nbits) - ax.set_title("Histogram") ax = pixel_histogram(img_grey, ax=ax_gray[1], nbits=nbits) ax.set_title("Histogram") + fig_gray.savefig("grey_analysis.png") if lens: # determine PSF width diff --git a/scripts/measure/analyze_measured_dataset.py b/scripts/measure/analyze_measured_dataset.py index 5b0b89f2..6137e176 100644 --- a/scripts/measure/analyze_measured_dataset.py +++ b/scripts/measure/analyze_measured_dataset.py @@ -14,6 +14,7 @@ import numpy as np import matplotlib.pyplot as plt import time +import tqdm @hydra.main(version_base=None, config_path="../../configs", config_name="analyze_dataset") @@ -21,7 +22,8 @@ def analyze_dataset(config): folder = config.dataset_path desired_range = config.desired_range - delete_saturate = config.delete_saturated + delete_bad = config.delete_bad + start_idx = config.start_idx assert ( folder is not None @@ -30,6 +32,9 @@ def analyze_dataset(config): # get all PNG files in folder files = sorted(glob.glob(os.path.join(folder, "*.png"))) print("Found {} files".format(len(files))) + if start_idx is not None: + files = files[start_idx:] + print("Starting at file {}".format(files[0])) if config.n_files is not None: files = files[: config.n_files] print("Analyzing first {} files".format(len(files))) @@ -37,8 +42,9 @@ def analyze_dataset(config): # loop over files for maximum value max_vals = [] n_bad_files = 0 + bad_files = [] start_time = time.time() - for fn in files: + for fn in tqdm.tqdm(files): im = np.array(Image.open(fn)) max_val = im.max() max_vals.append(max_val) @@ -47,10 +53,13 @@ def analyze_dataset(config): if max_val < desired_range[0] or max_val > desired_range[1]: # print("File {} has max value {}".format(fn, max_val)) n_bad_files += 1 + bad_files.append(fn) - if delete_saturate and max_val == 255: - os.remove(fn) - print("REMOVED file {}".format(fn)) + if delete_bad: + os.remove(fn) + print("REMOVED file {}".format(fn)) + else: + print("File {} has max value {}".format(fn, max_val)) proc_time = time.time() - start_time print("Went through {} files in {:.2f} seconds".format(len(files), proc_time)) @@ -60,6 +69,17 @@ def analyze_dataset(config): ) ) + # command line input on whether to delete bad files + if not delete_bad: + response = None + while response not in ["yes", "no"]: + response = input("Delete bad files: [yes|no] : ") + if response == "yes": + for _fn in bad_files: + os.remove(_fn) + else: + print("Not deleting bad files") + # plot histogram output_folder = os.getcwd() output_fp = os.path.join(output_folder, "max_vals.png") diff --git a/scripts/measure/collect_dataset_on_device.py b/scripts/measure/collect_dataset_on_device.py index 4de5dc4a..601a503b 100644 --- a/scripts/measure/collect_dataset_on_device.py +++ b/scripts/measure/collect_dataset_on_device.py @@ -214,9 +214,9 @@ def collect_dataset(config): ) if down: - output = resize(output[None, ...], 1 / down, interpolation=cv2.INTER_CUBIC)[ - 0 - ] + output = resize( + output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC + )[0] # print range print(f"{output_fp}, range: {output.min()} - {output.max()}") diff --git a/scripts/measure/on_device_capture.py b/scripts/measure/on_device_capture.py index f0aae9b5..00c8c470 100644 --- a/scripts/measure/on_device_capture.py +++ b/scripts/measure/on_device_capture.py @@ -73,6 +73,10 @@ def capture(config): res = config.res nbits_out = config.nbits_out + assert ( + nbits_out in sensor_dict[sensor][SensorParam.BIT_DEPTH] + ), f"nbits_out must be one of {sensor_dict[sensor][SensorParam.BIT_DEPTH]} for sensor {sensor}" + # https://www.raspberrypi.com/documentation/accessories/camera.html#hardware-specification sensor_param = sensor_dict[sensor] assert exp <= sensor_param[SensorParam.MAX_EXPOSURE] @@ -96,6 +100,7 @@ def capture(config): assert down is None + # https://www.raspberrypi.com/documentation/computers/camera_software.html#raw-image-capture jpg_fn = fn + ".jpg" fn += ".dng" pic_command = [ @@ -107,6 +112,10 @@ def capture(config): f"{int(exp * 1e6)}", "-o", f"{jpg_fn}", + # long exposure: https://www.raspberrypi.com/documentation/computers/camera_software.html#very-long-exposures + # -- setting awbgains caused issues + # "--awbgains 1,1", + # "--immediate" ] cmd = subprocess.Popen( diff --git a/scripts/measure/remote_capture.py b/scripts/measure/remote_capture.py index 411a0f4c..6777347b 100644 --- a/scripts/measure/remote_capture.py +++ b/scripts/measure/remote_capture.py @@ -34,7 +34,7 @@ import matplotlib.pyplot as plt import rawpy from lensless.hardware.utils import check_username_hostname -from lensless.hardware.sensor import SensorOptions +from lensless.hardware.sensor import SensorOptions, sensor_dict, SensorParam from lensless.utils.image import rgb2gray, print_image_info from lensless.utils.plot import plot_image, pixel_histogram from lensless.utils.io import save_image @@ -61,6 +61,13 @@ def liveview(config): source = config.capture.source plot = config.plot + assert ( + nbits_out in sensor_dict[sensor][SensorParam.BIT_DEPTH] + ), f"capture.nbits_out must be one of {sensor_dict[sensor][SensorParam.BIT_DEPTH]} for sensor {sensor}" + assert ( + config.capture.nbits in sensor_dict[sensor][SensorParam.BIT_DEPTH] + ), f"capture.nbits must be one of {sensor_dict[sensor][SensorParam.BIT_DEPTH]} for sensor {sensor}" + if config.save: if config.output is not None: # make sure output directory exists @@ -125,33 +132,40 @@ def liveview(config): # copy over DNG file remotefile = f"~/{remote_fn}.dng" - localfile = f"{fn}.dng" + localfile = os.path.join(save, f"{fn}.dng") print(f"\nCopying over picture as {localfile}...") os.system('scp "%s@%s:%s" %s' % (username, hostname, remotefile, localfile)) - raw = rawpy.imread(localfile) - - # https://letmaik.github.io/rawpy/api/rawpy.Params.html#rawpy.Params - # https://www.libraw.org/docs/API-datastruct-eng.html - if nbits_out > 8: - # only 8 or 16 bit supported by postprocess - if nbits_out != 16: - print("casting to 16 bit...") - output_bps = 16 - else: - if nbits_out != 8: - print("casting to 8 bit...") - output_bps = 8 - img = raw.postprocess( - adjust_maximum_thr=0, # default 0.75 - no_auto_scale=False, - gamma=(1, 1), - output_bps=output_bps, - bright=1, # default 1 - exp_shift=1, - no_auto_bright=True, - use_camera_wb=True, - use_auto_wb=False, # default is False? f both use_camera_wb and use_auto_wb are True, then use_auto_wb has priority. - ) + + img = load_image(localfile, verbose=True, bayer=bayer, nbits_out=nbits_out) + + # raw = rawpy.imread(localfile) + + # # https://letmaik.github.io/rawpy/api/rawpy.Params.html#rawpy.Params + # # https://www.libraw.org/docs/API-datastruct-eng.html + # if nbits_out > 8: + # # only 8 or 16 bit supported by postprocess + # if nbits_out != 16: + # print("casting to 16 bit...") + # output_bps = 16 + # else: + # if nbits_out != 8: + # print("casting to 8 bit...") + # output_bps = 8 + # img = raw.postprocess( + # adjust_maximum_thr=0, # default 0.75 + # no_auto_scale=False, + # # no_auto_scale=True, + # gamma=(1, 1), + # output_bps=output_bps, + # bright=1, # default 1 + # exp_shift=1, + # no_auto_bright=True, + # # use_camera_wb=True, + # # use_auto_wb=False, + # # -- gives better balance for PSF measurement + # use_camera_wb=False, + # use_auto_wb=True, # default is False? f both use_camera_wb and use_auto_wb are True, then use_auto_wb has priority. + # ) # print image properties print_image_info(img) diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index c84d5b92..80657793 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -60,9 +60,13 @@ def admm(config): else: org_data = data ax = plot_image(org_data, gamma=config["display"]["gamma"]) - ax.set_title("Original measurement") + ax.set_title("Raw data") plt.savefig(plib.Path(save) / "lensless.png") + # close axes + fig = plt.gcf() + plt.close(fig) + start_time = time.time() if not config.admm.unrolled: recon = ADMM(psf, **config.admm) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index c9be1ee4..9365a262 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -43,9 +43,13 @@ DiffuserCamMirflickr, SimulatedFarFieldDataset, SimulatedDatasetTrainableMask, + DigiCamCelebA, + HITLDatasetTrainableMask, ) from torch.utils.data import Subset import lensless.hardware.trainable_mask +from lensless.hardware.slm import full2subpattern +from lensless.hardware.sensor import VirtualSensor from lensless.recon.utils import create_process_network from lensless.utils.image import rgb2gray, is_grayscale from lensless.utils.simulation import FarFieldSimulator @@ -55,45 +59,57 @@ from lensless.utils.io import load_psf from lensless.utils.io import save_image from lensless.utils.plot import plot_image +from lensless import ADMM import matplotlib.pyplot as plt # A logger for this file log = logging.getLogger(__name__) -def simulate_dataset(config): +def simulate_dataset(config, generator=None): if config.torch_device == "cuda" and torch.cuda.is_available(): device = "cuda" else: device = "cpu" - # prepare PSF - psf_fp = os.path.join(get_original_cwd(), config.files.psf) - psf, _ = load_psf( - psf_fp, - downsample=config.files.downsample, - return_float=True, - return_bg=True, - bg_pix=(0, 15), - ) - if config.files.diffusercam_psf: - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - psf = transform_BRG2RGB(torch.from_numpy(psf)) + # -- prepare PSF + psf = None + if config.trainable_mask.mask_type is None or config.trainable_mask.initial_value == "psf": + psf_fp = os.path.join(get_original_cwd(), config.files.psf) + psf, _ = load_psf( + psf_fp, + downsample=config.files.downsample, + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + if config.files.diffusercam_psf: + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + psf = transform_BRG2RGB(torch.from_numpy(psf)) - # drop depth dimension - psf = psf.to(device) + # drop depth dimension + psf = psf.to(device) - # load dataset + else: + # training mask / PSF + # mask = prep_trainable_mask(config, psf, downsample=config.files.downsample) + mask = prep_trainable_mask(config, psf, downsample=config.simulation.downsample) + psf = mask.get_psf().to(device) + + # -- load dataset + pre_transform = None transforms_list = [transforms.ToTensor()] data_path = os.path.join(get_original_cwd(), "data") if config.simulation.grayscale: transforms_list.append(transforms.Grayscale()) - transform = transforms.Compose(transforms_list) + if config.files.dataset == "mnist": + transform = transforms.Compose(transforms_list) train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) elif config.files.dataset == "fashion_mnist": + transform = transforms.Compose(transforms_list) train_ds = datasets.FashionMNIST( root=data_path, train=True, download=True, transform=transform ) @@ -101,6 +117,7 @@ def simulate_dataset(config): root=data_path, train=False, download=True, transform=transform ) elif config.files.dataset == "cifar10": + transform = transforms.Compose(transforms_list) train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) elif config.files.dataset == "CelebA": @@ -109,8 +126,22 @@ def simulate_dataset(config): assert os.path.isdir( data_path ), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" - train_ds = datasets.CelebA(root=root, split="train", download=False, transform=transform) - test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) + transform = transforms.Compose(transforms_list) + if config.files.n_files is None: + train_ds = datasets.CelebA( + root=root, split="train", download=False, transform=transform + ) + test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) + else: + ds = datasets.CelebA(root=root, split="all", download=False, transform=transform) + + ds = Subset(ds, np.arange(config.files.n_files)) + + train_size = int((1 - config.files.test_size) * len(ds)) + test_size = len(ds) - train_size + train_ds, test_ds = torch.utils.data.random_split( + ds, [train_size, test_size], generator=generator + ) else: raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") @@ -118,9 +149,6 @@ def simulate_dataset(config): if config.simulation.grayscale and not is_grayscale(psf): psf = rgb2gray(psf) - # prepare mask - mask = prep_trainable_mask(config, psf, grayscale=config.simulation.grayscale) - # check if gpu is available device_conv = config.torch_device if device_conv == "cuda" and torch.cuda.is_available(): @@ -135,11 +163,10 @@ def simulate_dataset(config): **config.simulation, ) + # import pudb; pudb.set_trace() + # create Pytorch dataset and dataloader - n_files = config.files.n_files - if n_files is not None: - train_ds = torch.utils.data.Subset(train_ds, np.arange(n_files)) - test_ds = torch.utils.data.Subset(test_ds, np.arange(n_files)) + crop = config.files.crop.copy() if config.files.crop is not None else None if mask is None: train_ds_prop = SimulatedFarFieldDataset( dataset=train_ds, @@ -147,6 +174,11 @@ def simulate_dataset(config): dataset_is_CHW=True, device_conv=device_conv, flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, ) test_ds_prop = SimulatedFarFieldDataset( dataset=test_ds, @@ -154,37 +186,128 @@ def simulate_dataset(config): dataset_is_CHW=True, device_conv=device_conv, flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, ) else: - train_ds_prop = SimulatedDatasetTrainableMask( - dataset=train_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - ) - test_ds_prop = SimulatedDatasetTrainableMask( - dataset=test_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - ) + if config.measure is not None: + + train_ds_prop = HITLDatasetTrainableMask( + rpi_username=config.measure.rpi_username, + rpi_hostname=config.measure.rpi_hostname, + celeba_root=config.files.celeba_root, + display_config=config.measure.display, + capture_config=config.measure.capture, + mask_center=config.trainable_mask.ap_center, + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + + test_ds_prop = HITLDatasetTrainableMask( + rpi_username=config.measure.rpi_username, + rpi_hostname=config.measure.rpi_hostname, + celeba_root=config.files.celeba_root, + display_config=config.measure.display, + capture_config=config.measure.capture, + mask_center=config.trainable_mask.ap_center, + dataset=test_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + + else: + + train_ds_prop = SimulatedDatasetTrainableMask( + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + test_ds_prop = SimulatedDatasetTrainableMask( + dataset=test_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) return train_ds_prop, test_ds_prop, mask -def prep_trainable_mask(config, psf, grayscale=False): +def prep_trainable_mask(config, psf=None, downsample=None): mask = None + color_filter = None if config.trainable_mask.mask_type is not None: mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) if config.trainable_mask.initial_value == "random": - initial_mask = torch.rand_like(psf) + if psf is not None: + initial_mask = torch.rand_like(psf) + else: + sensor = VirtualSensor.from_name(config.simulation.sensor, downsample=downsample) + resolution = sensor.resolution + initial_mask = torch.rand((1, *resolution, 3)) elif config.trainable_mask.initial_value == "psf": initial_mask = psf.clone() + # if file ending with "npy" + elif config.trainable_mask.initial_value.endswith("npy"): + pattern = np.load(os.path.join(get_original_cwd(), config.trainable_mask.initial_value)) + + initial_mask = full2subpattern( + pattern=pattern, + shape=config.trainable_mask.ap_shape, + center=config.trainable_mask.ap_center, + slm=config.trainable_mask.slm, + ) + initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) + + # prepare color filter if needed + from waveprop.devices import slm_dict + from waveprop.devices import SLMParam as SLMParam_wp + + slm_param = slm_dict[config.trainable_mask.slm] + if ( + config.trainable_mask.train_color_filter + and SLMParam_wp.COLOR_FILTER in slm_param.keys() + ): + color_filter = slm_param[SLMParam_wp.COLOR_FILTER] + color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) + + # add small random values + color_filter = color_filter + 0.1 * torch.rand_like(color_filter) else: raise ValueError( f"Initial PSF value {config.trainable_mask.initial_value} not supported" @@ -194,7 +317,11 @@ def prep_trainable_mask(config, psf, grayscale=False): initial_mask = rgb2gray(initial_mask) mask = mask_class( - initial_mask, optimizer="Adam", lr=config.trainable_mask.mask_lr, grayscale=grayscale + initial_mask, + optimizer="Adam", + downsample=downsample, + color_filter=color_filter, + **config.trainable_mask, ) return mask @@ -203,9 +330,19 @@ def prep_trainable_mask(config, psf, grayscale=False): @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") def train_unrolled(config): - disp = config.display.disp - if disp < 0: - disp = None + # set seed + seed = config.seed + torch.manual_seed(seed) + np.random.seed(seed) + generator = torch.Generator().manual_seed(seed) + + if config.start_delay is not None: + # wait for this time before starting script + delay = config.start_delay * 60 + start_time = time.time() + delay + start_time = time.strftime("%H:%M:%S", time.localtime(start_time)) + print(f"\nScript will start at {start_time}") + time.sleep(delay) save = config.save if save: @@ -222,6 +359,7 @@ def train_unrolled(config): train_set = None test_set = None psf = None + crop = None if "DiffuserCam" in config.files.dataset: original_path = os.path.join(get_original_cwd(), config.files.dataset) @@ -257,14 +395,118 @@ def train_unrolled(config): psf = dataset.psf + elif "celeba_adafruit" in config.files.dataset: + + dataset = DigiCamCelebA( + data_dir=os.path.join(get_original_cwd(), config.files.dataset), + celeba_root=config.files.celeba_root, + psf_path=os.path.join(get_original_cwd(), config.files.psf), + downsample=config.files.downsample, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + simulation_config=config.simulation, + crop=config.files.crop, + ) + crop = dataset.crop + dataset.psf = dataset.psf.to(device) + log.info(f"Data shape : {dataset[0][0].shape}") + + # train-test split + train_size = int((1 - config.files.test_size) * len(dataset)) + test_size = len(dataset) - train_size + train_set, test_set = torch.utils.data.random_split( + dataset, [train_size, test_size], generator=generator + ) + if config.files.n_files is not None: + train_set = Subset(train_set, np.arange(config.files.n_files)) + test_set = Subset(test_set, np.arange(config.files.n_files)) + + # -- if learning mask + downsample = config.files.downsample * 4 # measured files are 4x downsampled + mask = prep_trainable_mask(config, dataset.psf, downsample=downsample) + + if mask is not None: + # plot initial PSF + with torch.no_grad(): + psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] + if config.trainable_mask.grayscale: + psf_np = psf_np[:, :, -1] + + save_image(psf_np, os.path.join(save, "psf_initial.png")) + plot_image(psf_np, gamma=config.display.gamma) + plt.savefig(os.path.join(save, "psf_initial_plot.png")) + + # save original PSF as well + psf_meas = dataset.psf.detach().cpu().numpy()[0, ...] + plot_image(psf_meas, gamma=config.display.gamma) + plt.savefig(os.path.join(save, "psf_meas_plot.png")) + + with torch.no_grad(): + psf = mask.get_psf().to(dataset.psf) + + else: + + psf = dataset.psf + + # print info about PSF + log.info(f"PSF shape : {psf.shape}") + log.info(f"PSF min : {psf.min()}") + log.info(f"PSF max : {psf.max()}") + log.info(f"PSF dtype : {psf.dtype}") + log.info(f"PSF norm : {psf.norm()}") + else: - train_set, test_set, mask = simulate_dataset(config) + train_set, test_set, mask = simulate_dataset(config, generator=generator) psf = train_set.psf + crop = train_set.crop assert train_set is not None assert psf is not None + # reconstruct lensless with ADMM + with torch.no_grad(): + if config.test_idx is not None: + + log.info("Reconstruction a few images with ADMM...") + + for i, _idx in enumerate(config.test_idx): + + # lensless, lensed = dataset[_idx] + lensless, lensed = test_set[_idx] + recon = ADMM(psf) + recon.set_data(lensless.to(psf.device)) + res = recon.apply(disp_iter=None, plot=False, n_iter=10) + res_np = res[0].cpu().numpy() + res_np = res_np / res_np.max() + + lensed_np = lensed[0].cpu().numpy() + + lensless_np = lensless[0].cpu().numpy() + save_image(lensless_np, f"lensless_raw_{_idx}.png") + + # -- plot lensed and res on top of each other + if config.training.crop_preloss: + + res_np = res_np[ + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + lensed_np = lensed_np[ + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + if i == 0: + log.info(f"Cropped shape : {res_np.shape}") + + save_image(res_np, f"lensless_recon_{_idx}.png") + save_image(lensed_np, f"lensed_{_idx}.png") + + plt.figure() + plt.imshow(lensed_np, alpha=0.4) + plt.imshow(res_np, alpha=0.7) + plt.savefig(f"overlay_lensed_recon_{_idx}.png") + log.info(f"Train test size : {len(train_set)}") log.info(f"Test test size : {len(test_set)}") @@ -274,14 +516,26 @@ def train_unrolled(config): pre_process, pre_process_name = create_process_network( config.reconstruction.pre_process.network, config.reconstruction.pre_process.depth, + nc=config.reconstruction.pre_process.nc, device=device, ) + pre_proc_delay = config.reconstruction.pre_process.delay # Load post process model post_process, post_process_name = create_process_network( config.reconstruction.post_process.network, config.reconstruction.post_process.depth, + nc=config.reconstruction.post_process.nc, device=device, ) + post_proc_delay = config.reconstruction.post_process.delay + + if config.reconstruction.post_process.train_last_layer: + for name, param in post_process.named_parameters(): + if "m_tail" in name: + param.requires_grad = True + else: + param.requires_grad = False + # print(name, param.requires_grad, param.numel()) # create reconstruction algorithm if config.reconstruction.method == "unrolled_fista": @@ -291,8 +545,9 @@ def train_unrolled(config): tk=config.reconstruction.unrolled_fista.tk, pad=True, learn_tk=config.reconstruction.unrolled_fista.learn_tk, - pre_process=pre_process, - post_process=post_process, + pre_process=pre_process if pre_proc_delay is None else None, + post_process=post_process if post_proc_delay is None else None, + skip_unrolled=config.reconstruction.skip_unrolled, ).to(device) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( @@ -302,8 +557,9 @@ def train_unrolled(config): mu2=config.reconstruction.unrolled_admm.mu2, mu3=config.reconstruction.unrolled_admm.mu3, tau=config.reconstruction.unrolled_admm.tau, - pre_process=pre_process, - post_process=post_process, + pre_process=pre_process if pre_proc_delay is None else None, + post_process=post_process if post_proc_delay is None else None, + skip_unrolled=config.reconstruction.skip_unrolled, ).to(device) else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") @@ -315,10 +571,10 @@ def train_unrolled(config): if config.reconstruction.post_process.network is not None: algorithm_name += "_" + post_process_name - # print number of parameters - n_param = sum(p.numel() for p in recon.parameters()) + # print number of trainable parameters + n_param = sum(p.numel() for p in recon.parameters() if p.requires_grad) if mask is not None: - n_param += sum(p.numel() for p in mask.parameters()) + n_param += sum(p.numel() for p in mask.parameters() if p.requires_grad) log.info(f"Training model with {n_param} parameters") log.info(f"Setup time : {time.time() - start_time} s") @@ -330,21 +586,30 @@ def train_unrolled(config): test_dataset=test_set, mask=mask, batch_size=config.training.batch_size, + eval_batch_size=config.training.eval_batch_size, loss=config.loss, lpips=config.lpips, l1_mask=config.trainable_mask.L1_strength, - optimizer=config.optimizer.type, - optimizer_lr=config.optimizer.lr, - slow_start=config.training.slow_start, + optimizer=config.optimizer, skip_NAN=config.training.skip_NAN, algorithm_name=algorithm_name, metric_for_best_model=config.training.metric_for_best_model, save_every=config.training.save_every, gamma=config.display.gamma, logger=log, + crop=crop if config.training.crop_preloss else None, + pre_process=pre_process, + pre_process_delay=pre_proc_delay, + pre_process_freeze=config.reconstruction.pre_process.freeze, + pre_process_unfreeze=config.reconstruction.pre_process.unfreeze, + post_process=post_process, + post_process_delay=post_proc_delay, + post_process_freeze=config.reconstruction.post_process.freeze, + post_process_unfreeze=config.reconstruction.post_process.unfreeze, + clip_grad=config.training.clip_grad, ) - trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) + trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx) log.info(f"Results saved in {save}") diff --git a/scripts/sim/digicam_psf.py b/scripts/sim/digicam_psf.py index d0e0636b..9b665c27 100644 --- a/scripts/sim/digicam_psf.py +++ b/scripts/sim/digicam_psf.py @@ -8,10 +8,11 @@ from slm_controller import slm from lensless.utils.io import save_image, get_dtype, load_psf from lensless.utils.plot import plot_image +from lensless.utils.image import gamma_correction from lensless.hardware.sensor import VirtualSensor from lensless.hardware.slm import get_programmable_mask, get_intensity_psf from waveprop.devices import slm_dict -from PIL import Image +from waveprop.devices import SLMParam as SLMParam_wp @hydra.main(version_base=None, config_path="../../configs", config_name="sim_digicam_psf") @@ -27,7 +28,7 @@ def digicam_psf(config): ap_shape = np.array(config.digicam.ap_shape) rotate_angle = config.digicam.rotate slm_param = slm_dict[config.digicam.slm] - sensor = VirtualSensor.from_name(config.digicam.sensor) + sensor = VirtualSensor.from_name(config.digicam.sensor, downsample=config.digicam.downsample) # simulation parameters scene2mask = config.sim.scene2mask @@ -76,23 +77,50 @@ def digicam_psf(config): start_time = time.time() slm_vals = pattern_sub / 255.0 + # prepare color filter + if SLMParam_wp.COLOR_FILTER in slm_param.keys(): + color_filter = slm_param[SLMParam_wp.COLOR_FILTER] + if config.use_torch: + color_filter = torch.from_numpy(color_filter.copy()).to( + device=torch_device, dtype=dtype + ) + else: + color_filter = color_filter.astype(dtype) + if config.digicam.slm == "adafruit": # flatten color channel along rows slm_vals = slm_vals.reshape((-1, slm_vals.shape[-1]), order="F") + # save extracted mask values + np.save(os.path.join(output_folder, "mask_vals.npy"), slm_vals) + if config.use_torch: slm_vals = torch.from_numpy(slm_vals).to(device=torch_device, dtype=dtype) else: slm_vals = slm_vals.astype(dtype) + # -- get mask mask = get_programmable_mask( vals=slm_vals, sensor=sensor, slm_param=slm_param, rotate=rotate_angle, flipud=config.sim.flipud, + color_filter=color_filter, ) + if config.digicam.vertical_shift is not None: + if config.use_torch: + mask = torch.roll(mask, config.digicam.vertical_shift, dims=1) + else: + mask = np.roll(mask, config.digicam.vertical_shift, axis=1) + + if config.digicam.horizontal_shift is not None: + if config.use_torch: + mask = torch.roll(mask, config.digicam.horizontal_shift, dims=2) + else: + mask = np.roll(mask, config.digicam.horizontal_shift, axis=2) + # -- plot mask if config.use_torch: mask_np = mask.cpu().detach().numpy() @@ -127,17 +155,37 @@ def digicam_psf(config): else: print("Could not load PSF image from: ", fp_psf) - fp = os.path.join(output_folder, "psf_plot.png") + fp = os.path.join(output_folder, "sim_psf_plot.png") + fig = plt.figure(frameon=False) + ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) + ax.set_axis_off() + fig.add_axes(ax) + ax.imshow(psf_in_np) + ax.set_xticks([]) + ax.set_yticks([]) + plt.savefig(fp) + if psf_meas is not None: - _, ax = plt.subplots(1, 2) - ax[0].imshow(psf_in_np) - ax[0].set_title("Simulated") - plot_image(psf_meas, gamma=config.digicam.gamma, normalize=True, ax=ax[1]) - # ax[1].imshow(psf_meas) - ax[1].set_title("Measured") - plt.savefig(fp) - else: - plt.imshow(psf_in_np) + + fig = plt.figure(frameon=False) + ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) + ax.set_axis_off() + fig.add_axes(ax) + plot_image(psf_meas, gamma=config.digicam.gamma, normalize=True, ax=ax) + # remove axis values + ax.set_xticks([]) + ax.set_yticks([]) + plt.savefig(os.path.join(output_folder, "meas_psf_plot.png")) + + # plot overlayed + fp = os.path.join(output_folder, "psf_overlay.png") + psf_meas_norm = psf_meas[0] / np.max(psf_meas) + # psf_meas_norm = gamma_correction(psf_meas_norm, gamma=config.digicam.gamma) + psf_in_np_norm = psf_in_np / np.max(psf_in_np) + + plt.figure() + plt.imshow(psf_in_np_norm, alpha=0.7) + plt.imshow(psf_meas_norm, alpha=0.7) plt.savefig(fp) # save PSF as png diff --git a/setup.py b/setup.py index 392fc7fe..df3de5fa 100644 --- a/setup.py +++ b/setup.py @@ -24,14 +24,14 @@ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - python_requires=">=3.8.1", + python_requires=">=3.8.1, <=3.11.9", install_requires=[ "opencv-python>=4.5.1.48", - "numpy>=1.22, <=1.23.5", + "numpy>=1.22", "scipy>=1.7.0", "image>=1.5.33", "matplotlib>=3.4.2", - "rawpy>=0.16.0", + "rawpy>=0.16.0", # less than python 3.12 "paramiko>=3.2.0", "hydra-core", ],