From b27b973751ad283963b11ec14b4c94a193617e05 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 7 Aug 2024 16:44:23 +0000 Subject: [PATCH] Clean up files and update CHANGELOG. --- CHANGELOG.rst | 15 ++++- README.rst | 2 +- configs/benchmark_diffusercam_mirflickr.yaml | 51 ++++++++------- configs/benchmark_digicam_celeba.yaml | 49 +++++--------- .../benchmark_digicam_mirflickr_multi.yaml | 32 ++++----- .../benchmark_digicam_mirflickr_single.yaml | 48 ++++---------- configs/benchmark_tapecam_mirflickr.yaml | 48 +++++++------- configs/digicam_example.yaml | 2 +- configs/finetune_tape_for_diffuser.yaml | 3 +- configs/recon_digicam_mirflickr.yaml | 13 ++-- configs/sim_digicam_psf.yaml | 2 +- configs/telegram_demo_iccp2024.yaml | 1 - docs/source/reconstruction.rst | 16 +++++ lensless/__init__.py | 1 + lensless/recon/multi_wiener.py | 65 +++++++------------ lensless/recon/trainable_inversion.py | 2 +- lensless/recon/trainable_recon.py | 2 + lensless/recon/utils.py | 7 -- scripts/recon/train_learning_based.py | 11 +--- 19 files changed, 163 insertions(+), 207 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7e4f9caf..625b51cf 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,17 @@ Added - Option to pass background image to ``utils.io.load_data``. - Option to set image resolution with ``hardware.utils.display`` function. +- Auxiliary of reconstructing output from pre-processor (not working). +- Option to set focal range for MultiLensArray. +- Optional to remove deadspace modelling for programmable mask. +- Compensation branch for unrolled ADMM: https://ieeexplore.ieee.org/abstract/document/9546648 +- Multi-Wiener deconvolution network: https://opg.optica.org/oe/fulltext.cfm?uri=oe-31-23-39088&id=541387 +- Option to skip pre-processor and post-processor at inference time. +- Option to set difference learning rate schedules, e.g. ADAMW, exponential decay, Cosine decay with warmup. +- Various augmentations for training: random flipping, random rotate, and random shifts. Latter two don't work well since new regions appear that throw off PSF/LSI modeling. +- HFSimulated object for simulating lensless data from ground-truth and PSF. +- Option to set cache directory for Hugging Face datasets. +- Option to initialize training with another model. Changed ~~~~~~~ @@ -24,7 +35,9 @@ Changed Bugfix ~~~~~~ -- Nothing +- Computation of average metric in batches. +- Support for grayscale PSF for RealFFTConvolve2D. +- Calling model.eval() before inference, and model.train() before training. 1.0.7 - (2024-05-14) diff --git a/README.rst b/README.rst index 428fd07e..cb3181ec 100644 --- a/README.rst +++ b/README.rst @@ -44,7 +44,7 @@ The toolkit includes: * Camera assembly tutorials (`link `__). * Measurement scripts (`link `__). * Dataset preparation and loading tools, with `Hugging Face `__ integration (`slides `__ on uploading a dataset to Hugging Face with `this script `__). -* `Reconstruction algorithms `__ (e.g. FISTA, ADMM, unrolled algorithms, trainable inversion, pre- and post-processors). +* `Reconstruction algorithms `__ (e.g. FISTA, ADMM, unrolled algorithms, trainable inversion, , multi-Wiener deconvolution network, pre- and post-processors). * `Training script `__ for learning-based reconstruction. * `Pre-trained models `__ that can be loaded from `Hugging Face `__, for example in `this script `__. * Mask `design `__ and `fabrication `__ tools. diff --git a/configs/benchmark_diffusercam_mirflickr.yaml b/configs/benchmark_diffusercam_mirflickr.yaml index 94242855..4968422d 100644 --- a/configs/benchmark_diffusercam_mirflickr.yaml +++ b/configs/benchmark_diffusercam_mirflickr.yaml @@ -5,7 +5,7 @@ defaults: dataset: HFDataset batchsize: 4 -device: "cuda:3" +device: "cuda:0" huggingface: repo: "bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM" @@ -20,39 +20,40 @@ huggingface: single_channel_psf: True algorithms: [ - # "ADMM", + "ADMM", - # ## - -- reconstructions trained on DiffuserCam - # "hf:diffusercam:mirflickr:U5+Unet8M", - # "hf:diffusercam:mirflickr:TrainInv+Unet8M", - # "hf:diffusercam:mirflickr:MMCN4M+Unet4M", - # "hf:diffusercam:mirflickr:MWDN8M", + ## -- reconstructions trained on DiffuserCam measured + "hf:diffusercam:mirflickr:U5+Unet8M", + "hf:diffusercam:mirflickr:TrainInv+Unet8M", + "hf:diffusercam:mirflickr:MMCN4M+Unet4M", + "hf:diffusercam:mirflickr:MWDN8M", "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", - # "hf:diffusercam:mirflickr:Unet4M+TrainInv+Unet4M", - # "hf:diffusercam:mirflickr:Unet2M+MMCN+Unet2M", - # "hf:diffusercam:mirflickr:Unet2M+MWDN6M", - # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", - - # ## - -- reconstructions trained on other datasets/systems - # # "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", - # # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", - # # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", - # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", - # # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", - # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips", - # # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", - # # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_aux1", - # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M", + "hf:diffusercam:mirflickr:Unet4M+TrainInv+Unet4M", + "hf:diffusercam:mirflickr:Unet2M+MMCN+Unet2M", + "hf:diffusercam:mirflickr:Unet2M+MWDN6M", + "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam_post", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam_pre", + # ## -- reconstruction trained on DiffuserCam simulated + # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_tapecam", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_tapecam_post", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_tapecam_pre", - # # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam", - # # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam_post", - # # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam_pre", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi_post", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi_pre", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi", + + # ## -- reconstructions trained on other datasets/systems + # "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", + # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", + # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", + # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", + # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips", + # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", + # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_aux1", # "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave", # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave", ] diff --git a/configs/benchmark_digicam_celeba.yaml b/configs/benchmark_digicam_celeba.yaml index bf5612a7..aa3e4a82 100644 --- a/configs/benchmark_digicam_celeba.yaml +++ b/configs/benchmark_digicam_celeba.yaml @@ -4,23 +4,25 @@ defaults: - _self_ -dataset: HFDataset # DiffuserCam, DigiCamCelebA, HFDataset +dataset: HFDataset batchsize: 10 -device: "cuda:2" +device: "cuda:0" algorithms: [ "ADMM", - # "hf:digicam:celeba_26k:U5+Unet8M_wave", - # "hf:digicam:celeba_26k:TrainInv+Unet8M_wave", - # "hf:digicam:celeba_26k:MWDN8M_wave", - # "hf:digicam:celeba_26k:MMCN4M+Unet4M_wave", - # "hf:digicam:celeba_26k:Unet2M+MWDN6M_wave", - # "hf:digicam:celeba_26k:Unet4M+TrainInv+Unet4M_wave", - # "hf:digicam:celeba_26k:Unet2M+MMCN+Unet2M_wave", - # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:celeba_26k:Unet4M+U10+Unet4M_wave", + + ## -- reconstructions trained on measured data + "hf:digicam:celeba_26k:U5+Unet8M_wave", + "hf:digicam:celeba_26k:TrainInv+Unet8M_wave", + "hf:digicam:celeba_26k:MWDN8M_wave", + "hf:digicam:celeba_26k:MMCN4M+Unet4M_wave", + "hf:digicam:celeba_26k:Unet2M+MWDN6M_wave", + "hf:digicam:celeba_26k:Unet4M+TrainInv+Unet4M_wave", + "hf:digicam:celeba_26k:Unet2M+MMCN+Unet2M_wave", + "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", + "hf:digicam:celeba_26k:Unet4M+U10+Unet4M_wave", - # #-- reconstructions trained on other datasets/systems + # # -- reconstructions trained on other datasets/systems # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", @@ -29,33 +31,12 @@ algorithms: [ # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", ] - -# ## -- reconstructions trained on other datasets/systems -# algorithms: [ -# "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", -# "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", -# "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", -# # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", -# ] - -# # algorithm configuration -# hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave: -# skip_post: True -# hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M: -# skip_post: True -# hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M: -# skip_post: True -# hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave: -# skip_post: True - save_idx: [0, 2, 3, 4, 9] n_iter_range: [100] # for ADMM huggingface: repo: bezzam/DigiCam-CelebA-26K - # cache_dir: /dev/shm - psf: psf_measured.png - # psf: psf_simulated_waveprop.png # psf_simulated_waveprop.png, psf_simulated.png, psf_measured.png + psf: psf_simulated_waveprop.png # psf_simulated_waveprop.png, psf_simulated.png, psf_measured.png split_seed: 0 test_size: 0.15 downsample: 2 diff --git a/configs/benchmark_digicam_mirflickr_multi.yaml b/configs/benchmark_digicam_mirflickr_multi.yaml index 0c9d5b8b..17ff3508 100644 --- a/configs/benchmark_digicam_mirflickr_multi.yaml +++ b/configs/benchmark_digicam_mirflickr_multi.yaml @@ -10,7 +10,6 @@ device: "cuda:0" huggingface: repo: "bezzam/DigiCam-Mirflickr-MultiMask-25K" - cache_dir: /dev/shm psf: null # null for simulating PSF image_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down @@ -22,34 +21,35 @@ huggingface: downsample: 1 algorithms: [ - # "ADMM", - ## -- reconstructions trained on other datasets/systems + "ADMM", + + ## -- reconstructions trained on measured data + "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave", + "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave", + "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_aux1", + "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips", + + # ## -- reconstructions trained on other datasets/systems # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", # "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave", # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_aux1", # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips", # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips_rotate10", - # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_aux1", - # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", - "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_ft_flips", - "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_ft_flips_rotate10", + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_ft_flips", + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_ft_flips_rotate10", ] +# # -- to only use output from unrolled +# hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_aux1: +# skip_post: True +# skip_pre: True -# -- to only use output from unrolled -hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_aux1: - skip_post: True - skip_pre: True - -# save_idx: [1, 2, 4, 5, 9] -save_idx: [24, 33, 61] +save_idx: [1, 2, 4, 5, 9, 24, 33, 61] n_iter_range: [100] # for ADMM # simulating PSF diff --git a/configs/benchmark_digicam_mirflickr_single.yaml b/configs/benchmark_digicam_mirflickr_single.yaml index 5f000ef8..7e921d4b 100644 --- a/configs/benchmark_digicam_mirflickr_single.yaml +++ b/configs/benchmark_digicam_mirflickr_single.yaml @@ -3,7 +3,6 @@ defaults: - benchmark - _self_ - dataset: HFDataset batchsize: 4 device: "cuda:0" @@ -23,16 +22,20 @@ huggingface: algorithms: [ - # "ADMM", - # "hf:digicam:mirflickr_single_25k:U5+Unet8M_wave", - # "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave", - # "hf:digicam:mirflickr_single_25k:MMCN4M+Unet4M_wave", - # "hf:digicam:mirflickr_single_25k:MWDN8M_wave", - # "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave", - # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:mirflickr_single_25k:Unet2M+MMCN+Unet2M_wave", - # "hf:digicam:mirflickr_single_25k:Unet2M+MWDN6M_wave", - # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", + "ADMM", + + # -- reconstructions trained on measured data + "hf:digicam:mirflickr_single_25k:U5+Unet8M_wave", + "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave", + "hf:digicam:mirflickr_single_25k:MMCN4M+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:MWDN8M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:Unet2M+MMCN+Unet2M_wave", + "hf:digicam:mirflickr_single_25k:Unet2M+MWDN6M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips", + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips_rotate10", # ## -- reconstructions trained on other datasets/systems # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", @@ -42,31 +45,8 @@ algorithms: [ # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave", - "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips", - "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips_rotate10", - ] - -# algorithms: [ -# # ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=True) -# # "hf:digicam:mirflickr_single_25k:U10_wave", -# # "hf:digicam:mirflickr_single_25k:Unet8M_wave", -# # "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave", -# # "hf:digicam:mirflickr_single_25k:U10+Unet8M_wave", -# # "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave", -# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave" - -# ## -- below models need to set correct PSF simulation -# # ## - measured PSF (huggingface.psf=psf_measured.png) -# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_measured", -# # ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=False) -# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave_nodead", -# # ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=True) -# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M", -# # ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=False) -# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_nodead" -# ] save_idx: [1, 2, 4, 5, 9] n_iter_range: [100] # for ADMM diff --git a/configs/benchmark_tapecam_mirflickr.yaml b/configs/benchmark_tapecam_mirflickr.yaml index ee966c97..674e5069 100644 --- a/configs/benchmark_tapecam_mirflickr.yaml +++ b/configs/benchmark_tapecam_mirflickr.yaml @@ -23,25 +23,31 @@ huggingface: ## -- reconstructions trained with same dataset/system algorithms: [ -# "ADMM", -# "hf:tapecam:mirflickr:U5+Unet8M", -# "hf:tapecam:mirflickr:TrainInv+Unet8M", -# "hf:tapecam:mirflickr:MMCN4M+Unet4M", -# "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", -# "hf:tapecam:mirflickr:Unet4M+TrainInv+Unet4M", -# "hf:tapecam:mirflickr:Unet2M+MMCN+Unet2M", -# "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", + "ADMM", + # -- reconstructions trained on measured data + "hf:tapecam:mirflickr:U5+Unet8M", + "hf:tapecam:mirflickr:TrainInv+Unet8M", + "hf:tapecam:mirflickr:MMCN4M+Unet4M", + "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", + "hf:tapecam:mirflickr:Unet4M+TrainInv+Unet4M", + "hf:tapecam:mirflickr:Unet2M+MMCN+Unet2M", + "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_aux1", - - # single_channel_psf = True -# "hf:tapecam:mirflickr:MWDN8M", -# "hf:tapecam:mirflickr:Unet2M+MWDN6M", - - # # -- generalization # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", + + # # below models need `single_channel_psf = True` + # "hf:tapecam:mirflickr:MWDN8M", + # "hf:tapecam:mirflickr:Unet2M+MWDN6M", + + # ## -- reconstructions trained on other datasets/systems + # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", + # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", + # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_tapecam", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_tapecam_post", @@ -50,17 +56,9 @@ algorithms: [ # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam", # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam_post", # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam_pre", - "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi_pre", - "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi", - -## -- reconstructions trained on other datasets/systems - # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", - # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", - # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", + # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi_pre", + # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi", ] save_idx: [1, 2, 4, 5, 9] -n_iter_range: [100] # for ADMM - +n_iter_range: [100] # for ADM diff --git a/configs/digicam_example.yaml b/configs/digicam_example.yaml index 3c1a618b..6267e4ae 100644 --- a/configs/digicam_example.yaml +++ b/configs/digicam_example.yaml @@ -12,7 +12,7 @@ psf: null # if not provided, simulate with parameters below mask: fp: null # provide path, otherwise generate with seed seed: 0 - # defaults to configuration use for this dataset: https://huggingface.co/datasets/bezzam/DigiCam-Mirflickr-SingleMask-25K + # defaults to configuration used for this dataset: https://huggingface.co/datasets/bezzam/DigiCam-Mirflickr-SingleMask-25K # ie this config: configs/collect_mirflickr_singlemask.yaml shape: [54, 26] center: [57, 77] diff --git a/configs/finetune_tape_for_diffuser.yaml b/configs/finetune_tape_for_diffuser.yaml index a564ba6b..eec397db 100644 --- a/configs/finetune_tape_for_diffuser.yaml +++ b/configs/finetune_tape_for_diffuser.yaml @@ -15,7 +15,7 @@ files: single_channel_psf: True downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution flipud: True - flip_lensed: True # for measure data + flip_lensed: True # for measured data hf_simulated: True @@ -26,7 +26,6 @@ training: reconstruction: init: hf:tapecam:mirflickr:Unet4M+U5+Unet4M - # init: hf:diffusercam:mirflickr:Unet4M+U5+Unet4M optimizer: lr: 1e-5 diff --git a/configs/recon_digicam_mirflickr.yaml b/configs/recon_digicam_mirflickr.yaml index 3a321105..38143f0c 100644 --- a/configs/recon_digicam_mirflickr.yaml +++ b/configs/recon_digicam_mirflickr.yaml @@ -20,25 +20,24 @@ alignment: # - Learned reconstructions: see "lensless/recon/model_dict.py" -### dataset: mirflickr_single_25k +# --- dataset: mirflickr_single_25k # model: TrainInv+Unet8M_wave # model: MMCN4M+Unet4M_wave -model: MWDN8M_wave +# model: MWDN8M_wave # model: U5+Unet8M_wave # model: Unet4M+TrainInv+Unet4M_wave # model: Unet2M+MMCN+Unet2M_wave # model: Unet4M+U5+Unet4M_wave # model: Unet4M+U10+Unet4M_wave - -# ## dataset: mirflickr_multi_25k -# model: Unet4M+U5+Unet4M_wave +# --- dataset: mirflickr_multi_25k +model: Unet4M+U5+Unet4M_wave # # -- for ADMM with fixed parameters # model: admm # n_iter: 100 -device: cuda:2 -n_trials: 1 # more if you want to get average inference time +device: cuda:0 +n_trials: 1 # to get average inference time idx: 1 # index from test set to reconstruct save: True \ No newline at end of file diff --git a/configs/sim_digicam_psf.yaml b/configs/sim_digicam_psf.yaml index e97e90e4..3547d180 100644 --- a/configs/sim_digicam_psf.yaml +++ b/configs/sim_digicam_psf.yaml @@ -8,7 +8,7 @@ dtype: float32 torch_device: cuda requires_grad: False -# if repo not provided, check for local file +# if repo not provided, check for local file at `digicam.pattern` huggingface_repo: bezzam/DigiCam-CelebA-26K huggingface_mask_pattern: mask_pattern.npy huggingface_psf: psf_measured.png diff --git a/configs/telegram_demo_iccp2024.yaml b/configs/telegram_demo_iccp2024.yaml index f5cec434..5abba477 100644 --- a/configs/telegram_demo_iccp2024.yaml +++ b/configs/telegram_demo_iccp2024.yaml @@ -4,7 +4,6 @@ defaults: # for Telegram token: null -whitelist: [360264201] setup_fp: voronoi_setup.jpeg # usernames and IP address diff --git a/docs/source/reconstruction.rst b/docs/source/reconstruction.rst index 4674f327..c46fe538 100644 --- a/docs/source/reconstruction.rst +++ b/docs/source/reconstruction.rst @@ -90,6 +90,22 @@ :special-members: __init__ :show-inheritance: + Trainable Inversion + ~~~~~~~~~~~~~~~~~~~ + + .. autoclass:: lensless.TrainableInversion + :members: forward + :special-members: __init__ + :show-inheritance: + + Multi-Wiener Deconvolution Network + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + .. autoclass:: lensless.MultiWiener + :members: forward + :special-members: __init__ + :show-inheritance: + Reconstruction Utilities ------------------------ diff --git a/lensless/__init__.py b/lensless/__init__.py index 748e503d..70990774 100644 --- a/lensless/__init__.py +++ b/lensless/__init__.py @@ -29,6 +29,7 @@ from .recon.unrolled_admm import UnrolledADMM from .recon.unrolled_fista import UnrolledFISTA from .recon.trainable_inversion import TrainableInversion + from .recon.multi_wiener import MultiWiener except Exception: pass diff --git a/lensless/recon/multi_wiener.py b/lensless/recon/multi_wiener.py index ac99e557..cb53d5a8 100644 --- a/lensless/recon/multi_wiener.py +++ b/lensless/recon/multi_wiener.py @@ -1,7 +1,10 @@ -""" -Adapted from source code by KC Lee. - -""" +# ############################################################################# +# multi_wiener.py +# =============== +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# Kyung Chul Lee +# ############################################################################# import torch @@ -20,12 +23,6 @@ def __init__(self, in_channels, out_channels, mid_channels=None): if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( - # nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), padding=1), - # nn.BatchNorm2d(mid_channels), - # nn.ReLU(inplace=True), - # nn.Conv2d(mid_channels, out_channels, kernel_size=(3, 3), padding=1), - # nn.BatchNorm2d(out_channels), - # nn.ReLU(inplace=True) nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), @@ -55,7 +52,7 @@ class Up(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - # or us ConvTranspose2d? https://github.com/milesial/Pytorch-UNet/blob/21d7850f2af30a9695bbeea75f3136aa538cfc4a/unet/unet_parts.py#L53 + # or use ConvTranspose2d? https://github.com/milesial/Pytorch-UNet/blob/21d7850f2af30a9695bbeea75f3136aa538cfc4a/unet/unet_parts.py#L53 self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) @@ -99,12 +96,25 @@ def __init__( skip_pre=False, ): """ + Constructor for Multi-Wiener Deconvolution Network (MWDN) as proposed in: + https://opg.optica.org/oe/fulltext.cfm?uri=oe-31-23-39088&id=541387 + Parameters ---------- in_channels : int Number of input channels. RGB or grayscale, i.e. 3 and 1 respectively. out_channels : int Number of output channels. RGB or grayscale, i.e. 3 and 1 respectively. + psf : :py:class:`~torch.Tensor` + Point spread function (PSF) that models forward propagation. + psf_channels : int + Number of channels in the PSF. Default is 1. + nc : list + Number of channels in the network. Default is [64, 128, 256, 512, 512]. + pre_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + Pre-processor applies before MWDN. Default is None. + skip_pre : bool + Skip pre-processing. Default is False. """ assert in_channels == 1 or in_channels == 3, "in_channels must be 1 or 3" @@ -119,10 +129,6 @@ def __init__( self.inc = DoubleConv(in_channels, nc[0]) self.down_layers = nn.ModuleList([Down(nc[i], nc[i + 1]) for i in range(len(nc) - 1)]) - # self.down1 = Down(64, 128) - # self.down2 = Down(128, 256) - # self.down3 = Down(256, 512) - # self.down4 = Down(512, 512) self.up_layers = [] n_prev = nc[-1] @@ -132,10 +138,6 @@ def __init__( self.up_layers.append(Up(n_in, n_out)) n_prev = n_out self.up_layers = nn.ModuleList(self.up_layers) - # self.up1 = Up(1024, 256) - # self.up2 = Up(512, 128) - # self.up3 = Up(256, 64) - # self.up4 = Up(128, 64) self.outc = OutConv(nc[0], out_channels) self.delta = nn.Parameter(torch.tensor(np.ones(5) * 0.01, dtype=torch.float32)) @@ -145,9 +147,6 @@ def __init__( self.inc0 = DoubleConv(psf_channels, nc[0]) self.psf_down = nn.ModuleList([Down(nc[i], nc[i + 1]) for i in range(len(nc) - 2)]) - # self.down11 = Down(64, 128) - # self.down22 = Down(128, 256) - # self.down33 = Down(256, 512) # padding H and W to next multiple of 8 img_shape = psf.shape[-3:-1] @@ -170,10 +169,11 @@ def __init__( def _prepare_process_block(self, process): """ Method for preparing the pre or post process block. + Parameters ---------- - process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional - Pre or post process block to prepare. + process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + Pre or post process block to prepare. """ if isinstance(process, torch.nn.Module): # If the post_process is a torch module, we assume it is a DruNet like network. @@ -233,35 +233,18 @@ def forward(self, batch, psfs=None, **kwargs): x_inter = [self.inc(batch)] for i in range(len(self.down_layers)): x_inter.append(self.down_layers[i](x_inter[-1])) - # x1 = self.inc(x) - # x2 = self.down1(x1) - # x3 = self.down2(x2) - # x4 = self.down3(x3) - # x5 = self.down4(x4) # -- multi-scale Wiener filtering psf_multi = [self.inc0(self.w * psf)] for i in range(len(self.psf_down)): psf_multi.append(self.psf_down[i](psf_multi[-1])) - # psf1 = self.inc0(self.w * psf) - # psf2 = self.down11(psf1) - # psf3 = self.down22(psf2) - # psf4 = self.down33(psf3) for i in range(len(psf_multi)): x_inter[i] = WieNer(x_inter[i], psf_multi[i], self.delta[i]) - # x4 = WieNer(x4, psf4, self.delta[3]) - # x3 = WieNer(x3, psf3, self.delta[2]) - # x2 = WieNer(x2, psf2, self.delta[1]) - # x1 = WieNer(x1, psf1, self.delta[0]) # upsample batch = self.up_layers[0](x_inter[-1], x_inter[-2]) for i in range(len(self.up_layers) - 1): batch = self.up_layers[i + 1](batch, x_inter[-i - 3]) - # x = self.up1(x5, x4) - # x = self.up2(x, x3) - # x = self.up3(x, x2) - # x = self.up4(x, x1) batch = self.outc(batch) # back to original shape diff --git a/lensless/recon/trainable_inversion.py b/lensless/recon/trainable_inversion.py index e9cf6df5..a4e82cf0 100644 --- a/lensless/recon/trainable_inversion.py +++ b/lensless/recon/trainable_inversion.py @@ -1,6 +1,6 @@ # ############################################################################# # trainable_inversion.py -# ================= +# ====================== # Authors : # Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index abb4db53..77474730 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -89,6 +89,8 @@ def __init__( compensation : list, optional Number of channels for each intermediate output in compensation layer, as in "Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging" (2021). Post-processor must be defined if compensation provided. + compensation_residual : bool, optional + Whether to use residual connection in compensation layer. """ assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor" super(TrainableReconstructionAlgorithm, self).__init__( diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 9b62153a..ca25a704 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -133,13 +133,6 @@ def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3, residual=True, pa # -- not mentinoed in paper, but added more max-pooling for later residual layers, otherwise dimensions don't match self.residual_layers = nn.ModuleList( [ - # double_cnn_max_pool( - # in_channel, nc[i], cnn_kernel=cnn_kernel, max_pool=max_pool ** (i + 1) - # ) - # B.sequential( - # B.ResBlock(in_channel, in_channel, bias=False, mode="CRC", padding=padding, stride=stride), - # B.downsample_maxpool(in_channel, nc[i], bias=False, mode=str(max_pool ** (i + 1)), padding=padding, stride=stride) - # ) if residual ResBlock( in_channel, in_channel, diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 093839aa..72128fc0 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -364,21 +364,12 @@ def train_learned(config): lensless = shift_with_pad(lensless, shift, axis=(1, 2)) lensed = shift_with_pad(lensed, shift, axis=(1, 2)) psf_recon = shift_with_pad(psf_recon, shift, axis=(1, 2)) - # lensless = torch.roll(lensless, tuple(shift), (1, 2)) - # lensed = torch.roll(lensed, tuple(shift), (1, 2)) - # psf_recon = torch.roll(psf_recon, tuple(shift), (1, 2)) shift = tuple(shift) if config.files.random_rotate or config.files.random_shifts: save_image(psf_recon[0].cpu().numpy(), f"psf_{_idx}.png") - # lensless[:, -1] = 0 - # lensless[:, :, -1] = 0 - # fake_shift = np.ones(2).astype(int) * 1 - # lensless = shift_with_pad(lensless, tuple(fake_shift), axis=(1, 2)) - # lensless = shift_with_pad(lensless, tuple(-1 * fake_shift), axis=(1, 2)) - recon = ADMM(psf_recon) recon.set_data(lensless.to(psf_recon.device)) @@ -658,7 +649,7 @@ def train_learned(config): use_wandb=True if config.wandb_project is not None else False, n_epoch=config.training.epoch, random_rotate=config.files.random_rotate, - # random_shift=config.files.random_shifts, + random_shift=config.files.random_shifts, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx)