diff --git a/.github/workflows/python_pycsou.yml b/.github/workflows/python_pycsou.yml index d5cf1e91..61f89fa5 100644 --- a/.github/workflows/python_pycsou.yml +++ b/.github/workflows/python_pycsou.yml @@ -59,5 +59,5 @@ jobs: pip install -U pytest pip install -r recon_requirements.txt pip install -r mask_requirements.txt - pip install git+https://github.com/matthieumeo/pycsou.git@v2-dev + pip install git+https://github.com/matthieumeo/pycsou.git@38e9929c29509d350a7ff12c514e2880fdc99d6e pytest \ No newline at end of file diff --git a/.gitignore b/.gitignore index 3430ec87..ca9b0dfa 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ data/* models/* *.png *.jpg +*.npy configs/telegram_demo_secret.yaml diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 847fa0f7..a0492898 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,29 @@ Unreleased Added ~~~~~ +- Trainable reconstruction can return intermediate outputs (between pre- and post-processing). +- Auto-download for DRUNet model. +- ``utils.dataset.DiffuserCamMirflickr`` helper class for Mirflickr dataset. + +Changed +~~~~~~~ + +- Better logic for saving best model. Based on desired metric rather than last epoch, and intermediate models can be saved. +- Optional normalization in ``utils.io.load_image``. + +Bugfix +~~~~~~ + +- Support for unrolled reconstruction with grayscale, needed to copy to three channels for LPIPS. +- Fix bad train/test split for DiffuserCamMirflickr in unrolled training. + + +1.0.5 - (2023-09-05) +-------------------- + +Added +~~~~~ + - Sensor module. - Single-script and Telegram demo. - Link and citation for JOSS. @@ -22,8 +45,15 @@ Added - Script for measuring arbitrary dataset (from Raspberry Pi). - Support for preprocessing and postprocessing, such as denoising, in ``TrainableReconstructionAlgorithm``. Both trainable and fix postprocessing can be used. - Utilities to load a trained DruNet model for use as postprocessing in ``TrainableReconstructionAlgorithm``. +- Unified interface for dataset. See ``utils.dataset.DualDataset``. +- New simulated dataset compatible with new data format ([(batch_size), depth, width, height, color]). See ``utils.dataset.SimulatedFarFieldDataset``. +- New dataset for pair of original image and their measurement from a screen. See ``utils.dataset.MeasuredDataset`` and ``utils.dataset.MeasuredDatasetSimulatedOriginal``. - Support for unrolled loading and inference in the script ``admm.py``. -- Tikhonov reconstruction for coded aperture measurements (MLS / MURA). +- Tikhonov reconstruction for coded aperture measurements (MLS / MURA): numpy and Pytorch support. +- New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch. +- New ``TrainableMask`` and ``TrainablePSF`` class to train/fine-tune a mask from a dataset. +- New ``SimulatedDatasetTrainableMask`` class to train/fine-tune a mask for measurement. +- PyTorch support for ``lensless.utils.io.rgb2gray``. Changed diff --git a/README.rst b/README.rst index 23066e12..5a88de08 100644 --- a/README.rst +++ b/README.rst @@ -60,7 +60,8 @@ Python 3.9, as some Python library versions may not be available with earlier versions of Python. Moreover, its `end-of-life `__ is Oct 2025. -**Local machine** +*Local machine setup* +===================== Below are commands that worked for our configuration (Ubuntu 21.04), but there are certainly other ways to download a repository and @@ -83,16 +84,20 @@ install the library locally. # (optional) try reconstruction on local machine python scripts/recon/admm.py + # (optional) try reconstruction on local machine with GPU + python scripts/recon/admm.py -cn pytorch -Note (25-04-2023): for using reconstruction method based on Pycsou ``lensless.apgd.APGD``, -V2 has to be installed: + +Note (25-04-2023): for using the :py:class:`~lensless.recon.apgd.APGD` reconstruction method based on Pycsou +(now `Pyxu `__), a specific commit has +to be installed (as there was no release at the time of implementation): .. code:: bash - pip install git+https://github.com/matthieumeo/pycsou.git@v2-dev + pip install git+https://github.com/matthieumeo/pycsou.git@38e9929c29509d350a7ff12c514e2880fdc99d6e If PyTorch is installed, you will need to be sure to have PyTorch 2.0 or higher, -as Pycsou V2 is not compatible with earlier versions of PyTorch. Moreover, +as Pycsou is not compatible with earlier versions of PyTorch. Moreover, Pycsou requires Python within `[3.9, 3.11) `__. @@ -102,7 +107,8 @@ Moreover, ``numba`` (requirement for Pycsou V2) may require an older version of pip install numpy==1.23.5 -**Raspberry Pi** +*Raspberry Pi setup* +==================== After `flashing your Raspberry Pi with SSH enabled `__, you need to set it up for `passwordless access `__. diff --git a/configs/adafruit.yaml b/configs/adafruit.yaml new file mode 100644 index 00000000..5399204e --- /dev/null +++ b/configs/adafruit.yaml @@ -0,0 +1,9 @@ +defaults: + - demo + - _self_ + +plot: True + +capture: + exp: 5.0 + awb_gains: [1, 1] diff --git a/configs/apgd_l1.yaml b/configs/apgd_l1.yaml index 5d0621cc..006b72aa 100644 --- a/configs/apgd_l1.yaml +++ b/configs/apgd_l1.yaml @@ -3,6 +3,10 @@ defaults: - defaults_recon - _self_ +preprocess: + # Downsampling factor along X and Y + downsample: 8 + apgd: # Proximal prior / regularization: nonneg, l1, null prox_penalty: l1 diff --git a/configs/apgd_l2.yaml b/configs/apgd_l2.yaml index 65a16405..0b50ba73 100644 --- a/configs/apgd_l2.yaml +++ b/configs/apgd_l2.yaml @@ -3,6 +3,10 @@ defaults: - defaults_recon - _self_ +preprocess: + # Downsampling factor along X and Y + downsample: 8 + apgd: diff_penalty: l2 diff_lambda: 0.0001 diff --git a/configs/defaults_recon.yaml b/configs/defaults_recon.yaml index 5cd05d6c..1771ff8a 100644 --- a/configs/defaults_recon.yaml +++ b/configs/defaults_recon.yaml @@ -8,11 +8,13 @@ input: # File path for raw data data: data/raw_data/thumbs_up_rgb.png dtype: float32 + original: null # ground truth image torch: False torch_device: 'cpu' preprocess: + normalize: True # Downsampling factor along X and Y downsample: 4 # Image shape (height, width) for reconstruction. @@ -27,6 +29,7 @@ preprocess: single_psf: False # Whether to perform construction in grayscale. gray: False + bg_pix: [5, 25] # null to skip display: diff --git a/configs/demo.yaml b/configs/demo.yaml index c769d1a2..ddc0c528 100644 --- a/configs/demo.yaml +++ b/configs/demo.yaml @@ -26,6 +26,8 @@ display: psf: null # all black screen black: False + # all white screen + white: False capture: gamma: null # for visualization diff --git a/configs/diffusercam_mirflickr_single_admm.yaml b/configs/diffusercam_mirflickr_single_admm.yaml new file mode 100644 index 00000000..5055bf6f --- /dev/null +++ b/configs/diffusercam_mirflickr_single_admm.yaml @@ -0,0 +1,43 @@ +# python scripts/recon/admm.py -cn diffusercam_mirflickr_single_admm +defaults: + - defaults_recon + - _self_ + + +display: + gamma: null + +input: + # File path for recorded PSF + psf: data/DiffuserCam_Test/psf.tiff + # File path for raw data + data: data/DiffuserCam_Test/diffuser/im5.npy + dtype: float32 + original: data/DiffuserCam_Test/lensed/im5.npy + +torch: True +torch_device: 'cuda:0' + +preprocess: + downsample: 8 # factor for PSF, which is 4x resolution of image + normalize: False + +admm: + # Number of iterations + n_iter: 20 + # Hyperparameters + mu1: 1e-6 + mu2: 1e-5 + mu3: 4e-5 + tau: 0.0001 + #Loading unrolled model + unrolled: True + # checkpoint_fp: pretrained_models/Pre_Unrolled_Post-DiffuserCam/model_weights.pt + checkpoint_fp: outputs/2023-09-11/22-06-49/recon.pt # pre unet and post drunet + pre_process_model: + network : UnetRes # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process_model: + network : DruNet # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + \ No newline at end of file diff --git a/configs/digicam.yaml b/configs/digicam.yaml new file mode 100644 index 00000000..d84b3a89 --- /dev/null +++ b/configs/digicam.yaml @@ -0,0 +1,23 @@ +rpi: + username: null + hostname: null + +device: adafruit +virtual: False +save: True + +# pattern: data/psf/adafruit_random_pattern_20230719.npy +pattern: random +# pattern: rect +# pattern: circ +min_val: 0 # if pattern: random, min for range(0,1) +rect_shape: [20, 10] # if pattern: rect +radius: 20 # if pattern: circ +center: [0, 0] + + +aperture: + center: [59,76] + shape: [19,26] + +z: 4 # mask to sensor distance diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml new file mode 100644 index 00000000..af55e03a --- /dev/null +++ b/configs/fine-tune_PSF.yaml @@ -0,0 +1,18 @@ +# python scripts/recon/train_unrolled.py -cn fine-tune_PSF +defaults: + - train_unrolledADMM + - _self_ + +#Trainable Mask +trainable_mask: + mask_type: TrainablePSF #Null or "TrainablePSF" + initial_value: psf + mask_lr: 1e-3 + L1_strength: 1.0 #False or float + +#Training +training: + save_every: 5 + +display: + gamma: 2.2 diff --git a/configs/mask_sim_single.yaml b/configs/mask_sim_single.yaml index f793d302..0d20efa5 100644 --- a/configs/mask_sim_single.yaml +++ b/configs/mask_sim_single.yaml @@ -8,6 +8,7 @@ files: #original: data/original/mnist_3.png save: True +use_torch: False simulation: object_height: 0.3 diff --git a/configs/recon_dataset.yaml b/configs/recon_dataset.yaml new file mode 100644 index 00000000..f474aed5 --- /dev/null +++ b/configs/recon_dataset.yaml @@ -0,0 +1,47 @@ +# python scripts/recon/dataset.py +defaults: + - defaults_recon + - _self_ + +torch: True +torch_device: 'cuda:0' + +input: + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + psf: data/psf/adafruit_random_2mm_20231907.png + # https://drive.switch.ch/index.php/s/m89D1tFEfktQueS + raw_data: data/celeba_adafruit_random_2mm_20230720_1K + +n_files: 25 # null for all files +output_folder: data/celeba_adafruit_recon + +# extraction region of interest +roi: null # top, left, bottom, right +# -- values for `data/celeba_adafruit_random_2mm_20230720_1K` +# roi: [10, 300, 560, 705] # down 4 +# roi: [6, 200, 373, 470] # down 6 +# roi: [5, 150, 280, 352] # down 8 + +preprocess: + flip: True + downsample: 6 + + # to have different data shape than PSF + data_dim: null + # data_dim: [48, 64] # down 64 + # data_dim: [506, 676] # down 6 + +display: + disp: -1 + plot: False + +algo: admm # "admm", "apgd", "null" to just copy over (resized) raw data + +apgd: + n_jobs: 1 # run in parallel as algo is slow + max_iter: 500 + +admm: + n_iter: 10 + +save: False \ No newline at end of file diff --git a/configs/sim_digicam_psf.yaml b/configs/sim_digicam_psf.yaml new file mode 100644 index 00000000..216455cd --- /dev/null +++ b/configs/sim_digicam_psf.yaml @@ -0,0 +1,38 @@ +# python scripts/sim/digicam_psf.py +hydra: + job: + chdir: True # change to output folder + +use_torch: False +dtype: float32 +torch_device: cuda +requires_grad: True + +digicam: + + slm: adafruit + sensor: rpi_hq + + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + pattern: data/psf/adafruit_random_pattern_20230719.npy + ap_center: [59, 76] + ap_shape: [19, 26] + rotate: -0.8 # rotation in degrees + + # optionally provide measured PSF for side-by-side comparison + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + psf: data/psf/adafruit_random_2mm_20231907.png + gamma: 2 # for plotting measured + +sim: + + # whether SLM is fliped + flipud: True + + # in practice found waveprop=True or False doesn't make difference + waveprop: False + + # below are ignored if waveprop=False + scene2mask: 0.03 # [m] + mask2sensor: 0.002 # [m] + \ No newline at end of file diff --git a/configs/train_celeba_classifier.yaml b/configs/train_celeba_classifier.yaml new file mode 100644 index 00000000..11a391c8 --- /dev/null +++ b/configs/train_celeba_classifier.yaml @@ -0,0 +1,38 @@ +hydra: + job: + chdir: True # change to output folder + +seed: 0 + +data: + # -- path to original CelebA (parent directory) + original: /scratch/bezzam + + output_dir: "./vit-celeba" # basename for model output + + # -- raw + # https://drive.switch.ch/index.php/s/m89D1tFEfktQueS + measured: data/celeba_adafruit_random_2mm_20230720_10K + raw: True + + # # -- reconstructed + # # run `python scripts/recon/dataset.py` to get a reconstructed dataset + # measured: null + # raw: False + + n_files: null # null to use all in measured_folder + test_size: 0.15 + attr: Male # "Male", "Smiling", etc + +augmentation: + + random_resize_crop: False + horizontal_flip: True # cannot be used with raw measurement! + +train: + + prev: null # path to previously trained model + n_epochs: 4 + dropout: 0.1 + batch_size: 16 + learning_rate: 2e-4 diff --git a/configs/train_pre-post-processing.yaml b/configs/train_pre-post-processing.yaml new file mode 100644 index 00000000..f4d6ba98 --- /dev/null +++ b/configs/train_pre-post-processing.yaml @@ -0,0 +1,24 @@ +# python scripts/recon/train_unrolled.py -cn train_pre-post-processing +defaults: + - train_unrolledADMM + - _self_ + +display: + disp: 400 + +reconstruction: + method: unrolled_admm + + pre_process: + network: UnetRes + depth: 2 + post_process: + network: DruNet + depth: 4 + +training: + epoch: 50 + slow_start: 0.01 + +loss: l2 +lpips: 1.0 diff --git a/configs/train_psf_from_scratch.yaml b/configs/train_psf_from_scratch.yaml new file mode 100644 index 00000000..b4eef0ed --- /dev/null +++ b/configs/train_psf_from_scratch.yaml @@ -0,0 +1,18 @@ +# python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +defaults: + - train_unrolledADMM + - _self_ + +# Train Dataset +files: + dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + celeba_root: /scratch/bezzam + downsample: 8 + +#Trainable Mask +trainable_mask: + mask_type: TrainablePSF #Null or "TrainablePSF" + initial_value: "random" + +simulation: + grayscale: False diff --git a/configs/unrolled_recon.yaml b/configs/train_unrolledADMM.yaml similarity index 57% rename from configs/unrolled_recon.yaml rename to configs/train_unrolledADMM.yaml index 621e3cfa..3871be0d 100644 --- a/configs/unrolled_recon.yaml +++ b/configs/train_unrolledADMM.yaml @@ -1,33 +1,24 @@ +# python scripts/recon/train_unrolled.py hydra: job: chdir: True # change to output folder -#Reconstruction algorithm -input: - # File path for recorded PSF - psf: data/DiffuserCam_Mirflickr_200_3011302021_11h43_seed11/psf.tiff - dtype: float32 +# Dataset +files: + dataset: data/DiffuserCam # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + psf: data/psf.tiff + diffusercam_psf: True + n_files: null # null to use all for both train/test + downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution torch: True torch_device: 'cuda' -preprocess: - # Image shape (height, width) for reconstruction. - shape: null - # Whether image is raw bayer data. - bayer: False - blue_gain: null - red_gain: null - # Same PSF for all channels (sum) or unique PSF for RGB. - single_psf: False - # Whether to perform construction in grayscale. - gray: False - - display: # How many iterations to wait for intermediate plot. # Set to negative value for no intermediate plots. - disp: 400 + disp: 500 # Whether to plot results. plot: True # Gamma factor for plotting. @@ -48,24 +39,27 @@ reconstruction: learn_tk: True unrolled_admm: # Number of iterations - n_iter: 5 + n_iter: 20 # Hyperparameters mu1: 1e-4 mu2: 1e-4 mu3: 1e-4 tau: 2e-4 pre_process: - network : UnetRes # UnetRes or DruNet or null + network : null # UnetRes or DruNet or null depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet post_process: - network : UnetRes # UnetRes or DruNet or null + network : null # UnetRes or DruNet or null depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet -# Train Dataset - -files: - dataset: "DiffuserCam" # "mnist", "fashion_mnist", "cifar10", "CelebA", "DiffuserCam" - n_files: null # null to use all +#Trainable Mask +trainable_mask: + mask_type: Null #Null or "TrainablePSF" + # "random" (with shape of config.files.psf) or "psf" (using config.files.psf) + initial_value: psf + grayscale: False + mask_lr: 1e-3 + L1_strength: 1.0 #False or float target: "object_plane" # "original" or "object_plane" or "label" @@ -73,23 +67,25 @@ target: "object_plane" # "original" or "object_plane" or "label" simulation: grayscale: False # random variations - object_height: 0.6 # range for random height or scalar + object_height: 0.04 # range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) random_shift: False random_vflip: 0.5 random_hflip: 0.5 random_rotate: False # these distance parameters are typically fixed for a given PSF - # for tape_rgb psf # for DiffuserCam psf - scene2mask: 40e-2 # scene2mask: 10e-2 - mask2sensor: 4e-3 # mask2sensor: 9e-3 + # for DiffuserCam psf # for tape_rgb psf + scene2mask: 10e-2 # scene2mask: 40e-2 + mask2sensor: 9e-3 # mask2sensor: 4e-3 # see waveprop.devices sensor: "rpi_hq" - snr_db: 40 + snr_db: 10 # simulate different sensor resolution # output_dim: [24, 32] # [H, W] or null # Downsampling for PSF downsample: 8 # max val in simulated measured (quantized 8 bits) + quantize: False # must be False for differentiability max_val: 255 #Training @@ -97,6 +93,8 @@ simulation: training: batch_size: 8 epoch: 50 + metric_for_best_model: null # e.g. LPIPS_Vgg, null does test loss + save_every: null #In case of instable training skip_NAN: True slow_start: False #float how much to reduce lr for first epoch @@ -104,7 +102,7 @@ training: optimizer: type: Adam - lr: 1e-6 + lr: 1e-4 loss: 'l2' # set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) diff --git a/digicam_requirements.txt b/digicam_requirements.txt new file mode 100644 index 00000000..fbbcaa30 --- /dev/null +++ b/digicam_requirements.txt @@ -0,0 +1 @@ +slm_controller @ git+https://github.com/ebezzam/slm-controller.git \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index f8146fac..3eb1e15f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,4 +4,6 @@ docutils==0.16 # >0.17 doesn't render bullets numpy>=1.22 # so that default dtype are correctly rendered torch>=1.10 torchvision>=0.15.2 -torchmetrics>=0.11.4 \ No newline at end of file +torchmetrics>=0.11.4 +pyFFS>=2.2.3 # for waveprop +waveprop>=0.0.7 \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index fc01f75b..60ee9e96 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,21 +21,25 @@ "torchmetrics.image", "scipy.ndimage", "pycsou.abc", + "pycsou.operator", "pycsou.operator.func", + "pycsou.operator.linop", "pycsou.opt.solver", "pycsou.opt.stop", "pycsou.runtime", "pycsou.util", "pycsou.util.ptype", "PIL", + "PIL.Image", "tqdm", "paramiko", "paramiko.ssh_exception", "perlin_numpy", - "waveprop", - "waveprop.fresnel", - "waveprop.rs", - "waveprop.noise", + "hydra", + "hydra.utils", + "scipy.special", + "matplotlib.cm", + "pyffs", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/docs/source/data.rst b/docs/source/data.rst index 768b46fb..50b323c6 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -39,6 +39,20 @@ use the correct PSF file for the data you're using! input.psf=data/psf/tape_rgb.png +Measured CelebA Dataset +----------------------- + +You can download 1K measurements of the CelebA dataset done with +our lensless camera and a random pattern on the Adafruit LCD +`here (1.2 GB) `__, +and a dataset with 10K measurements +`here (13.1 GB) `__. +They both correspond to the PSF which can be found `here `__ +(``adafruit_random_2mm_20231907.png`` which is the PSF of +``adafruit_random_pattern_20230719.npy`` measured with a mask to sensor +distance of 2 mm). + + DiffuserCam Lensless Mirflickr Dataset (DLMD) --------------------------------------------- diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst new file mode 100644 index 00000000..ad21defb --- /dev/null +++ b/docs/source/dataset.rst @@ -0,0 +1,61 @@ +Dataset objects (for training and testing) +========================================== + +The software below provides functionality (with PyTorch) to load +datasets for training and testing. + +.. automodule:: lensless.utils.dataset + +Abstract base class +------------------- + +All dataset objects derive from this abstract base class, which +lays out the notion of a dataset with pairs of images: one image +is lensed (simulated or measured), and the other is lensless (simulated +or measured). + +.. autoclass:: lensless.utils.dataset.DualDataset + :members: _get_images_pair + :special-members: __init__, __len__ + + +Simulated dataset objects +------------------------- + +These dataset objects can be used for training and testing with +simulated data. The main assumption is that the imaging system +is linear shift-invariant (LSI), and that the lensless image is +the result of a convolution of the lensed image with a point-spread +function (PSF). Check out `this Medium post `__ +for more details on the simulation procedure. + +With simulated data, we can avoid the hassle of collecting a large +amount of data. However, it's important to note that the LSI assumption +can sometimes be too idealistic, in particular for large angles. + +Nevertheless, simulating data is the only option of learning the +mask / PSF. + +.. autoclass:: lensless.utils.dataset.SimulatedFarFieldDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.SimulatedDatasetTrainableMask + :members: + :special-members: __init__ + + +Measured dataset objects +------------------------ + +.. autoclass:: lensless.utils.dataset.MeasuredDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset + :members: + :special-members: __init__ diff --git a/docs/source/evaluation.rst b/docs/source/evaluation.rst index f3f381d2..0f2c9d93 100644 --- a/docs/source/evaluation.rst +++ b/docs/source/evaluation.rst @@ -23,8 +23,4 @@ .. automodule:: lensless.eval.benchmark - .. autoclass:: lensless.eval.benchmark.ParallelDataset - :members: - :special-members: __init__ - .. autofunction:: lensless.eval.benchmark.benchmark \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 94c236e6..3fba13d2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,6 +35,7 @@ Contents simulation data + dataset .. toctree:: :hidden: diff --git a/docs/source/mask.rst b/docs/source/mask.rst index 0ad8327e..036d0f12 100644 --- a/docs/source/mask.rst +++ b/docs/source/mask.rst @@ -29,5 +29,15 @@ ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: lensless.hardware.mask.FresnelZoneAperture + :members: + :special-members: __init__ + + Trainable Mask + ~~~~~~~~~~~~~~~~~~~~~ + .. autoclass:: lensless.hardware.trainable_mask.TrainableMask + :members: + :special-members: __init__ + + .. autoclass:: lensless.hardware.trainable_mask.TrainablePSF :members: :special-members: __init__ \ No newline at end of file diff --git a/docs/source/reconstruction.rst b/docs/source/reconstruction.rst index 27434c40..e5b927f4 100644 --- a/docs/source/reconstruction.rst +++ b/docs/source/reconstruction.rst @@ -55,7 +55,7 @@ Accelerated Proximal Gradient Descent (APGD) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - .. autoclass:: lensless.APGD + .. autoclass:: lensless.recon.apgd.APGD :special-members: __init__ @@ -88,4 +88,22 @@ .. autoclass:: lensless.UnrolledADMM :members: batch_call :special-members: __init__ - :show-inheritance: \ No newline at end of file + :show-inheritance: + + + Reconstruction Utilities + ------------------------ + + .. autoclass:: lensless.recon.utils.Trainer + :members: + :special-members: __init__ + + .. autofunction:: lensless.recon.utils.load_drunet + + .. autofunction:: lensless.recon.utils.apply_denoiser + + .. autofunction:: lensless.recon.utils.get_drunet_function + + .. autofunction:: lensless.recon.utils.measure_gradient + + .. autofunction:: lensless.recon.utils.create_process_network diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst index d5ecaa34..12739ad2 100644 --- a/docs/source/simulation.rst +++ b/docs/source/simulation.rst @@ -16,6 +16,18 @@ library is used with the following simulation steps: PyTorch support is available to speed up simulation on GPU, and to create Dataset and DataLoader objects for training and testing! +FarFieldSimulator +------------------ + +A wrapper around `waveprop.simulation.FarFieldSimulator `__ +is implemented as :py:class:`lensless.utils.simulation.FarFieldSimulator`. +It handles the conversion between the HWC and CHW dimension orderings so that the convention of LenslessPiCam can be maintained (namely HWC). + +.. autoclass:: lensless.utils.simulation.FarFieldSimulator + :members: + :special-members: __init__ + + Simulating 3D data ------------------ diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index b4aa6b79..885766f3 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -7,18 +7,12 @@ # ############################################################################# -import glob -import os -from lensless.utils.io import load_psf -from lensless.utils.image import resize -import numpy as np +from lensless.utils.dataset import DiffuserCamTestDataset from tqdm import tqdm -from lensless.utils.io import load_image - try: import torch - from torch.utils.data import Dataset, DataLoader + from torch.utils.data import DataLoader from torch.nn import MSELoss, L1Loss from torchmetrics import StructuralSimilarityIndexMeasure from torchmetrics.image import lpip, psnr @@ -28,207 +22,6 @@ ) -class ParallelDataset(Dataset): - """ - Dataset consisting of lensless and corresponding lensed image. - - It can be used with a PyTorch DataLoader to load a batch of lensless and corresponding lensed images. - - """ - - def __init__( - self, - root_dir, - n_files=False, - background=None, - downsample=4, - flip=False, - transform_lensless=None, - transform_lensed=None, - lensless_fn="diffuser", - lensed_fn="lensed", - image_ext="npy", - **kwargs, - ): - """ - Dataset consisting of lensless and corresponding lensed image. Default parameters are for the DiffuserCam - Lensless Mirflickr Dataset (DLMD). - - Parameters - ---------- - - root_dir : str - Path to the test dataset. It is expected to contain two folders: ones of lensless images and one of lensed images. - n_files : int or None, optional - Metrics will be computed only on the first ``n_files`` images. If None, all images are used, by default False - background : :py:class:`~torch.Tensor` or None, optional - If not ``None``, background is removed from lensless images, by default ``None``. - downsample : int, optional - Downsample factor of the lensless images, by default 4. - flip : bool, optional - If ``True``, lensless images are flipped, by default ``False``. - transform_lensless : PyTorch Transform or None, optional - Transform to apply to the lensless images, by default None - transform_lensed : PyTorch Transform or None, optional - Transform to apply to the lensed images, by default None - lensless_fn : str, optional - Name of the folder containing the lensless images, by default "diffuser". - lensed_fn : str, optional - Name of the folder containing the lensed images, by default "lensed". - image_ext : str, optional - Extension of the images, by default "npy". - """ - - self.root_dir = root_dir - self.lensless_dir = os.path.join(root_dir, lensless_fn) - self.lensed_dir = os.path.join(root_dir, lensed_fn) - self.image_ext = image_ext.lower() - - files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) - if n_files: - files = files[:n_files] - self.files = [os.path.basename(fn) for fn in files] - - if len(self.files) == 0: - raise FileNotFoundError( - f"No files found in {self.lensless_dir} with extension {image_ext}" - ) - - self.background = background - self.downsample = downsample / 4 - self.flip = flip - self.transform_lensless = transform_lensless - self.transform_lensed = transform_lensed - - def __len__(self): - return len(self.files) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - - if self.image_ext == "npy": - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) - lensless = np.load(lensless_fp) - lensed = np.load(lensed_fp) - else: - # more standard image formats: png, jpg, tiff, etc. - lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) - lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) - lensless = load_image(lensless_fp) - lensed = load_image(lensed_fp) - - # convert to float - if lensless.dtype == np.uint8: - lensless = lensless.astype(np.float32) / 255 - lensed = lensed.astype(np.float32) / 255 - else: - # 16 bit - lensless = lensless.astype(np.float32) / 65535 - lensed = lensed.astype(np.float32) / 65535 - - if self.downsample != 1.0: - lensless = resize(lensless, factor=1 / self.downsample) - lensed = resize(lensed, factor=1 / self.downsample) - - lensless = torch.from_numpy(lensless) - lensed = torch.from_numpy(lensed) - - # If [H, W, C] -> [D, H, W, C] - if len(lensless.shape) == 3: - lensless = lensless.unsqueeze(0) - if len(lensed.shape) == 3: - lensed = lensed.unsqueeze(0) - - if self.background is not None: - lensless = lensless - self.background - - # flip image x and y if needed - if self.flip: - lensless = torch.rot90(lensless, dims=(-3, -2)) - lensed = torch.rot90(lensed, dims=(-3, -2)) - if self.transform_lensless: - lensless = self.transform_lensless(lensless) - - if self.transform_lensed: - lensed = self.transform_lensed(lensed) - - return lensless, lensed - - -class DiffuserCamTestDataset(ParallelDataset): - """ - Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking. - """ - - def __init__( - self, - data_dir="data", - n_files=200, - downsample=8, - ): - """ - Dataset consisting of lensless and corresponding lensed image. Default parameters are for the test set of DiffuserCam - Lensless Mirflickr Dataset (DLMD). - - Parameters - ---------- - data_dir : str, optional - The path to the folder containing the DiffuserCam_Test dataset, by default "data" - n_files : int, optional - Number of image pair to load in the dataset , by default 200 - downsample : int, optional - Downsample factor of the lensless images, by default 8 - """ - # download dataset if necessary - main_dir = data_dir - data_dir = os.path.join(data_dir, "DiffuserCam_Test") - if not os.path.isdir(data_dir): - print("No dataset found for benchmarking.") - try: - from torchvision.datasets.utils import download_and_extract_archive - except ImportError: - exit() - msg = "Do you want to download the sample dataset (3.5GB)?" - - # default to yes if no input is given - valid = input("%s (Y/n) " % msg).lower() != "n" - if valid: - url = "https://drive.switch.ch/index.php/s/D3eRJ6PRljfHoH8/download" - filename = "DiffuserCam_Test.zip" - download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True) - - psf_fp = os.path.join(data_dir, "psf.tiff") - psf, background = load_psf( - psf_fp, - downsample=downsample, - return_float=True, - return_bg=True, - bg_pix=(0, 15), - ) - - # transform from BGR to RGB - from torchvision import transforms - - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - - self.psf = transform_BRG2RGB(torch.from_numpy(psf)) - - super().__init__( - data_dir, - n_files, - background, - downsample, - flip=False, - transform_lensless=transform_BRG2RGB, - transform_lensed=transform_BRG2RGB, - lensless_fn="diffuser", - lensed_fn="lensed", - image_ext="npy", - ) - - def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): """ Compute multiple metrics for a reconstruction algorithm. @@ -300,7 +93,20 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): if metric == "ReconstructionError": metrics_values[metric] += model.reconstruction_error().cpu().item() else: - metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() + if "LPIPS" in metric: + if prediction.shape[1] == 1: + # LPIPS needs 3 channels + metrics_values[metric] += ( + metrics[metric]( + prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + ) + .cpu() + .item() + ) + else: + metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() + else: + metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() model.reset() diff --git a/lensless/hardware/aperture.py b/lensless/hardware/aperture.py new file mode 100644 index 00000000..37e8e37b --- /dev/null +++ b/lensless/hardware/aperture.py @@ -0,0 +1,379 @@ +# ############################################################################# +# aperture.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +from enum import Enum + +import numpy as np +from lensless.utils.image import rgb2gray + + +class ApertureOptions(Enum): + RECT = "rect" + SQUARE = "square" + LINE = "line" + CIRC = "circ" + + @staticmethod + def values(): + return [shape.value for shape in ApertureOptions] + + +class Aperture: + def __init__(self, shape, pixel_pitch): + """ + Class for defining VirtualSLM. + + :param shape: (height, width) in number of cell. + :type shape: tuple(int) + :param pixel_pitch: Pixel pitch (height, width) in meters. + :type pixel_pitch: tuple(float) + """ + assert np.all(shape) > 0 + assert np.all(pixel_pitch) > 0 + self._shape = shape + self._pixel_pitch = pixel_pitch + self._values = np.zeros((3,) + shape, dtype=np.uint8) + + @property + def size(self): + return np.prod(self._shape) + + @property + def shape(self): + return self._shape + + @property + def pixel_pitch(self): + return self._pixel_pitch + + @property + def center(self): + return np.array([self.height / 2, self.width / 2]) + + @property + def dim(self): + return np.array(self._shape) * np.array(self._pixel_pitch) + + @property + def height(self): + return self.dim[0] + + @property + def width(self): + return self.dim[1] + + @property + def values(self): + return self._values + + @property + def grayscale_values(self): + return rgb2gray(self._values) + + def at(self, physical_coord, value=None): + """ + Get/set values of VirtualSLM at physical coordinate in meters. + + :param physical_coord: Physical coordinates to get/set VirtualSLM values. + :type physical_coord: int, float, slice tuples + :param value: [Optional] values to set, otherwise return values at + specified coordinates. Defaults to None + :type value: int, float, :py:class:`~numpy.ndarray`, optional + :return: If getter is used, values at those coordinates + :rtype: ndarray + """ + idx = prepare_index_vals(physical_coord, self._pixel_pitch) + if value is None: + # getter + return self._values[idx] + else: + # setter + self._values[idx] = value + + def __getitem__(self, key): + return self._values[key] + + def __setitem__(self, key, value): + self._values[key] = value + + def plot(self, show_tick_labels=False): + """ + Plot Aperture. + + :param show_tick_labels: Whether to show cell number along x- and y-axis, defaults to False + :type show_tick_labels: bool, optional + :return: The axes of the plot. + :rtype: Axes + """ + # prepare mask data for `imshow`, expects the input data array size to be (width, height, 3) + Z = self.values.transpose(1, 2, 0) + + # plot + import matplotlib.pyplot as plt + + _, ax = plt.subplots() + extent = [ + -0.5 * self._pixel_pitch[1], + (self._shape[1] - 0.5) * self._pixel_pitch[1], + (self._shape[0] - 0.5) * self._pixel_pitch[0], + -0.5 * self._pixel_pitch[0], + ] + ax.imshow(Z, extent=extent) + ax.grid(which="major", axis="both", linestyle="-", color="0.5", linewidth=0.25) + + x_ticks = np.arange(-0.5, self._shape[1], 1) * self._pixel_pitch[1] + ax.set_xticks(x_ticks) + if show_tick_labels: + x_tick_labels = (np.arange(-0.5, self._shape[1], 1) + 0.5).astype(int) + else: + x_tick_labels = [None] * len(x_ticks) + ax.set_xticklabels(x_tick_labels) + + y_ticks = np.arange(-0.5, self._shape[0], 1) * self._pixel_pitch[0] + ax.set_yticks(y_ticks) + if show_tick_labels: + y_tick_labels = (np.arange(-0.5, self._shape[0], 1) + 0.5).astype(int) + else: + y_tick_labels = [None] * len(y_ticks) + ax.set_yticklabels(y_tick_labels) + return ax + + +def rect_aperture(slm_shape, pixel_pitch, apert_dim, center=None): + """ + Create and return VirtualSLM object with rectangular aperture of desired dimensions. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param apert_dim: Dimensions (height, width) of aperture in meters. + :type apert_dim: tuple(float) + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :raises ValueError: If aperture does extend over the boarder of the SLM. + :return: VirtualSLM object with cells programmed to desired rectangular aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # check input values + assert np.all(apert_dim) > 0 + + # initialize SLM + slm = Aperture(shape=slm_shape, pixel_pitch=pixel_pitch) + + # check / compute center + if center is None: + center = slm.center + else: + assert ( + 0 <= center[0] < slm.height + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + assert ( + 0 <= center[1] < slm.width + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + + # compute mask + apert_dim = np.array(apert_dim) + top_left = center - apert_dim / 2 + bottom_right = top_left + apert_dim + if ( + top_left[0] < 0 + or top_left[1] < 0 + or bottom_right[0] >= slm.dim[0] + or bottom_right[1] >= slm.dim[1] + ): + raise ValueError( + f"Aperture ({top_left[0]}:{bottom_right[0]}, " + f"{top_left[1]}:{bottom_right[1]}) extends past valid " + f"VirtualSLM dimensions {slm.dim}" + ) + slm.at( + physical_coord=np.s_[top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]], + value=255, + ) + + return slm + + +def line_aperture(slm_shape, pixel_pitch, length, vertical=True, center=None): + """ + Create and return VirtualSLM object with a line aperture of desired length. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param length: Length of aperture in meters. + :type length: float + :param vertical: Orient line vertically, defaults to True. + :type vertical: bool, optional + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired line aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # call `create_rect_aperture` + apert_dim = (length, pixel_pitch[1]) if vertical else (pixel_pitch[0], length) + return rect_aperture(slm_shape, pixel_pitch, apert_dim, center) + + +def square_aperture(slm_shape, pixel_pitch, side, center=None): + """ + Create and return VirtualSLM object with a square aperture of desired shape. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param side: Side length of square aperture in meters. + :type side: float + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired square aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + return rect_aperture(slm_shape, pixel_pitch, (side, side), center) + + +def circ_aperture(slm_shape, pixel_pitch, radius, center=None): + """ + Create and return VirtualSLM object with a circle aperture of desired shape. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param radius: Radius of aperture in meters. + :type radius: float + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired circle aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # check input values + assert radius > 0 + + # initialize SLM + slm = Aperture(shape=slm_shape, pixel_pitch=pixel_pitch) + + # check / compute center + if center is None: + center = slm.center + else: + assert ( + 0 <= center[0] < slm.height + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + assert ( + 0 <= center[1] < slm.width + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + + # compute mask + i, j = np.meshgrid( + np.arange(slm.dim[0], step=slm.pixel_pitch[0]), + np.arange(slm.dim[1], step=slm.pixel_pitch[1]), + sparse=True, + indexing="ij", + ) + x2 = (i - center[0]) ** 2 + y2 = (j - center[1]) ** 2 + slm[:] = 255 * (x2 + y2 < radius**2) + return slm + + +def _cell_slice(_slice, cell_m): + """ + Convert slice indexing in meters to slice indexing in cells. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param _slice: Original slice in meters. + :type _slice: slice + :param cell_m: Dimension of cell in meters. + :type cell_m: float + :return: The new slice + :rtype: slice + """ + start = None if _slice.start is None else _m_to_cell_idx(_slice.start, cell_m) + stop = _m_to_cell_idx(_slice.stop, cell_m) if _slice.stop is not None else None + step = _m_to_cell_idx(_slice.step, cell_m) if _slice.step is not None else None + return slice(start, stop, step) + + +def _m_to_cell_idx(val, cell_m): + """ + Convert location to cell index. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param val: Location in meters. + :type val: float + :param cell_m: Dimension of cell in meters. + :type cell_m: float + :return: The cell index. + :rtype: int + """ + return int(val / cell_m) + + +def prepare_index_vals(key, pixel_pitch): + """ + Convert indexing object in meters to indexing object in cell indices. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param key: Indexing operation in meters. + :type key: int, float, slice, or list + :param pixel_pitch: Pixel pitch (height, width) in meters. + :type pixel_pitch: tuple(float) + :raises ValueError: If the key is of the wrong type. + :raises NotImplementedError: If key is of size 3, individual channels can't + be indexed. + :raises ValueError: If the key has the wrong dimensions. + :return: The new indexing object. + :rtype: tuple[slice, int] | tuple[slice, slice] | tuple[slice, ...] + """ + if isinstance(key, (float, int)): + idx = slice(None), _m_to_cell_idx(key, pixel_pitch[0]) + + elif isinstance(key, slice): + idx = slice(None), _cell_slice(key, pixel_pitch[0]) + + elif len(key) == 2: + idx = [slice(None)] + for k, _slice in enumerate(key): + + if isinstance(_slice, slice): + idx.append(_cell_slice(_slice, pixel_pitch[k])) + + elif isinstance(_slice, (float, int)): + idx.append(_m_to_cell_idx(_slice, pixel_pitch[k])) + + else: + raise ValueError("Invalid key.") + idx = tuple(idx) + + elif len(key) == 3: + raise NotImplementedError("Cannot index individual channels.") + + else: + raise ValueError("Invalid key.") + return idx diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index 9cde01b2..f9597bf5 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -32,7 +32,13 @@ from waveprop.noise import add_shot_noise from lensless.hardware.sensor import VirtualSensor from lensless.utils.image import resize -from lensless.utils.image import rgb2bayer, bayer2rgb + +try: + import torch + + torch_available = True +except ImportError: + torch_available = False class Mask(abc.ABC): @@ -296,12 +302,23 @@ def simulate(self, obj, snr_db=20): # Convolve image n_channels = obj.shape[-1] - meas = np.dstack([multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)]) + + if torch_available and isinstance(obj, torch.Tensor): + P = torch.from_numpy(P).float() + Q = torch.from_numpy(Q).float() + meas = torch.dstack( + [torch.linalg.multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)] + ).float() + else: + meas = np.dstack([multi_dot([P, obj[:, :, c], Q.T]) for c in range(n_channels)]) # Add noise if snr_db is not None: meas = add_shot_noise(meas, snr_db=snr_db) + if torch_available and isinstance(obj, torch.Tensor): + meas = meas.to(obj) + return meas diff --git a/lensless/hardware/sensor.py b/lensless/hardware/sensor.py index 36a5adda..08a00a05 100644 --- a/lensless/hardware/sensor.py +++ b/lensless/hardware/sensor.py @@ -170,6 +170,8 @@ def __init__( else: self.size = self.pixel_size * self.resolution + self.pitch = self.size / self.resolution + self.image_shape = self.resolution if self.color: self.image_shape = np.append(self.image_shape, 3) @@ -298,6 +300,7 @@ def downsample(self, factor): assert factor > 1, "Downsample factor must be greater than 1." self.pixel_size = self.pixel_size * factor + self.pitch = self.pitch * factor self.resolution = (self.resolution / factor).astype(int) self.size = self.pixel_size * self.resolution self.image_shape = self.resolution diff --git a/lensless/hardware/slm.py b/lensless/hardware/slm.py new file mode 100644 index 00000000..572ae4a7 --- /dev/null +++ b/lensless/hardware/slm.py @@ -0,0 +1,298 @@ +# ############################################################################# +# slm.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +import os +import numpy as np +from lensless.hardware.utils import check_username_hostname +from lensless.utils.io import get_dtype, get_ctypes +from slm_controller.hardware import SLMParam, slm_devices +from waveprop.spherical import spherical_prop +from waveprop.color import ColorSystem +from waveprop.rs import angular_spectrum +from waveprop.slm import get_centers, get_color_filter +from waveprop.devices import SLMParam as SLMParam_wp +from scipy.ndimage import rotate as rotate_func + + +try: + import torch + from torchvision import transforms + + torch_available = True +except ImportError: + torch_available = False + + +SUPPORTED_DEVICE = { + "adafruit": "~/slm-controller/examples/adafruit_slm.py", + "nokia": "~/slm-controller/examples/nokia_slm.py", + "holoeye": "~/slm-controller/examples/holoeye_slm.py", +} + + +def set_programmable_mask(pattern, device, rpi_username, rpi_hostname): + """ + Set LCD pattern on Raspberry Pi. + + This function assumes that `slm-controller `_ + is installed on the Raspberry Pi. + + Parameters + ---------- + pattern : :py:class:`~numpy.ndarray` + Pattern to set on programmable mask. + device : str + Name of device to set pattern on. Supported devices: "adafruit", "nokia", "holoeye". + rpi_username : str + Username of Raspberry Pi. + rpi_hostname : str + Hostname of Raspberry Pi. + + """ + + client = check_username_hostname(rpi_username, rpi_hostname) + + # get path to python executable on Raspberry Pi + rpi_python = "~/slm-controller/slm_controller_env/bin/python" + assert ( + device in SUPPORTED_DEVICE.keys() + ), f"Device {device} not supported. Supported devices: {SUPPORTED_DEVICE.keys()}" + script = SUPPORTED_DEVICE[device] + + # check that pattern is correct shape + expected_shape = slm_devices[device][SLMParam.SLM_SHAPE] + if not slm_devices[device][SLMParam.MONOCHROME]: + expected_shape = (3, *expected_shape) + assert ( + pattern.shape == expected_shape + ), f"Pattern shape {pattern.shape} does not match expected shape {expected_shape}" + + # save pattern + pattern_fn = "tmp_pattern.npy" + local_path = os.path.join(os.getcwd(), pattern_fn) + np.save(local_path, pattern) + + # copy pattern to Raspberry Pi + remote_path = f"~/{pattern_fn}" + print(f"PUTTING {local_path} to {remote_path}") + + os.system('scp %s "%s@%s:%s" ' % (local_path, rpi_username, rpi_hostname, remote_path)) + # # -- not sure why this doesn't work... permission denied + # sftp = client.open_sftp() + # sftp.put(local_path, remote_path, confirm=True) + # sftp.close() + + # run script on Raspberry Pi to set mask pattern + command = f"{rpi_python} {script} --file_path {remote_path}" + print(f"COMMAND : {command}") + _stdin, _stdout, _stderr = client.exec_command(command) + print(_stdout.read().decode()) + client.close() + + os.remove(local_path) + + +def get_programmable_mask( + vals, + sensor, + slm_param, + rotate=None, + flipud=False, + nbits=8, +): + """ + Get mask as a numpy or torch array. Return same type. + + Parameters + ---------- + vals : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Values to set on programmable mask. + sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` + Sensor object. + slm_param : dict + SLM parameters. + rotate : float, optional + Rotation angle in degrees. + flipud : bool, optional + Flip mask vertically. + nbits : int, optional + Number of bits/levels to quantize mask to. + + """ + + use_torch = False + if torch_available: + use_torch = isinstance(vals, torch.Tensor) + dtype = vals.dtype + + # -- prepare SLM mask + n_active_slm_pixels = vals.shape + n_color_filter = np.prod(slm_param["color_filter"].shape[:2]) + pixel_pitch = slm_param[SLMParam_wp.PITCH] + centers = get_centers(n_active_slm_pixels, pixel_pitch=pixel_pitch) + + if SLMParam_wp.COLOR_FILTER in slm_param.keys(): + color_filter = slm_param[SLMParam_wp.COLOR_FILTER] + if flipud: + color_filter = np.flipud(color_filter) + + cf = get_color_filter( + slm_dim=n_active_slm_pixels, + color_filter=color_filter, + shift=0, + flat=True, + ) + + else: + + # monochrome + cf = None + + d1 = sensor.pitch + _height_pixel, _width_pixel = (slm_param[SLMParam_wp.CELL_SIZE] / d1).astype(int) + + if use_torch: + mask = torch.zeros((n_color_filter,) + tuple(sensor.resolution)).to(vals) + slm_vals_flat = vals.flatten() + else: + mask = np.zeros((n_color_filter,) + tuple(sensor.resolution), dtype=dtype) + slm_vals_flat = vals.reshape(-1) + + for i, _center in enumerate(centers): + + _center_pixel = (_center / d1 + sensor.resolution / 2).astype(int) + _center_top_left_pixel = ( + _center_pixel[0] - np.floor(_height_pixel / 2).astype(int), + _center_pixel[1] + 1 - np.floor(_width_pixel / 2).astype(int), + ) + + if cf is not None: + _rect = np.tile(cf[i][:, np.newaxis, np.newaxis], (1, _height_pixel, _width_pixel)) + else: + _rect = np.ones((1, _height_pixel, _width_pixel)) + + if use_torch: + _rect = torch.tensor(_rect).to(slm_vals_flat) + + mask[ + :, + _center_top_left_pixel[0] : _center_top_left_pixel[0] + _height_pixel, + _center_top_left_pixel[1] : _center_top_left_pixel[1] + _width_pixel, + ] = ( + slm_vals_flat[i] * _rect + ) + + # quantize mask + if use_torch: + mask = mask / torch.max(mask) + mask = torch.round(mask * (2**nbits - 1)) / (2**nbits - 1) + else: + mask = mask / np.max(mask) + mask = np.round(mask * (2**nbits - 1)) / (2**nbits - 1) + + # rotate + if rotate is not None: + if use_torch: + mask = transforms.functional.rotate(mask, angle=rotate) + else: + mask = rotate_func(mask, axes=(2, 1), angle=rotate, reshape=False) + + return mask + + +def get_intensity_psf( + mask, + waveprop=False, + sensor=None, + scene2mask=None, + mask2sensor=None, + color_system=None, +): + """ + Get intensity PSF from mask pattern. Return same type of data. + + Parameters + ---------- + mask : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Mask pattern. + waveprop : bool, optional + Whether to use wave propagation to compute PSF. Default is False, + namely to return squared intensity of mask pattern as the PSF (i.e., + no wave propagation and just shadow of pattern). + sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` + Sensor object. Not used if ``waveprop=False``. + scene2mask : float + Distance from scene to mask. Not used if ``waveprop=False``. + mask2sensor : float + Distance from mask to sensor. Not used if ``waveprop=False``. + color_system : :py:class:`~waveprop.color.ColorSystem`, optional + Color system. Not used if ``waveprop=False``. + + """ + if color_system is None: + color_system = ColorSystem.rgb() + + is_torch = False + device = None + if torch_available: + is_torch = isinstance(mask, torch.Tensor) + device = mask.device + + dtype = mask.dtype + ctype, _ = get_ctypes(dtype, is_torch) + + if is_torch: + psfs = torch.zeros(mask.shape, dtype=ctype, device=device) + else: + psfs = np.zeros(mask.shape, dtype=ctype) + + if waveprop: + + assert sensor is not None, "sensor must be specified" + assert scene2mask is not None, "scene2mask must be specified" + assert mask2sensor is not None, "mask2sensor must be specified" + + assert ( + len(color_system.wv) == mask.shape[0] + ), "Number of wavelengths must match number of color channels" + + # spherical wavefronts to mask + spherical_wavefront = spherical_prop( + in_shape=sensor.resolution, + d1=sensor.pitch, + wv=color_system.wv, + dz=scene2mask, + return_psf=True, + is_torch=True, + device=device, + dtype=dtype, + ) + u_in = spherical_wavefront * mask + + # free space propagation to sensor + for i, wv in enumerate(color_system.wv): + psfs[i], _, _ = angular_spectrum( + u_in=u_in[i], + wv=wv, + d1=sensor.pitch, + dz=mask2sensor, + dtype=dtype, + device=device, + ) + + else: + + psfs = mask + + # -- intensity PSF + if is_torch: + psf_in = torch.square(torch.abs(psfs)) + else: + psf_in = np.square(np.abs(psfs)) + + return psf_in diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py new file mode 100644 index 00000000..9bc70bc8 --- /dev/null +++ b/lensless/hardware/trainable_mask.py @@ -0,0 +1,102 @@ +# ############################################################################# +# trainable_mask.py +# ================== +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + +import abc +import torch +from lensless.utils.image import is_grayscale + + +class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): + """ + Abstract class for defining trainable masks. + + The following abstract methods need to be defined: + + - :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`: returning the PSF of the mask. + - :py:class:`~lensless.hardware.trainable_mask.TrainableMask.project`: projecting the mask parameters to a valid space (should be a subspace of [0,1]). + + """ + + def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): + """ + Base constructor. Derived constructor may define new state variables + + Parameters + ---------- + initial_mask : :py:class:`~torch.Tensor` + Initial mask parameters. + optimizer : str, optional + Optimizer to use for updating the mask parameters, by default "Adam" + lr : float, optional + Learning rate for the mask parameters, by default 1e-3 + """ + super().__init__() + self._mask = torch.nn.Parameter(initial_mask) + self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs) + self._counter = 0 + + @abc.abstractmethod + def get_psf(self): + """ + Abstract method for getting the PSF of the mask. Should be fully compatible with pytorch autograd. + + Returns + ------- + :py:class:`~torch.Tensor` + The PSF of the mask. + """ + raise NotImplementedError + + def update_mask(self): + """Update the mask parameters. Acoording to externaly updated gradiants.""" + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self.project() + self._counter += 1 + + @abc.abstractmethod + def project(self): + """Abstract method for projecting the mask parameters to a valid space (should be a subspace of [0,1]).""" + raise NotImplementedError + + +class TrainablePSF(TrainableMask): + """ + Class for defining an object that directly optimizes the PSF, without any constraints on what can be realized physically. + + Parameters + ---------- + grayscale : bool, optional + Whether mask should be returned as grayscale when calling :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`. + Otherwise PSF will be returned as RGB. By default False. + """ + + def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs): + super().__init__(initial_mask, optimizer, lr, **kwargs) + assert ( + len(initial_mask.shape) == 4 + ), "Mask must be of shape (depth, height, width, channels)" + self.grayscale = grayscale + self._is_grayscale = is_grayscale(initial_mask) + if grayscale: + assert self._is_grayscale, "Mask must be grayscale" + + def get_psf(self): + if self._is_grayscale: + if self.grayscale: + # simulation in grayscale + return self._mask + else: + # replicate to 3 channels + return self._mask.expand(-1, -1, -1, 3) + else: + # assume RGB + return self._mask + + def project(self): + self._mask.data = torch.clamp(self._mask, 0, 1) diff --git a/lensless/hardware/utils.py b/lensless/hardware/utils.py index a0c0d573..97b384f6 100644 --- a/lensless/hardware/utils.py +++ b/lensless/hardware/utils.py @@ -2,6 +2,7 @@ import os import socket import subprocess +import time import paramiko from paramiko.ssh_exception import AuthenticationException, BadHostKeyException, SSHException @@ -65,7 +66,7 @@ def check_username_hostname(username, hostname, timeout=10): except (BadHostKeyException, AuthenticationException, SSHException, socket.error) as e: raise ValueError(f"Could not connect to {username}@{hostname}\n{e}") - return username, hostname + return client def get_distro(): @@ -92,3 +93,62 @@ def get_distro(): # Just major version shown, replace it with the full version RELEASE_DATA["VERSION"] = " ".join([DEBIAN_VERSION] + version_split[1:]) return f"{RELEASE_DATA['NAME']} {RELEASE_DATA['VERSION']}" + + +def set_mask_sensor_distance(distance, rpi_username, rpi_hostname, motor=1): + """ + Set the distance between the mask and sensor. + + This functions assumes that `StepperDriver `_ is installed. + is downloaded on the Raspberry Pi. + + Parameters + ---------- + distance : float + Distance in mm. Positive values move the mask away from the sensor. + rpi_username : str + Username of Raspberry Pi. + rpi_hostname : str + Hostname of Raspberry Pi. + """ + + MAX_DISTANCE = 16 # mm + timeout = 5 + + client = check_username_hostname(rpi_username, rpi_hostname) + assert motor in [0, 1] + assert distance >= 0, "Distance must be non-negative" + assert distance < MAX_DISTANCE, f"Distance must be less than {MAX_DISTANCE} mm" + + # assumes that `StepperDriver` is in home directory + rpi_python = "python3" + script = "~/StepperDriver/Python/serial_motors.py" + + # reset to zero + print("Resetting to zero distance...") + try: + command = f"{rpi_python} {script} {motor} REV {MAX_DISTANCE * 1000}" + _stdin, _stdout, _stderr = client.exec_command(command, timeout=timeout) + except socket.timeout: # socket.timeout + pass + + client.close() + time.sleep(5) # TODO reduce this time + client = check_username_hostname(rpi_username, rpi_hostname) + + # set to desired distance + if distance != 0: + print(f"Setting distance to {distance} mm...") + distance_um = distance * 1000 + if distance_um >= 0: + command = f"{rpi_python} {script} {motor} FWD {distance_um}" + else: + command = f"{rpi_python} {script} {motor} REV {-1 * distance_um}" + print(f"COMMAND : {command}") + try: + _stdin, _stdout, _stderr = client.exec_command(command, timeout=timeout) + print(_stdout.read().decode()) + except socket.timeout: # socket.timeout + client.close() + + client.close() diff --git a/lensless/recon/apgd.py b/lensless/recon/apgd.py index 2ae5a69d..327c32de 100644 --- a/lensless/recon/apgd.py +++ b/lensless/recon/apgd.py @@ -11,7 +11,9 @@ import inspect import numpy as np from typing import Optional +from lensless.utils.image import resize from lensless.recon.rfft_convolve import RealFFTConvolve2D as Convolver +import cv2 import pycsou.abc as pyca import pycsou.operator.func as func @@ -20,6 +22,7 @@ import pycsou.runtime as pycrt import pycsou.util as pycu import pycsou.util.ptype as pyct +import pycsou.operator.linop as pycl class APGDPriors: @@ -95,6 +98,7 @@ def __init__( rel_error=None, lipschitz_tight=True, lipschitz_tol=1.0, + img_shape=None, **kwargs ): """ @@ -132,27 +136,52 @@ def __init__( Whether to use tight Lipschitz constant or not. Default is True. lipschitz_tol : float, optional Tolerance to compute Lipschitz constant. Default is 1. + img_shape : tuple, optional + Shape of measurement (H, W, C). If None, assume shape of PSF. """ assert isinstance(psf, np.ndarray), "PSF must be a numpy array" - # PSF and data are the same size / shape self._original_shape = psf.shape - self._original_size = psf.size - self._apgd = None - self._gen = None - - super(APGD, self).__init__(psf, dtype, n_iter=max_iter, **kwargs) self._stop_crit = stop.MaxIter(max_iter) if rel_error is not None: self._stop_crit = self._stop_crit | stop.RelError(eps=rel_error) self._disp = disp - # Convolution operator + # Convolution (and optional downsampling) operator + if img_shape is not None: + + meas_shape = np.array(img_shape[:2]) + rec_shape = np.array(self._original_shape[1:3]) + assert np.all(meas_shape <= rec_shape), "Image shape must be smaller than PSF shape" + self.downsampling_factor = np.round(rec_shape / meas_shape).astype(int) + + # new PSF shape, must be integer multiple of image shape + new_shape = tuple(np.array(meas_shape) * self.downsampling_factor) + (psf.shape[-1],) + psf_re = resize(psf.copy(), shape=new_shape, interpolation=cv2.INTER_CUBIC) + + # combine operations + conv = RealFFTConvolve2D(psf_re, dtype=dtype) + ds = pycl.SubSample( + psf_re.shape, + slice(None), + slice(0, -1, self.downsampling_factor[0]), + slice(0, -1, self.downsampling_factor[1]), + slice(None), + ) + + self._H = ds * conv + + super(APGD, self).__init__(psf_re, dtype, n_iter=max_iter, **kwargs) + + else: + self.downsampling_factor = 1 + self._H = RealFFTConvolve2D(psf, dtype=dtype) + + super(APGD, self).__init__(psf, dtype, n_iter=max_iter, **kwargs) - self._H = RealFFTConvolve2D(self._psf, dtype=dtype) self._H.lipschitz(tol=lipschitz_tol, tight=lipschitz_tight) # initialize solvers which will be created when data is set @@ -192,9 +221,25 @@ def set_data(self, data): 3D (RGB). """ - super(APGD, self).set_data( - np.repeat(data, self._original_shape[-4], axis=0) - ) # we repeat the data for each depth to match the size of the PSF + + # super(APGD, self).set_data( + # np.repeat(data, self._original_shape[-4], axis=0) + # ) # we repeat the data for each depth to match the size of the PSF + + data = np.repeat(data, self._original_shape[-4], axis=0) # repeat for each depth + assert isinstance(data, np.ndarray) + assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]." + + assert np.all( + self._psf_shape[-3:-1] == (np.array(data.shape)[-3:-1] * self.downsampling_factor) + ), "PSF and data shape mismatch" + + if len(data.shape) == 3: + self._data = data[None, None, ...] + elif len(data.shape) == 4: + self._data = data[None, ...] + else: + self._data = data """ Set up problem """ # Cost function @@ -220,13 +265,15 @@ def reset(self): if self._initial_est is not None: self._image_est = self._initial_est else: - self._image_est = np.zeros(self._original_size, dtype=self._dtype) + self._image_est = np.zeros(np.prod(self._psf_shape), dtype=self._dtype) def _update(self, iter): res = next(self._apgd.steps()) self._image_est[:] = res["x"] def _form_image(self): - image = self._image_est.reshape(self._original_shape) + image = self._image_est.reshape(self._psf_shape) image[image < 0] = 0 + if np.any(self._psf_shape != self._original_shape): + image = resize(image, shape=self._original_shape) return image diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 58200f2a..444e3b0a 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -10,7 +10,7 @@ ============== The core algorithmic component of ``LenslessPiCam`` is the abstract -class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction +class :py:class:`~lensless.ReconstructionAlgorithm`. The five reconstruction strategies available in ``LenslessPiCam`` derive from this class: - :py:class:`~lensless.GradientDescent`: projected gradient descent with a @@ -25,6 +25,14 @@ class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction long as it is compatible with Pycsou, namely derives from one of `DiffFunc `_ or `ProxFunc `_. +- :py:class:`~lensless.UnrolledFISTA`: unrolled FISTA with a non-negativity constraint. +- :py:class:`~lensless.UnrolledADMM`: unrolled ADMM with a non-negativity constraint and a total variation (TV) regularizer [1]_. + +Note that the unrolled algorithms derive from the abstract class +:py:class:`~lensless.TrainableReconstructionAlgorithm`, which itself derives from +:py:class:`~lensless.ReconstructionAlgorithm` while adding functionality +for training on batches and adding trainable pre- and post-processing +blocks. New reconstruction algorithms can be conveniently implemented by deriving from the abstract class and defining the following abstract @@ -154,6 +162,7 @@ class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction import pathlib as plib import matplotlib.pyplot as plt from lensless.utils.plot import plot_image +from lensless.utils.io import get_dtype from lensless.recon.rfft_convolve import RealFFTConvolve2D try: @@ -232,16 +241,7 @@ def __init__( self._psf_shape = np.array(self._psf.shape) # set dtype - if dtype is None: - if self.is_torch: - dtype = torch.float32 - else: - dtype = np.float32 - else: - if self.is_torch: - dtype = torch.float32 if dtype == "float32" else torch.float64 - else: - dtype = np.float32 if dtype == "float32" else np.float64 + dtype = get_dtype(dtype, self.is_torch) if self.is_torch: if dtype: @@ -404,6 +404,28 @@ def get_image_estimate(self): """Get current image estimate as [Batch, Depth, Height, Width, Channels].""" return self._form_image() + def _set_psf(self, psf): + """ + Set PSF. + + Parameters + ---------- + psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + PSF to set. + """ + assert len(psf.shape) == 4, "PSF must be 4D: (depth, height, width, channels)." + assert psf.shape[3] == 3 or psf.shape[3] == 1, "PSF must either be rgb (3) or grayscale (1)" + assert self._psf.shape == psf.shape, "new PSF must have same shape as old PSF" + assert isinstance(psf, type(self._psf)), "new PSF must have same type as old PSF" + + self._psf = psf + self._convolver = RealFFTConvolve2D( + psf, + dtype=self._convolver._psf.dtype, + pad=self._convolver.pad, + norm=self._convolver.norm, + ) + def _progress(self): """ Optional method for printing progress update, e.g. relative improvement @@ -491,7 +513,9 @@ def apply( if (plot or save) and disp_iter is not None: if ax is None: - ax = plot_image(self._get_numpy_data(self._image_est[0]), gamma=gamma) + img = self._form_image() + ax = plot_image(self._get_numpy_data(img[0]), gamma=gamma) + else: ax = None disp_iter = n_iter + 1 diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index 5c867cd3..34cca96a 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -57,6 +57,9 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs): self._is_rgb = psf.shape[3] == 3 assert self._is_rgb or psf.shape[3] == 1 + # save normalization + self.norm = norm + # set dtype if dtype is None: if self.is_torch: diff --git a/lensless/recon/tikhonov.py b/lensless/recon/tikhonov.py index 84a88011..fb9a182d 100644 --- a/lensless/recon/tikhonov.py +++ b/lensless/recon/tikhonov.py @@ -2,8 +2,8 @@ # tikhonov.py # ================= # Authors : -# Aaron FARGEON [aa.fargeon@gmail.com] # Eric BEZZAM [ebezzam@gmail.com] +# Aaron FARGEON [aa.fargeon@gmail.com] # ############################################################################# """ @@ -20,6 +20,13 @@ import numpy as np from numpy.linalg import multi_dot +try: + import torch + + torch_available = True +except ImportError: + torch_available = False + class CodedApertureReconstruction: """ @@ -32,7 +39,7 @@ def __init__(self, mask, image_shape, P=None, Q=None, lmbd=3e-4): """ Parameters ---------- - mask : py:class:`~lensless.hardware.mask.CodedAperture` + mask : py:class:`lensless.hardware.mask.CodedAperture` Coded aperture mask object. image_shape : (`array-like` or `tuple`) The shape of the image to reconstruct. @@ -67,46 +74,97 @@ def apply(self, img): Parameters ---------- - img : :py:class:`~numpy.ndarray` + img : :py:class:`~numpy.ndarray` or :py:class:`torch.Tensor` Lensless capture measurement. Must be 3D even if grayscale. Returns ------- - :py:class:`~numpy.ndarray` + :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` Reconstructed image, in the same format as the measurement. """ - assert len(img.shape) == 3, "Object should be a 3D array (HxWxC) even if grayscale." - - # Empty matrix for reconstruction - n_channels = img.shape[-1] - x_est = np.empty([self.P.shape[1], self.Q.shape[1], n_channels]) - - # Applying reconstruction for each channel - for c in range(n_channels): - - # SVD of left matrix - UL, SL, VLh = np.linalg.svd(self.P, full_matrices=True) - VL = VLh.T - DL = np.concatenate((np.diag(SL), np.zeros([self.P.shape[0] - SL.size, SL.size]))) - singLsq = np.square(SL) - - # SVD of right matrix - UR, SR, VRh = np.linalg.svd(self.Q, full_matrices=True) - VR = VRh.T - DR = np.concatenate((np.diag(SR), np.zeros([self.Q.shape[0] - SR.size, SR.size]))) - singRsq = np.square(SR) - - # Applying analytical reconstruction - Yc = img[:, :, c] - inner = multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( - np.outer(singLsq, singRsq) + np.full(x_est.shape[0:2], self.lmbd) - ) - x_est[:, :, c] = multi_dot([VL, inner, VR.T]) - - # Non-negativity constraint: setting all negative values to 0 - x_est = x_est.clip(min=0) - - # Normalizing the image - x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) + assert ( + len(img.shape) == 3 + ), "Object should be a 3D array or tensor (HxWxC) even if grayscale." + + if torch_available and isinstance(img, torch.Tensor): + + # Empty matrix for reconstruction + n_channels = img.shape[-1] + x_est = torch.empty([self.P.shape[1], self.Q.shape[1], n_channels]) + + self.P = torch.from_numpy(self.P).float() + self.Q = torch.from_numpy(self.Q).float() + + # Applying reconstruction for each channel + for c in range(n_channels): + Yc = img[:, :, c] + + # SVD of left matrix + UL, SL, VLh = torch.linalg.svd(self.P) + VL = VLh.T + DL = torch.cat( + ( + torch.diag(SL), + torch.zeros([self.P.shape[0] - SL.size(0), SL.size(0)], device=SL.device), + ) + ) + singLsq = SL**2 + + # SVD of right matrix + UR, SR, VRh = torch.linalg.svd(self.Q) + VR = VRh.T + DR = torch.cat( + ( + torch.diag(SR), + torch.zeros([self.Q.shape[0] - SR.size(0), SR.size(0)], device=SR.device), + ) + ) + singRsq = SR**2 + + # Applying analytical reconstruction + inner = torch.linalg.multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( + torch.outer(singLsq, singRsq) + torch.full(x_est.shape[0:2], self.lmbd) + ) + x_est[:, :, c] = torch.linalg.multi_dot([VL, inner, VR.T]) + + # Non-negativity constraint: setting all negative values to 0 + x_est = torch.clamp(x_est, min=0) + + # Normalizing the image + x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) + + else: + + # Empty matrix for reconstruction + n_channels = img.shape[-1] + x_est = np.empty([self.P.shape[1], self.Q.shape[1], n_channels]) + + # Applying reconstruction for each channel + for c in range(n_channels): + + # SVD of left matrix + UL, SL, VLh = np.linalg.svd(self.P, full_matrices=True) + VL = VLh.T + DL = np.concatenate((np.diag(SL), np.zeros([self.P.shape[0] - SL.size, SL.size]))) + singLsq = np.square(SL) + + # SVD of right matrix + UR, SR, VRh = np.linalg.svd(self.Q, full_matrices=True) + VR = VRh.T + DR = np.concatenate((np.diag(SR), np.zeros([self.Q.shape[0] - SR.size, SR.size]))) + singRsq = np.square(SR) + + # Applying analytical reconstruction + Yc = img[:, :, c] + inner = multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( + np.outer(singLsq, singRsq) + np.full(x_est.shape[0:2], self.lmbd) + ) + x_est[:, :, c] = multi_dot([VL, inner, VR.T]) + + # Non-negativity constraint: setting all negative values to 0 + x_est = x_est.clip(min=0) + + # Normalizing the image + x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) return x_est diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index c7129a3b..82fd883d 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -5,8 +5,10 @@ # Yohann PERRON [yohann.perron@gmail.com] # ############################################################################# -import abc +import pathlib as plib +from matplotlib import pyplot as plt from lensless.recon.recon import ReconstructionAlgorithm +from lensless.utils.plot import plot_image try: import torch @@ -24,7 +26,6 @@ class TrainableReconstructionAlgorithm(ReconstructionAlgorithm, torch.nn.Module) * ``_update``: updating state variables at each iterations. * ``reset``: reset state variables. * ``_form_image``: any pre-processing that needs to be done in order to view the image estimate, e.g. reshaping or clipping. - * ``batch_call``: method for performing iterative reconstruction on a batch of images. One advantage of deriving from this abstract class is that functionality for iterating, saving, and visualization is already implemented, namely in the @@ -155,7 +156,15 @@ def batch_call(self, batch): return image_est def apply( - self, disp_iter=10, plot_pause=0.2, plot=True, save=False, gamma=None, ax=None, reset=True + self, + disp_iter=10, + plot_pause=0.2, + plot=True, + save=False, + gamma=None, + ax=None, + reset=True, + output_intermediate=False, ): """ Method for performing iterative reconstruction. Contrary to non-trainable reconstruction @@ -180,6 +189,8 @@ def apply( Gamma correction factor to apply for plots. Default is None. ax : :py:class:`~matplotlib.axes.Axes`, optional `Axes` object to fill for plotting/saving, default is to create one. + output_intermediate : bool, optional + Whether to output intermediate reconstructions after preprocessing and before postprocessing. Returns ------- @@ -190,8 +201,11 @@ def apply( returning if `plot` or `save` is True. """ + pre_processed_image = None if self.pre_process is not None: self._data = self.pre_process(self._data, self.pre_process_param) + if output_intermediate: + pre_processed_image = self._data[0, ...].clone() im = super(TrainableReconstructionAlgorithm, self).apply( n_iter=self._n_iter, @@ -203,6 +217,30 @@ def apply( ax=ax, reset=reset, ) + + # remove plot if returned + if plot: + im, _ = im + + # post process data + pre_post_process_image = None if self.post_process is not None: + # apply post process + if output_intermediate: + pre_post_process_image = im.clone() im = self.post_process(im, self.post_process_param) - return im + + if plot: + ax = plot_image(self._get_numpy_data(im[0]), ax=ax, gamma=gamma) + ax.set_title( + "Final reconstruction after {} iterations and post process".format(self._n_iter) + ) + if save: + plt.savefig(plib.Path(save) / "final.png") + + if output_intermediate: + return im, pre_processed_image, pre_post_process_image + elif plot: + return im, ax + else: + return im diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 7fad0400..2409dd80 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -1,15 +1,36 @@ +# ############################################################################# +# utils.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +import json +import math +import numpy as np +import matplotlib.pyplot as plt +import time +from hydra.utils import get_original_cwd +import os import torch +from lensless.eval.benchmark import benchmark +from lensless.hardware.trainable_mask import TrainableMask +from tqdm import tqdm from lensless.recon.drunet.network_unet import UNetRes +from lensless.utils.io import save_image +from lensless.utils.plot import plot_image -def load_drunet(model_path, n_channels=3, requires_grad=False): +def load_drunet(model_path=None, n_channels=3, requires_grad=False): """ Load a pre-trained Drunet model. Parameters ---------- - model_path : str - Path to pre-trained model. + model_path : str, optional + Path to pre-trained model. Download if not provided. n_channels : int Number of channels in input image. requires_grad : bool @@ -17,10 +38,29 @@ def load_drunet(model_path, n_channels=3, requires_grad=False): Returns ------- - model : :py:class:`~torch.nn.Module` + model : :py:class:`torch.nn.Module` Loaded model. """ + if model_path is None: + model_path = os.path.join(get_original_cwd(), "models", "drunet_color.pth") + if not os.path.exists(model_path): + try: + from torchvision.datasets.utils import download_url + except ImportError: + exit() + msg = "Do you want to download the pretrained DRUNet model (130MB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + output_path = os.path.join(get_original_cwd(), "models") + if valid: + url = "https://drive.switch.ch/index.php/s/jTdeMHom025RFRQ/download" + filename = "drunet_color.pth" + download_url(url, output_path, filename=filename) + + assert os.path.exists(model_path), f"Model path {model_path} does not exist" + model = UNetRes( in_nc=n_channels + 1, out_nc=n_channels, @@ -45,11 +85,11 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference") Parameters ---------- - model : :py:class:`~torch.nn.Module` + model : :py:class:`torch.nn.Module` Drunet compatible model. Its input must consist of 4 channels (RGB + noise level) and output an RGB image both in CHW format. - image : :py:class:`~torch.Tensor` + image : :py:class:`torch.Tensor` Input image. - noise_level : float or :py:class:`~torch.Tensor` + noise_level : float or :py:class:`torch.Tensor` Noise level in the image. device : str Device to use for computation. Can be "cpu" or "cuda". @@ -58,7 +98,7 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference") Returns ------- - image : :py:class:`~torch.Tensor` + image : :py:class:`torch.Tensor` Reconstructed image. """ # convert from NDHWC to NCHW @@ -108,7 +148,7 @@ def get_drunet_function(model, device="cpu", mode="inference"): Parameters ---------- - model : torch.nn.Module + model : :py:class:`torch.nn.Module` DruNet like denoiser model device : str Device to use for computation. Can be "cpu" or "cuda". @@ -129,3 +169,479 @@ def process(image, noise_level): return image return process + + +def measure_gradient(model): + """ + Helper function to measure L2 norm of the gradient of a model. + + Parameters + ---------- + model : :py:class:`torch.nn.Module` + Model to measure gradient of. + + Returns + ------- + Float + L2 norm of the gradient of the model. + """ + total_norm = 0.0 + for p in model.parameters(): + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm**0.5 + return total_norm + + +def create_process_network(network, depth, device="cpu"): + """ + Helper function to create a process network. + + Parameters + ---------- + network : str + Name of network to use. Can be "DruNet" or "UnetRes". + depth : int + Depth of network. + device : str + Device to use for computation. Can be "cpu" or "cuda". Defaults to "cpu". + + Returns + ------- + :py:class:`torch.nn.Module` + New process network. Already trained for Drunet. + """ + if network == "DruNet": + from lensless.recon.utils import load_drunet + + process = load_drunet(requires_grad=True).to(device) + process_name = "DruNet" + elif network == "UnetRes": + from lensless.recon.drunet.network_unet import UNetRes + + n_channels = 3 + process = UNetRes( + in_nc=n_channels + 1, + out_nc=n_channels, + nc=[64, 128, 256, 512], + nb=depth, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ).to(device) + process_name = "UnetRes_d" + str(depth) + else: + process = None + process_name = None + + return (process, process_name) + + +class Trainer: + def __init__( + self, + recon, + train_dataset, + test_dataset, + test_size=0.15, + mask=None, + batch_size=4, + loss="l2", + lpips=None, + l1_mask=None, + optimizer="Adam", + optimizer_lr=1e-6, + slow_start=None, + skip_NAN=False, + algorithm_name="Unknown", + metric_for_best_model=None, + save_every=None, + gamma=None, + ): + """ + Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. + + The train and test metrics at the end of each epoch can be found in ``self.metrics``, + with "LOSS" being the train loss. The test loss can be found in "MSE" (if loss is "l2") or + "MAE" (if loss is "l1"). If ``lpips`` is not None, the LPIPS loss is also added + to the train loss, such that the test loss can be computed as "MSE" + ``lpips`` * "LPIPS_Vgg" + (or "MAE" + ``lpips`` * "LPIPS_Vgg"). + + Parameters + ---------- + recon : :py:class:`lensless.TrainableReconstructionAlgorithm` + Reconstruction algorithm to train. + train_dataset : :py:class:`torch.utils.data.Dataset` + Dataset to use for training. + test_dataset : :py:class:`torch.utils.data.Dataset` + Dataset to use for testing. + test_size : float, optional + If test_dataset is None, fraction of the train dataset to use for testing, by default 0.15. + mask : TrainableMask, optional + Trainable mask to use for training. If none, training with fix psf, by default None. + batch_size : int, optional + Batch size to use for training, by default 4. + loss : str, optional + Loss function to use for training "l1" or "l2", by default "l2". + lpips : float, optional + the weight of the lpips(VGG) in the total loss. If None ignore. By default None. + l1_mask : float, optional + the weight of the l1 norm of the mask in the total loss. If None ignore. By default None. + optimizer : str, optional + Optimizer to use durring training. Available : "Adam". By default "Adam". + optimizer_lr : float, optional + Learning rate for the optimizer, by default 1e-6. + slow_start : float, optional + Multiplicative factor to reduce the learning rate during the first two epochs. If None, ignored. Default is None. + skip_NAN : bool, optional + Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False + algorithm_name : str, optional + Algorithm name for logging, by default "Unknown". + metric_for_best_model : str, optional + Metric to use for saving the best model. If None, will default to evaluation loss. Default is None. + save_every : int, optional + Save model every ``save_every`` epochs. If None, just save best model. + gamma : float, optional + Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None. + + + """ + self.device = recon._psf.device + + self.recon = recon + + assert train_dataset is not None + if test_dataset is None: + assert test_size < 1.0 and test_size > 0.0 + # split train dataset + train_size = int((1 - test_size) * len(train_dataset)) + test_size = len(train_dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split( + train_dataset, [train_size, test_size] + ) + print(f"Train size : {train_size}, Test size : {test_size}") + + self.train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=(self.device != "cpu"), + ) + self.test_dataset = test_dataset + self.lpips = lpips + self.skip_NAN = skip_NAN + + if mask is not None: + assert isinstance(mask, TrainableMask) + self.mask = mask + self.use_mask = True + else: + self.use_mask = False + + self.l1_mask = l1_mask + self.gamma = gamma + + # loss + if loss == "l2": + self.Loss = torch.nn.MSELoss() + elif loss == "l1": + self.Loss = torch.nn.L1Loss() + else: + raise ValueError(f"Unsuported loss : {loss}") + + # Lpips loss + if lpips: + try: + import lpips + + self.Loss_lpips = lpips.LPIPS(net="vgg").to(self.device) + except ImportError: + return ImportError( + "lpips package is need for LPIPS loss. Install using : pip install lpips" + ) + + # optimizer + if optimizer == "Adam": + # the parameters of the base model and non torch.Module process must be added separatly + parameters = [{"params": recon.parameters()}] + self.optimizer = torch.optim.Adam(parameters, lr=optimizer_lr) + else: + raise ValueError(f"Unsuported optimizer : {optimizer}") + # Scheduler + if slow_start: + + def learning_rate_function(epoch): + if epoch == 0: + return slow_start + elif epoch == 1: + return math.sqrt(slow_start) + else: + return 1 + + else: + + def learning_rate_function(epoch): + return 1 + + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=learning_rate_function + ) + + self.metrics = { + "LOSS": [], # train loss + "MSE": [], + "MAE": [], + "LPIPS_Vgg": [], + "LPIPS_Alex": [], + "PSNR": [], + "SSIM": [], + "ReconstructionError": [], + "n_iter": self.recon._n_iter, + "algorithm": algorithm_name, + "metric_for_best_model": metric_for_best_model, + "best_epoch": 0, + "best_eval_score": 0 + if metric_for_best_model == "PSNR" or metric_for_best_model == "SSIM" + else np.inf, + } + if metric_for_best_model is not None: + assert metric_for_best_model in self.metrics.keys() + self.save_every = save_every + + # Backward hook that detect NAN in the gradient and print the layer weights + if not self.skip_NAN: + + def detect_nan(grad): + if torch.isnan(grad).any(): + print(grad, flush=True) + for name, param in recon.named_parameters(): + if param.requires_grad: + print(name, param) + raise ValueError("Gradient is NaN") + return grad + + for param in recon.parameters(): + if param.requires_grad: + param.register_hook(detect_nan) + if param.requires_grad: + param.register_hook(detect_nan) + + def train_epoch(self, data_loader, disp=-1): + """ + Train for one epoch. + + Parameters + ---------- + data_loader : :py:class:`torch.utils.data.DataLoader` + Data loader to use for training. + disp : int + Display interval, if -1, no display + + Returns + ------- + float + Mean loss of the epoch. + """ + mean_loss = 0.0 + i = 1.0 + pbar = tqdm(data_loader) + for X, y in pbar: + # send to device + X = X.to(self.device) + y = y.to(self.device) + + # update psf according to mask + if self.use_mask: + self.recon._set_psf(self.mask.get_psf()) + + # forward pass + y_pred = self.recon.batch_call(X.to(self.device)) + # normalizing each output + eps = 1e-12 + y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps + y_pred = y_pred / y_pred_max + + # normalizing y + y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps + y = y / y_max + + if i % disp == 1: + img_pred = y_pred[0, 0].cpu().detach().numpy() + img_truth = y[0, 0].cpu().detach().numpy() + + plt.imshow(img_pred) + plt.savefig(f"y_pred_{i-1}.png") + plt.imshow(img_truth) + plt.savefig(f"y_{i-1}.png") + + self.optimizer.zero_grad(set_to_none=True) + # convert to CHW for loss and remove depth + y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) + y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) + + loss_v = self.Loss(y_pred, y) + if self.lpips: + + if y_pred.shape[1] == 1: + # if only one channel, repeat for LPIPS + y_pred = y_pred.repeat(1, 3, 1, 1) + y = y.repeat(1, 3, 1, 1) + + # value for LPIPS needs to be in range [-1, 1] + loss_v = loss_v + self.lpips * torch.mean( + self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) + ) + if self.use_mask and self.l1_mask: + loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask)) + loss_v.backward() + + torch.nn.utils.clip_grad_norm_(self.recon.parameters(), 1.0) + + # if any gradient is NaN, skip training step + if self.skip_NAN: + is_NAN = False + for param in self.recon.parameters(): + if torch.isnan(param.grad).any(): + is_NAN = True + break + if is_NAN: + print("NAN detected in gradiant, skipping training step") + i += 1 + continue + self.optimizer.step() + + # update mask + if self.use_mask: + self.mask.update_mask() + + mean_loss += (loss_v.item() - mean_loss) * (1 / i) + pbar.set_description(f"loss : {mean_loss}") + i += 1 + + return mean_loss + + def evaluate(self, mean_loss, save_pt): + """ + Evaluate the reconstruction algorithm on the test dataset. + + Parameters + ---------- + mean_loss : float + Mean loss of the last epoch. + save_pt : str + Path to save metrics dictionary to. If None, no logging of metrics. + """ + if self.test_dataset is None: + return + # benchmarking + current_metrics = benchmark(self.recon, self.test_dataset, batchsize=10) + + # update metrics with current metrics + self.metrics["LOSS"].append(mean_loss) + for key in current_metrics: + self.metrics[key].append(current_metrics[key]) + + if save_pt: + # save dictionary metrics to file with json + with open(os.path.join(save_pt, "metrics.json"), "w") as f: + json.dump(self.metrics, f) + + # check best metric + if self.metrics["metric_for_best_model"] is None: + eval_loss = current_metrics["MSE"] + if self.lpips is not None: + eval_loss += self.lpips * current_metrics["LPIPS_Vgg"] + if self.use_mask and self.l1_mask: + eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy())) + return eval_loss + else: + return current_metrics[self.metrics["metric_for_best_model"]] + + def on_epoch_end(self, mean_loss, save_pt, epoch): + """ + Called at the end of each epoch. + + Parameters + ---------- + mean_loss : float + Mean loss of the last epoch. + save_pt : str + Path to save metrics dictionary to. If None, no logging of metrics. + epoch : int + Current epoch. + """ + if save_pt is None: + # Use current directory + save_pt = os.getcwd() + + # save model + # self.save(path=save_pt, include_optimizer=False) + epoch_eval_metric = self.evaluate(mean_loss, save_pt) + new_best = False + if ( + self.metrics["metric_for_best_model"] == "PSNR" + or self.metrics["metric_for_best_model"] == "SSIM" + ): + if epoch_eval_metric > self.metrics["best_eval_score"]: + self.metrics["best_eval_score"] = epoch_eval_metric + new_best = True + else: + if epoch_eval_metric < self.metrics["best_eval_score"]: + self.metrics["best_eval_score"] = epoch_eval_metric + new_best = True + + if new_best: + self.metrics["best_epoch"] = epoch + self.save(path=save_pt, include_optimizer=False, epoch="BEST") + + if self.save_every is not None and epoch % self.save_every == 0: + self.save(path=save_pt, include_optimizer=False, epoch=epoch) + + def train(self, n_epoch=1, save_pt=None, disp=-1): + """ + Train the reconstruction algorithm. + + Parameters + ---------- + n_epoch : int, optional + Number of epochs to train for, by default 1 + save_pt : str, optional + Path to save metrics dictionary to. If None, use current directory, by default None + disp : int, optional + Display interval, if -1, no display. Default is -1. + """ + + start_time = time.time() + + self.evaluate(-1, save_pt) + for epoch in range(n_epoch): + print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") + mean_loss = self.train_epoch(self.train_dataloader, disp=disp) + # offset because of evaluate before loop + self.on_epoch_end(mean_loss, save_pt, epoch + 1) + self.scheduler.step() + + print(f"Train time : {time.time() - start_time} s") + + def save(self, epoch, path="recon", include_optimizer=False): + # create directory if it does not exist + if not os.path.exists(path): + os.makedirs(path) + # save mask + if self.use_mask: + torch.save(self.mask._mask, os.path.join(path, f"mask_epoch{epoch}.pt")) + torch.save( + self.mask._optimizer.state_dict(), os.path.join(path, f"mask_optim_epoch{epoch}.pt") + ) + + psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] + psf_np = psf_np.squeeze() # remove (potential) singleton color channel + save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png")) + plot_image(psf_np, gamma=self.gamma) + plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png")) + + # save optimizer + if include_optimizer: + torch.save(self.optimizer.state_dict(), os.path.join(path, f"optim_epoch{epoch}.pt")) + # save recon + torch.save(self.recon.state_dict(), os.path.join(path, f"recon_epoch{epoch}")) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py new file mode 100644 index 00000000..a5a2e8a9 --- /dev/null +++ b/lensless/utils/dataset.py @@ -0,0 +1,598 @@ +# ############################################################################# +# dataset.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + +import numpy as np +import glob +import os +import torch +from abc import abstractmethod +from torch.utils.data import Dataset +from torchvision import transforms +from lensless.utils.simulation import FarFieldSimulator +from lensless.utils.io import load_image, load_psf +from lensless.utils.image import resize + + +class DualDataset(Dataset): + """ + Abstract class for defining a dataset of paired lensed and lensless images. + """ + + def __init__( + self, + indices=None, + # psf_path=None, + background=None, + # background_pix=(0, 15), + downsample=1, + flip=False, + transform_lensless=None, + transform_lensed=None, + **kwargs, + ): + """ + Dataset consisting of lensless and corresponding lensed image. + + Parameters + ---------- + indices : range or int or None + Indices of the images to use in the dataset (if integer, it should be interpreted as range(indices)), by default None. + psf_path : str + Path to the PSF of the imaging system, by default None. + background : :py:class:`~torch.Tensor` or None, optional + If not ``None``, background is removed from lensless images, by default ``None``. If PSF is provided, background is estimated from the PSF. + background_pix : tuple, optional + Pixels to use for background estimation, by default (0, 15). + downsample : int, optional + Downsample factor of the lensless images, by default 1. + flip : bool, optional + If ``True``, lensless images are flipped, by default ``False``. + transform_lensless : PyTorch Transform or None, optional + Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + transform_lensed : PyTorch Transform or None, optional + Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + """ + if isinstance(indices, int): + indices = range(indices) + self.indices = indices + self.background = background + self.downsample = downsample + self.flip = flip + self.transform_lensless = transform_lensless + self.transform_lensed = transform_lensed + + # self.psf = None + # if psf_path is not None: + # psf, background = load_psf( + # psf_path, + # downsample=downsample, + # return_float=True, + # return_bg=True, + # bg_pix=background_pix, + # ) + # if self.background is None: + # self.background = background + # self.psf = torch.from_numpy(psf) + # if self.transform_lensless is not None: + # self.psf = self.transform_lensless(self.psf) + + @abstractmethod + def __len__(self): + """ + Abstract method to get the length of the dataset. It should take into account the indices parameter. + """ + raise NotImplementedError + + @abstractmethod + def _get_images_pair(self, idx): + """ + Abstract method to get the lensed and lensless images. Should return a pair (lensless, lensed) of numpy arrays with values in [0,1]. + + Parameters + ---------- + idx : int + images index + """ + raise NotImplementedError + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.item() + + if self.indices is not None: + idx = self.indices[idx] + lensless, lensed = self._get_images_pair(idx) + + if isinstance(lensless, np.ndarray): + # expected case + if self.downsample != 1.0: + lensless = resize(lensless, factor=1 / self.downsample) + lensed = resize(lensed, factor=1 / self.downsample) + + lensless = torch.from_numpy(lensless) + lensed = torch.from_numpy(lensed) + else: + # torch tensor + # This mean get_images_pair returned a torch tensor. This isn't recommended, if possible get_images_pair should return a numpy array + # In this case it should also have applied the downsampling + pass + + # If [H, W, C] -> [D, H, W, C] + if len(lensless.shape) == 3: + lensless = lensless.unsqueeze(0) + if len(lensed.shape) == 3: + lensed = lensed.unsqueeze(0) + + if self.background is not None: + lensless = lensless - self.background + + # flip image x and y if needed + if self.flip: + lensless = torch.rot90(lensless, dims=(-3, -2)) + lensed = torch.rot90(lensed, dims=(-3, -2)) + if self.transform_lensless: + lensless = self.transform_lensless(lensless) + if self.transform_lensed: + lensed = self.transform_lensed(lensed) + + return lensless, lensed + + +class SimulatedFarFieldDataset(DualDataset): + """ + Dataset of propagated images (through simulation) from a Torch Dataset. :py:class:`lensless.utils.simulation.FarFieldSimulator` is used for simulation, + assuming a far-field propagation and a shift-invariant system with a single point spread function (PSF). + + """ + + def __init__( + self, + dataset, + simulator, + pre_transform=None, + dataset_is_CHW=False, + flip=False, + **kwargs, + ): + """ + Parameters + ---------- + + dataset : :py:class:`torch.utils.data.Dataset` + Dataset to propagate. Should output images with shape [H, W, C] unless ``dataset_is_CHW`` is ``True`` (and therefore images have the dimension ordering of [C, H, W]). + simulator : :py:class:`lensless.utils.simulation.FarFieldSimulator` + Simulator object used on images from ``dataset``. Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. + pre_transform : PyTorch Transform or None, optional + Transform to apply to the images before simulation, by default ``None``. Note that this transform is applied on HCW images (different from torchvision). + dataset_is_CHW : bool, optional + If True, the input dataset is expected to output images with shape [C, H, W], by default ``False``. + flip : bool, optional + If True, images are flipped beffore the simulation, by default ``False``. + """ + + # we do the flipping before the simualtion + super(SimulatedFarFieldDataset, self).__init__(flip=False, **kwargs) + + assert isinstance(dataset, Dataset) + self.dataset = dataset + self.n_files = len(dataset) + + self.dataset_is_CHW = dataset_is_CHW + self._pre_transform = pre_transform + self.flip_pre_sim = flip + + # check simulator + assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" + assert simulator.is_torch, "Simulator should be a pytorch simulator" + assert simulator.fft_shape is not None, "Simulator should have a psf" + self.sim = simulator + + @property + def psf(self): + return self.sim.get_psf() + + def get_image(self, index): + return self.dataset[index] + + def _get_images_pair(self, index): + # load image + img, _ = self.get_image(index) + # convert to HWC for simulator and transform + if self.dataset_is_CHW: + img = img.moveaxis(-3, -1) + if self.flip_pre_sim: + img = torch.rot90(img, dims=(-3, -2)) + if self._pre_transform is not None: + img = self._pre_transform(img) + + lensless, lensed = self.sim.propagate_image(img, return_object_plane=True) + + if lensed.shape[-1] == 1 and lensless.shape[-1] == 3: + # copy to 3 channels + lensed = lensed.repeat(1, 1, 3) + assert ( + lensed.shape[-1] == lensless.shape[-1] + ), "Lensed and lensless should have same number of channels" + + return lensless, lensed + + def __len__(self): + if self.indices is None: + return self.n_files + else: + return len([x for x in self.indices if x < self.n_files]) + + +class MeasuredDatasetSimulatedOriginal(DualDataset): + """ + Dataset consisting of lensless image captured from a screen and the corresponding image shown on the screen. + Unlike :py:class:`lensless.utils.dataset.MeasuredDataset`, the ground-truth lensed image is simulated using a :py:class:`lensless.utils.simulation.FarFieldSimulator` + object rather than measured with a lensed camera. + """ + + def __init__( + self, + root_dir, + simulator, + lensless_fn="diffuser", + original_fn="lensed", + image_ext="npy", + original_ext=None, + downsample=1, + **kwargs, + ): + """ + Dataset consisting of lensless image captured from a screen and the corresponding image shown on screen. + + Parameters + ---------- + root_dir : str + Path to the test dataset. It is expected to contain two folders: one of lensless images and one of original images. + simulator : :py:class:`lensless.utils.simulatorFarFieldSimulator` + Simulator to use for the projection of the original image to object space. The PSF **should not** be specified, and it is expect to have ``is_torch = True``. + lensless_fn : str, optional + Name of the folder containing the lensless images, by default "diffuser". + lensed_fn : str, optional + Name of the folder containing the lensed images, by default "lensed". + image_ext : str, optional + Extension of the images, by default "npy". + original_ext : str, optional + Extension of the original image if different from lenless, by default None. + downsample : int, optional + Downsample factor of the lensless images, by default 1. + """ + super(MeasuredDatasetSimulatedOriginal, self).__init__(downsample=1, **kwargs) + self.pre_downsample = downsample + + self.root_dir = root_dir + self.lensless_dir = os.path.join(root_dir, lensless_fn) + self.original_dir = os.path.join(root_dir, original_fn) + assert os.path.isdir(self.lensless_dir) + assert os.path.isdir(self.original_dir) + + self.image_ext = image_ext.lower() + self.original_ext = original_ext.lower() if original_ext is not None else image_ext.lower() + + files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) + files.sort() + self.files = [os.path.basename(fn) for fn in files] + + if len(self.files) == 0: + raise FileNotFoundError( + f"No files found in {self.lensless_dir} with extension {image_ext}" + ) + + # check simulator + assert isinstance(simulator, FarFieldSimulator), "Simulator should be a FarFieldSimulator" + assert simulator.is_torch, "Simulator should be a pytorch simulator" + assert simulator.fft_shape is None, "Simulator should not have a psf" + self.sim = simulator + + def __len__(self): + if self.indices is None: + return len(self.files) + else: + return len([i for i in self.indices if i < len(self.files)]) + + def _get_images_pair(self, idx): + if self.image_ext == "npy" or self.image_ext == "npz": + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + original_fp = os.path.join(self.original_dir, self.files[idx]) + lensless = np.load(lensless_fp) + lensless = resize(lensless, factor=1 / self.downsample) + original = np.load(original_fp[:-3] + self.original_ext) + else: + # more standard image formats: png, jpg, tiff, etc. + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + original_fp = os.path.join(self.original_dir, self.files[idx]) + lensless = load_image(lensless_fp, downsample=self.pre_downsample) + original = load_image( + original_fp[:-3] + self.original_ext, downsample=self.pre_downsample + ) + + # convert to float + if lensless.dtype == np.uint8: + lensless = lensless.astype(np.float32) / 255 + original = original.astype(np.float32) / 255 + else: + # 16 bit + lensless = lensless.astype(np.float32) / 65535 + original = original.astype(np.float32) / 65535 + + # convert to torch + lensless = torch.from_numpy(lensless) + original = torch.from_numpy(original) + + # project original image to lensed space + with torch.no_grad(): + lensed = self.sim.propagate_image() + + return lensless, lensed + + +class MeasuredDataset(DualDataset): + """ + Dataset consisting of lensless and corresponding lensed image. + It can be used with a PyTorch DataLoader to load a batch of lensless and corresponding lensed images. + Unless the setup is perfectly calibrated, one should expect to have to use ``transform_lensed`` to adjust the alignment and rotation. + """ + + def __init__( + self, + root_dir, + lensless_fn="diffuser", + lensed_fn="lensed", + image_ext="npy", + **kwargs, + ): + """ + Dataset consisting of lensless and corresponding lensed image. Default parameters are for the + `DiffuserCam Lensless Mirflickr Dataset (DLMD) `_. + + Parameters + ---------- + root_dir : str + Path to the test dataset. It is expected to contain two folders: ones of lensless images and one of lensed images. + lensless_fn : str, optional + Name of the folder containing the lensless images, by default "diffuser". + lensed_fn : str, optional + Name of the folder containing the lensed images, by default "lensed". + image_ext : str, optional + Extension of the images, by default "npy". + """ + + super(MeasuredDataset, self).__init__(**kwargs) + + self.root_dir = root_dir + self.lensless_dir = os.path.join(root_dir, lensless_fn) + self.lensed_dir = os.path.join(root_dir, lensed_fn) + assert os.path.isdir(self.lensless_dir) + assert os.path.isdir(self.lensed_dir) + + self.image_ext = image_ext.lower() + + files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext)) + files.sort() + self.files = [os.path.basename(fn) for fn in files] + + if len(self.files) == 0: + raise FileNotFoundError( + f"No files found in {self.lensless_dir} with extension {image_ext}" + ) + + def __len__(self): + if self.indices is None: + return len(self.files) + else: + return len([i for i in self.indices if i < len(self.files)]) + + def _get_images_pair(self, idx): + if self.image_ext == "npy" or self.image_ext == "npz": + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) + lensless = np.load(lensless_fp) + lensed = np.load(lensed_fp) + + else: + # more standard image formats: png, jpg, tiff, etc. + lensless_fp = os.path.join(self.lensless_dir, self.files[idx]) + lensed_fp = os.path.join(self.lensed_dir, self.files[idx]) + lensless = load_image(lensless_fp) + lensed = load_image(lensed_fp) + + # convert to float + if lensless.dtype == np.uint8: + lensless = lensless.astype(np.float32) / 255 + lensed = lensed.astype(np.float32) / 255 + else: + # 16 bit + lensless = lensless.astype(np.float32) / 65535 + lensed = lensed.astype(np.float32) / 65535 + + return lensless, lensed + + +class DiffuserCamMirflickr(MeasuredDataset): + """ + Helper class for DiffuserCam Mirflickr dataset. + + Note that image colors are in BGR format: https://github.com/Waller-Lab/LenslessLearning/blob/master/utils.py#L432 + """ + + def __init__( + self, + dataset_dir, + psf_path, + downsample=2, + **kwargs, + ): + + psf, background = load_psf( + psf_path, + downsample=downsample * 4, # PSF is 4x the resolution of the images + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + self.psf = transform_BRG2RGB(torch.from_numpy(psf)) + self.allowed_idx = np.arange(2, 25001) + + super().__init__( + root_dir=dataset_dir, + background=background, + downsample=downsample, + flip=False, + transform_lensless=transform_BRG2RGB, + transform_lensed=transform_BRG2RGB, + lensless_fn="diffuser_images", + lensed_fn="ground_truth_lensed", + image_ext="npy", + **kwargs, + ) + + def _get_images_pair(self, idx): + + assert idx >= self.allowed_idx.min(), f"idx should be >= {self.allowed_idx.min()}" + assert idx <= self.allowed_idx.max(), f"idx should be <= {self.allowed_idx.max()}" + + fn = f"im{idx}.npy" + lensless_fp = os.path.join(self.lensless_dir, fn) + lensed_fp = os.path.join(self.lensed_dir, fn) + lensless = np.load(lensless_fp) + lensed = np.load(lensed_fp) + + return lensless, lensed + + +class DiffuserCamTestDataset(MeasuredDataset): + """ + Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking. + """ + + def __init__( + self, + data_dir=None, + n_files=None, + downsample=2, + ): + """ + Dataset consisting of lensless and corresponding lensed image. Default parameters are for the test set of + `DiffuserCam Lensless Mirflickr Dataset (DLMD) `_. + + Parameters + ---------- + data_dir : str, optional + The path to ``DiffuserCam_Test`` dataset, by default looks inside the ``data`` folder. + n_files : int, optional + Number of image pairs to load in the dataset , by default use all. + downsample : int, optional + Downsample factor of the lensless images, by default 2. Note that the PSF has a resolution of 4x of the images. + """ + + # download dataset if necessary + if data_dir is None: + data_dir = os.path.join( + os.path.dirname(__file__), "..", "..", "data", "DiffuserCam_Test" + ) + if not os.path.isdir(data_dir): + main_dir = os.path.join(os.path.dirname(__file__), "..", "..", "data") + print("No dataset found for benchmarking.") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the sample dataset (3.5GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/D3eRJ6PRljfHoH8/download" + filename = "DiffuserCam_Test.zip" + download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True) + + psf_fp = os.path.join(data_dir, "psf.tiff") + psf, background = load_psf( + psf_fp, + downsample=downsample * 4, # PSF is 4x the resolution of the images + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + + # transform from BGR to RGB + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + + self.psf = transform_BRG2RGB(torch.from_numpy(psf)) + + if n_files is None: + indices = None + else: + indices = range(n_files) + + super().__init__( + root_dir=data_dir, + indices=indices, + background=background, + downsample=downsample, + flip=False, + transform_lensless=transform_BRG2RGB, + transform_lensed=transform_BRG2RGB, + lensless_fn="diffuser", + lensed_fn="lensed", + image_ext="npy", + ) + + +class SimulatedDatasetTrainableMask(SimulatedFarFieldDataset): + """ + Dataset of propagated images (through simulation) from a Torch Dataset with learnable mask. + The `waveprop `_ package is used for the simulation, + assuming a far-field propagation and a shift-invariant system with a single point spread function (PSF). + To ensure autograd compatibility, the dataloader should have ``num_workers=0``. + """ + + def __init__( + self, + mask, + dataset, + simulator, + **kwargs, + ): + """ + Parameters + ---------- + + mask : :py:class:`lensless.hardware.trainable_mask.TrainableMask` + Mask to use for simulation. Should be a 4D tensor with shape [1, H, W, C]. Simulation of multi-depth data is not supported yet. + dataset : :py:class:`torch.utils.data.Dataset` + Dataset to propagate. Should output images with shape [H, W, C] unless ``dataset_is_CHW`` is ``True`` (and therefore images have the dimension ordering of [C, H, W]). + simulator : :py:class:`lensless.utils.simulation.FarFieldSimulator` + Simulator object used on images from ``dataset``. Waveprop simulator to use for the simulation. It is expected to have ``is_torch = True``. + """ + + self._mask = mask + + temp_psf = self._mask.get_psf() + test_sim = FarFieldSimulator(psf=temp_psf, **simulator.params) + assert ( + test_sim.conv_dim == simulator.conv_dim + ).all(), "PSF shape should match simulator shape" + assert ( + not simulator.quantize + ), "Simulator should not perform quantization to maintain differentiability. Please set quantize=False" + + super(SimulatedDatasetTrainableMask, self).__init__(dataset, simulator, **kwargs) + + def _get_images_pair(self, index): + # update psf + psf = self._mask.get_psf() + self.sim.set_point_spread_function(psf) + + # return simulated images + return super()._get_images_pair(index) diff --git a/lensless/utils/image.py b/lensless/utils/image.py index 7d2c65b3..748aaf50 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -1,5 +1,5 @@ # ############################################################################# -# image_utils.py +# image.py # ================= # Authors : # Eric BEZZAM [ebezzam@gmail.com] @@ -14,6 +14,7 @@ try: import torch import torchvision.transforms as tf + from torchvision.transforms.functional import rgb_to_grayscale torch_available = True except ImportError: @@ -76,16 +77,33 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC): return np.clip(resized, min_val, max_val) +def is_grayscale(img): + """ + Check if image is RGB. Assuming image is of shape ([depth,] height, width, color). + + Parameters + ---------- + img : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Image array. + + Returns + ------- + bool + Whether image is RGB. + """ + return img.shape[-1] == 1 + + def rgb2gray(rgb, weights=None, keepchanneldim=True): """ Convert RGB array to grayscale. Parameters ---------- - rgb : :py:class:`~numpy.ndarray` + rgb : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` ([Depth,] Height, Width, Channel) image. weights : :py:class:`~numpy.ndarray` - [Optional] (3,) weights to convert from RGB to grayscale. + [Optional] (3,) weights to convert from RGB to grayscale. Only used for NumPy arrays. keepchanneldim : bool Whether to keep the channel dimension. Default is True. @@ -95,22 +113,53 @@ def rgb2gray(rgb, weights=None, keepchanneldim=True): Grayscale image of dimension ([depth,] height, width [, 1]). """ - if weights is None: - weights = np.array([0.299, 0.587, 0.114]) - assert len(weights) == 3 - - if len(rgb.shape) == 4: - image = np.tensordot(rgb, weights, axes=((3,), 0)) - elif len(rgb.shape) == 3: - image = np.tensordot(rgb, weights, axes=((2,), 0)) - else: - raise ValueError("Input must be at least 3D.") - if keepchanneldim: - return image[..., np.newaxis] - else: + use_torch = False + if torch_available: + if torch.is_tensor(rgb): + use_torch = True + + if use_torch: + + # move channel dimension to third to last + if len(rgb.shape) == 4: + rgb = rgb.permute(0, 3, 1, 2) + elif len(rgb.shape) == 3: + rgb = rgb.permute(2, 0, 1) + else: + raise ValueError("Input must be at least 3D.") + + image = rgb_to_grayscale(rgb) + + # move channel dimension to last + if len(rgb.shape) == 4: + image = image.permute(0, 2, 3, 1) + elif len(rgb.shape) == 3: + image = image.permute(1, 2, 0) + + if not keepchanneldim: + image = image.squeeze(-1) + return image + else: + + if weights is None: + weights = np.array([0.299, 0.587, 0.114]) + assert len(weights) == 3 + + if len(rgb.shape) == 4: + image = np.tensordot(rgb, weights, axes=((3,), 0)) + elif len(rgb.shape) == 3: + image = np.tensordot(rgb, weights, axes=((2,), 0)) + else: + raise ValueError("Input must be at least 3D.") + + if keepchanneldim: + return image[..., np.newaxis] + else: + return image + def gamma_correction(vals, gamma=2.2): """ diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 57c4f740..1b2b234f 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -1,3 +1,11 @@ +# ############################################################################# +# io.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + import warnings from PIL import Image import cv2 @@ -6,7 +14,7 @@ from lensless.utils.plot import plot_image from lensless.hardware.constants import RPI_HQ_CAMERA_BLACK_LEVEL, RPI_HQ_CAMERA_CCM_MATRIX -from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray +from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray, get_max_val def load_image( @@ -22,6 +30,11 @@ def load_image( nbits_out=None, as_4d=False, downsample=None, + bg=None, + return_float=False, + shape=None, + dtype=None, + normalize=True, ): """ Load image as numpy array. @@ -53,6 +66,17 @@ def load_image( height, width, color). downsample : int, optional Downsampling factor. Recommended for image reconstruction. + bg : array_like + Background level to subtract. + return_float : bool + Whether to return image as float array, or unsigned int. + shape : tuple, optional + Shape (H, W, C) to resize to. + dtype : str, optional + Data type of returned data. Default is to use that of input. + normalize : bool, default True + If ``return_float``, whether to normalize data to maximum value of 1. + Returns ------- img : :py:class:`~numpy.ndarray` @@ -103,6 +127,8 @@ def load_image( if len(img.shape) == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + original_dtype = img.dtype + if flip: img = np.flipud(img) img = np.fliplr(img) @@ -110,14 +136,40 @@ def load_image( if verbose: print_image_info(img) + if bg is not None: + + # if bg is float vector, turn into int-valued vector + if bg.max() <= 1 and img.dtype not in [np.float32, np.float64]: + bg = bg * get_max_val(img) + + img = img - bg + img = np.clip(img, a_min=0, a_max=img.max()) + if as_4d: if len(img.shape) == 3: img = img[np.newaxis, :, :, :] elif len(img.shape) == 2: img = img[np.newaxis, :, :, np.newaxis] - if downsample is not None: - img = resize(img, factor=1 / downsample) + if downsample is not None or shape is not None: + if downsample is not None: + factor = 1 / downsample + else: + factor = None + img = resize(img, factor=factor, shape=shape) + + if return_float: + if dtype is None: + dtype = np.float32 + assert dtype == np.float32 or dtype == np.float64 + img = img.astype(dtype) + if normalize: + img /= img.max() + + else: + if dtype is None: + dtype = original_dtype + img = img.astype(dtype) return img @@ -212,6 +264,7 @@ def load_psf( ) original_dtype = psf.dtype + max_val = get_max_val(psf) psf = np.array(psf, dtype=dtype) if use_3d: @@ -274,6 +327,7 @@ def load_psf( if return_float: # psf /= psf.max() psf /= np.linalg.norm(psf.ravel()) + bg /= max_val else: psf = psf.astype(original_dtype) @@ -286,6 +340,7 @@ def load_psf( def load_data( psf_fp, data_fp, + return_float=True, downsample=None, bg_pix=(5, 25), plot=True, @@ -300,6 +355,7 @@ def load_data( shape=None, torch=False, torch_device="cpu", + normalize=False, ): """ Load data for image reconstruction. @@ -310,6 +366,8 @@ def load_data( Full path to PSF file. data_fp : str Full path to measurement file. + return_float : bool, optional + Whether to return PSF as float array, or unsigned int. downsample : int or float Downsampling factor. bg_pix : tuple, optional @@ -336,6 +394,8 @@ def load_data( Whether to sum RGB channels into single PSF, same across channels. Done in "Learned reconstructions for practical mask-based lensless imaging" of Kristina Monakhova et. al. + normalize : bool default True + Whether to normalize data to maximum value of 1. Returns ------- @@ -365,7 +425,7 @@ def load_data( psf, bg = load_psf( psf_fp, downsample=downsample, - return_float=True, + return_float=return_float, bg_pix=bg_pix, return_bg=True, flip=flip, @@ -379,21 +439,22 @@ def load_data( ) # load and process raw measurement - data = load_image(data_fp, flip=flip, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain) - data = np.array(data, dtype=dtype) - - data -= bg - data = np.clip(data, a_min=0, a_max=data.max()) - - if len(data.shape) == 3: - data = data[np.newaxis, :, :, :] - elif len(data.shape) == 2: - data = data[np.newaxis, :, :, np.newaxis] + data = load_image( + data_fp, + flip=flip, + bayer=bayer, + blue_gain=blue_gain, + red_gain=red_gain, + bg=bg, + as_4d=True, + return_float=return_float, + shape=shape, + normalize=normalize, + ) if data.shape != psf.shape: # in DiffuserCam dataset, images are already reshaped data = resize(data, shape=psf.shape) - data /= np.linalg.norm(data.ravel()) if data.shape[3] > 1 and psf.shape[3] == 1: warnings.warn( @@ -454,3 +515,58 @@ def save_image(img, fp, max_val=255): img = Image.fromarray(img) img.save(fp) + + +def get_dtype(dtype=None, is_torch=False): + """ + Get dtype for numpy or torch. + + Parameters + ---------- + dtype : str, optional + "float32" or "float64", Default is "float32". + is_torch : bool, optional + Whether to return torch dtype. + """ + if dtype is None: + dtype = "float32" + assert dtype == "float32" or dtype == "float64" + + if is_torch: + import torch + + if dtype is None: + if is_torch: + dtype = torch.float32 + else: + dtype = np.float32 + else: + if is_torch: + dtype = torch.float32 if dtype == "float32" else torch.float64 + else: + dtype = np.float32 if dtype == "float32" else np.float64 + + return dtype + + +def get_ctypes(dtype, is_torch): + if not is_torch: + if dtype == np.float32 or dtype == np.complex64: + return np.complex64, np.complex64 + elif dtype == np.float64 or dtype == np.complex128: + return np.complex128, np.complex128 + else: + raise ValueError("Unexpected dtype: ", dtype) + else: + import torch + + if dtype == np.float32 or dtype == np.complex64: + return torch.complex64, np.complex64 + elif dtype == np.float64 or dtype == np.complex128: + return torch.complex128, np.complex128 + elif dtype == torch.float32 or dtype == torch.complex64: + return torch.complex64, np.complex64 + elif dtype == torch.float64 or dtype == torch.complex128: + return torch.complex128, np.complex128 + else: + raise ValueError("Unexpected dtype: ", dtype) diff --git a/lensless/utils/simulation.py b/lensless/utils/simulation.py new file mode 100644 index 00000000..b77fabcb --- /dev/null +++ b/lensless/utils/simulation.py @@ -0,0 +1,163 @@ +# ############################################################################# +# simulation.py +# ================= +# Authors : +# Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + +from waveprop.simulation import FarFieldSimulator as FarFieldSimulator_wp +import torch + + +class FarFieldSimulator(FarFieldSimulator_wp): + """ + LenslessPiCam-compatible wrapper for :py:class:`~waveprop.simulation.FarFieldSimulator` (source code on `GitHub `__). + """ + + def __init__( + self, + object_height, + scene2mask, + mask2sensor, + sensor, + psf=None, + output_dim=None, + snr_db=None, + max_val=255, + device_conv="cpu", + random_shift=False, + is_torch=False, + quantize=True, + **kwargs + ): + """ + Parameters + ---------- + psf : np.ndarray or torch.Tensor, optional. + Point spread function. If not provided, return image at object plane. + object_height : float or (float, float) + Height of object in meters. Or range of values to randomly sample from. + scene2mask : float + Distance from scene to mask in meters. + mask2sensor : float + Distance from mask to sensor in meters. + sensor : str + Sensor name. + snr_db : float, optional + Signal-to-noise ratio in dB, by default None. + max_val : int, optional + Maximum value of image, by default 255. + device_conv : str, optional + Device to use for convolution (when using pytorch), by default "cpu". + random_shift : bool, optional + Whether to randomly shift the image, by default False. + is_torch : bool, optional + Whether to use pytorch, by default False. + quantize : bool, optional + Whether to quantize image, by default True. + """ + + assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" + + if torch.is_tensor(psf): + # drop depth dimension, and convert HWC to CHW + psf = psf[0].movedim(-1, 0) + assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + psf = psf[0] + assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" + + super().__init__( + object_height, + scene2mask, + mask2sensor, + sensor, + psf, + output_dim, + snr_db, + max_val, + device_conv, + random_shift, + is_torch, + quantize, + **kwargs + ) + + if self.is_torch: + assert self.psf.shape[0] == 1 or self.psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + assert ( + self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 + ), "PSF must have 1 or 3 channels" + + # save all the parameters in a dict + self.params = { + "object_height": object_height, + "scene2mask": scene2mask, + "mask2sensor": mask2sensor, + "sensor": sensor, + "output_dim": output_dim, + "snr_db": snr_db, + "max_val": max_val, + "device_conv": device_conv, + "random_shift": random_shift, + "is_torch": is_torch, + "quantize": quantize, + } + self.params.update(kwargs) + + def get_psf(self): + if self.is_torch: + # convert CHW to HWC + return self.psf.movedim(0, -1).unsqueeze(0) + else: + return self.psf[None, ...] + + # needs different name from parent class + def set_point_spread_function(self, psf): + """ + Set point spread function. + + Parameters + ---------- + psf : np.ndarray or torch.Tensor + Point spread function. + """ + assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" + + if torch.is_tensor(psf): + # convert HWC to CHW + psf = psf[0].movedim(-1, 0) + assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" + else: + psf = psf[0] + assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" + + return super().set_psf(psf) + + def propagate_image(self, obj, return_object_plane=False): + """ + Parameters + ---------- + obj : np.ndarray or torch.Tensor + Single image to propagate of format HWC. + return_object_plane : bool, optional + Whether to return object plane, by default False. + """ + + assert obj.shape[-1] == 1 or obj.shape[-1] == 3, "Image must have 1 or 3 channels" + + if self.is_torch: + # channel in first dimension as expected by waveprop for pytorch + obj = obj.moveaxis(-1, 0) + res = super().propagate(obj, return_object_plane) + if isinstance(res, tuple): + res = res[0].moveaxis(-3, -1), res[1].moveaxis(-3, -1) + else: + res = res.moveaxis(-3, -1) + return res + else: + # TODO: not tested, but normally don't need to move dimensions for numpy + res = super().propagate(obj, return_object_plane) + return res diff --git a/lensless/version.py b/lensless/version.py index 92192eed..68cdeee4 100644 --- a/lensless/version.py +++ b/lensless/version.py @@ -1 +1 @@ -__version__ = "1.0.4" +__version__ = "1.0.5" diff --git a/mask_requirements.txt b/mask_requirements.txt index ee87c51f..9e9c28a4 100644 --- a/mask_requirements.txt +++ b/mask_requirements.txt @@ -1,3 +1,3 @@ sympy>=1.11.1 perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 -waveprop>=0.0.4 \ No newline at end of file +waveprop>=0.0.8 \ No newline at end of file diff --git a/recon_requirements.txt b/recon_requirements.txt index 5d142936..0b90adf2 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -2,11 +2,11 @@ jedi==0.18.0 lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 -hydra-core click>=8.0.1 -waveprop>=0.0.3 # for simulation +waveprop>=0.0.8 # for simulation # Library for learning algorithm -torch >= 1.8.0 +torch >= 2.0.0 torchvision +torchmetrics lpips \ No newline at end of file diff --git a/scripts/classify/train_celeba_vit.py b/scripts/classify/train_celeba_vit.py new file mode 100644 index 00000000..79a32e44 --- /dev/null +++ b/scripts/classify/train_celeba_vit.py @@ -0,0 +1,330 @@ +""" +Fine-tune ViT on CelebA dataset measured with lensless camera. +Original tutorial: https://huggingface.co/blog/fine-tune-vit + +First, set-up HuggingFace libraries: +``` +pip install datasets transformers +``` + +Raw measurement datasets can be download from SwitchDrive. +This will be done by the script if the dataset is not found. +``` +# 10K measurements (13.1 GB) +python scripts/classify/train_celeba_vit.py \ +data.measured=data/celeba_adafruit_random_2mm_20230720_10K + +# 1K measurements (1.2 GB) +python scripts/classify/train_celeba_vit.py \ +data.measured=data/celeba_adafruit_random_2mm_20230720_1K +``` + +Note that the CelebA dataset also needs to be available locally! +It can be download here: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + +In order to classify on reconstructed outputs, the following +script needs to be run to create the dataset of reconstructed +images: +``` +# reconstruct with ADMM +python scripts/recon/dataset.py algo=admm \ +input.raw_data=path/to/raw/data +``` + +To classify on raw downsampled images, the same script can be +used, e.g. with the following command (`algo=null` for no reconstruction): +``` +python scripts/recon/dataset.py algo=null \ +input.raw_data=path/to/raw/data \ +preprocess.data_dim=[48,64] +``` + +Other hyperparameters for classification can be found in +`configs/train_celeba_classifier.yaml`. + +""" + +import warnings +from transformers import ViTImageProcessor, ViTForImageClassification +from transformers import TrainingArguments, Trainer, TrainerCallback +import numpy as np +import torch +import os +from hydra.utils import to_absolute_path +import glob +import hydra +import random +from datasets import load_metric +from PIL import Image +import pandas as pd +import time +import torchvision.transforms as transforms +import torchvision.datasets as dset +from datasets import Dataset +from copy import deepcopy +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor, +) + + +class CustomCallback(TrainerCallback): + def __init__(self, trainer) -> None: + super().__init__() + self._trainer = trainer + + def on_epoch_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + def on_step_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + def on_train_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + +@hydra.main(version_base=None, config_path="../../configs", config_name="train_celeba_classifier") +def train_celeba_classifier(config): + + seed = config.seed + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + # check how many measured files + measured_dataset = to_absolute_path(config.data.measured) + if not os.path.isdir(measured_dataset): + print(f"No dataset found at {measured_dataset}") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the CelebA dataset measured with a random Adafruit LCD pattern (13.1 GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/9NNGCJs3DoBDGlY/download" + filename = "celeba_adafruit_random_2mm_20230720_10K.zip" + download_and_extract_archive( + url, os.path.dirname(measured_dataset), filename=filename, remove_finished=True + ) + measured_files = sorted(glob.glob(os.path.join(measured_dataset, "*.png"))) + print(f"Found {len(measured_files)} files in {measured_dataset}") + + if config.data.n_files is not None: + n_files = config.data.n_files + measured_files = measured_files[: config.data.n_files] + print(f"Using {len(measured_files)} files") + n_files = len(measured_files) + + # create dataset split + attr = config.data.attr + ds = dset.CelebA( + root=config.data.original, + split="all", + download=False, + transform=transforms.ToTensor(), + ) + label_idx = ds.attr_names.index(attr) + labels = ds.attr[:, label_idx][:n_files] + + # make dataset with measured data and corresponding labels + df = pd.DataFrame( + { + "labels": labels, + "image_file_path": measured_files, + } + ) + ds = Dataset.from_pandas(df) + ds = ds.class_encode_column("labels") + + # -- train / test split + test_size = config.data.test_size + ds = ds.train_test_split( + test_size=test_size, stratify_by_column="labels", seed=seed, shuffle=True + ) + + # prepare dataset + model_name_or_path = "google/vit-base-patch16-224-in21k" + processor = ViTImageProcessor.from_pretrained(model_name_or_path) + + # -- processors for train and val + image_mean, image_std = processor.image_mean, processor.image_std + size = processor.size["height"] + + normalize = Normalize(mean=image_mean, std=image_std) + # _train_transforms = Compose( + # [ + # # RandomResizedCrop( + # # size, + # # scale=(0.9, 1.0), + # # ratio=(0.9, 1.1), + # # ), + # Resize(size), + # CenterCrop(size), + # RandomHorizontalFlip(), + # ToTensor(), + # normalize, + # ] + # ) + _train_transforms = [] + if config.augmentation.random_resize_crop: + _train_transforms.append( + RandomResizedCrop( + size, + scale=(0.9, 1.0), + ratio=(0.9, 1.1), + ) + ) + _train_transforms.append( + Resize(size), + CenterCrop(size), + ) + if config.augmentation.horizontal_flip: + if config.data.raw: + warnings.warn("Horizontal flip is not supported for raw data, Skipping!") + else: + _train_transforms.append(RandomHorizontalFlip()) + _train_transforms.append( + ToTensor(), + normalize, + ) + _train_transforms = Compose(_train_transforms) + + _val_transforms = Compose( + [ + Resize(size), + CenterCrop(size), + ToTensor(), + normalize, + ] + ) + + def train_transforms(examples): + # Take a list of PIL images and turn them to pixel values + examples["pixel_values"] = [ + _train_transforms(Image.open(fp)) for fp in examples["image_file_path"] + ] + return examples + + def val_transforms(examples): + # Take a list of PIL images and turn them to pixel values + examples["pixel_values"] = [ + _val_transforms(Image.open(fp)) for fp in examples["image_file_path"] + ] + return examples + + # transform dataset + ds["train"].set_transform(train_transforms) + ds["test"].set_transform(val_transforms) + + # data collator + def collate_fn(batch): + return { + "pixel_values": torch.stack([x["pixel_values"] for x in batch]), + "labels": torch.tensor([x["labels"] for x in batch]), + } + + # evaluation metric + metric = load_metric("accuracy") + + def compute_metrics(p): + return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids) + + # load model + if config.train.prev is not None: + model_path = to_absolute_path(config.train.prev) + else: + model_path = model_name_or_path + + labels = ds["train"].features["labels"].names + model = ViTForImageClassification.from_pretrained( + model_path, + num_labels=len(labels), + id2label={str(i): c for i, c in enumerate(labels)}, + label2id={c: str(i) for i, c in enumerate(labels)}, + hidden_dropout_prob=config.train.dropout, + attention_probs_dropout_prob=config.train.dropout, + ) + + # configure training + output_dir = ( + config.data.output_dir + f"-{config.data.attr}" + os.path.basename(measured_dataset) + ) + + training_args = TrainingArguments( + output_dir=output_dir, + per_device_train_batch_size=config.train.batch_size, + evaluation_strategy="steps", + eval_steps=100, + save_steps=100, + num_train_epochs=config.train.n_epochs, + fp16=True, + logging_steps=10, + learning_rate=config.train.learning_rate, + save_total_limit=2, + remove_unused_columns=False, # important to keep False + push_to_hub=False, + report_to="tensorboard", + load_best_model_at_end=True, + ) + + trainer = Trainer( + model=model, + args=training_args, + data_collator=collate_fn, + compute_metrics=compute_metrics, + tokenizer=processor, + train_dataset=ds["train"], + eval_dataset=ds["test"], + ) + trainer.add_callback(CustomCallback(trainer)) # add accuracy on train set + + # train + hydra_output = os.getcwd() + print("Results saved to : ", hydra_output) + + start_time = time.time() + train_results = trainer.train() + trainer.save_model() + trainer.log_metrics("train", train_results.metrics) + trainer.save_metrics("train", train_results.metrics) + trainer.save_state() + + # evaluate + metrics = trainer.evaluate(ds["train"]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + metrics = trainer.evaluate(ds["test"]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # + hydra_output = os.getcwd() + print("Results saved to : ", hydra_output) + print(f"Training took {time.time() - start_time} seconds") + + +if __name__ == "__main__": + train_celeba_classifier() diff --git a/scripts/demo.py b/scripts/demo.py index 32b26e42..760b663a 100644 --- a/scripts/demo.py +++ b/scripts/demo.py @@ -18,7 +18,9 @@ @hydra.main(version_base=None, config_path="../configs", config_name="demo") def demo(config): - RPI_USERNAME, RPI_HOSTNAME = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + RPI_USERNAME = config.rpi.username + RPI_HOSTNAME = config.rpi.hostname display_fp = to_absolute_path(config.fp) if config.save: diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index de6a1c68..89a31309 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -20,9 +20,10 @@ import json import os import pathlib as plib -from lensless.eval.benchmark import benchmark, DiffuserCamTestDataset +from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent +from lensless.utils.dataset import DiffuserCamTestDataset try: import torch @@ -43,9 +44,7 @@ def benchmark_recon(config): device = "cpu" # Benchmark dataset - benchmark_dataset = DiffuserCamTestDataset( - data_dir=os.path.join(get_original_cwd(), "data"), n_files=n_files, downsample=downsample - ) + benchmark_dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample) psf = benchmark_dataset.psf.to(device) model_list = [] # list of algoritms to benchmark diff --git a/scripts/hardware/config_digicam.py b/scripts/hardware/config_digicam.py new file mode 100644 index 00000000..cd8cab86 --- /dev/null +++ b/scripts/hardware/config_digicam.py @@ -0,0 +1,101 @@ +import warnings +import hydra +from datetime import datetime +import numpy as np +from slm_controller import slm +from slm_controller.hardware import SLMParam, slm_devices +import matplotlib.pyplot as plt + +from lensless.hardware.slm import set_programmable_mask +from lensless.hardware.aperture import rect_aperture, circ_aperture +from lensless.hardware.utils import set_mask_sensor_distance + + +@hydra.main(version_base=None, config_path="../../configs", config_name="digicam") +def config_digicam(config): + + rpi_username = config.rpi.username + rpi_hostname = config.rpi.hostname + device = config.device + + shape = slm_devices[device][SLMParam.SLM_SHAPE] + if not slm_devices[device][SLMParam.MONOCHROME]: + shape = (3, *shape) + pixel_pitch = slm_devices[device][SLMParam.PIXEL_PITCH] + + # set mask to sensor distance + if config.z is not None and not config.virtual: + set_mask_sensor_distance(config.z, rpi_username, rpi_hostname) + + center = np.array(config.center) * pixel_pitch + + # create random pattern + pattern = None + if config.pattern.endswith(".npy"): + pattern = np.load(config.pattern) + elif config.pattern == "random": + rng = np.random.RandomState(1) + # pattern = rng.randint(low=0, high=np.iinfo(np.uint8).max, size=shape, dtype=np.uint8) + pattern = rng.uniform(low=config.min_val, high=1, size=shape) + pattern = (pattern * np.iinfo(np.uint8).max).astype(np.uint8) + + elif config.pattern == "rect": + rect_shape = config.rect_shape + apert_dim = rect_shape[0] * pixel_pitch[0], rect_shape[1] * pixel_pitch[1] + ap = rect_aperture( + apert_dim=apert_dim, + slm_shape=slm_devices[device][SLMParam.SLM_SHAPE], + pixel_pitch=pixel_pitch, + center=center, + ) + pattern = ap.values + elif config.pattern == "circ": + ap = circ_aperture( + radius=config.radius * pixel_pitch[0], + slm_shape=slm_devices[device][SLMParam.SLM_SHAPE], + pixel_pitch=pixel_pitch, + center=center, + ) + pattern = ap.values + else: + raise ValueError(f"Pattern {config.pattern} not supported") + + # save pattern + if not config.pattern.endswith(".npy") and config.save: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + pattern_fn = f"{device}_{config.pattern}_pattern_{timestamp}.npy" + np.save(pattern_fn, pattern) + print(f"Saved pattern to {pattern_fn}") + + print("Pattern shape : ", pattern.shape) + print("Pattern dtype : ", pattern.dtype) + print("Pattern min : ", pattern.min()) + print("Pattern max : ", pattern.max()) + + # apply aperture + if config.aperture is not None: + + aperture = np.zeros(shape, dtype=np.uint8) + top_left = np.array(config.aperture.center) - np.array(config.aperture.shape) // 2 + bottom_right = top_left + np.array(config.aperture.shape) + aperture[:, top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]] = 1 + pattern = pattern * aperture + + assert pattern is not None + + n_nonzero = np.count_nonzero(pattern) + print(f"Nonzero pixels: {n_nonzero}") + + if not config.virtual: + set_programmable_mask(pattern, device, rpi_username, rpi_hostname) + + # preview mask + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = slm.create(device) + s._show_preview(pattern) + plt.savefig("preview.png") + + +if __name__ == "__main__": + config_digicam() diff --git a/scripts/hardware/digicam_measure_psfs.py b/scripts/hardware/digicam_measure_psfs.py new file mode 100644 index 00000000..901d24cb --- /dev/null +++ b/scripts/hardware/digicam_measure_psfs.py @@ -0,0 +1,60 @@ +import numpy as np +from lensless.hardware.utils import set_mask_sensor_distance +import hydra +import os +from datetime import datetime +from PIL import Image + +SATURATION_THRESHOLD = 0.01 + + +@hydra.main(version_base=None, config_path="../../configs", config_name="digicam") +def config_digicam(config): + + rpi_username = config.rpi.username + rpi_hostname = config.rpi.hostname + + mask_sensor_distances = np.arange(9) * 0.1 + exposure_time = 5 + + timestamp = datetime.now().strftime("%Y%m%d") + + for i in range(len(mask_sensor_distances)): + + print(f"Mask sensor distance: {mask_sensor_distances[i]}mm") + mask_sensor_distance = mask_sensor_distances[i] + + # set the mask sensor distance + set_mask_sensor_distance(mask_sensor_distance, rpi_username, rpi_hostname) + + good_exposure = False + while not good_exposure: + + # measure PSF + output_folder = f"adafruit_psf_{mask_sensor_distance}mm__{timestamp}" + os.system( + f"python scripts/remote_capture.py -cn capture_bayer output={output_folder} rpi.username={rpi_username} rpi.hostname={rpi_hostname} capture.exp={exposure_time}" + ) + + # check for saturation + OUTPUT_FP = os.path.join(output_folder, "raw_data.png") + # -- load picture to check for saturation + img = np.array(Image.open(OUTPUT_FP)) + ratio = np.sum(img == 4095) / np.prod(img.shape) + print(f"Saturation ratio: {ratio}") + if ratio > SATURATION_THRESHOLD or ratio == 0: + + if ratio == 0: + print("Need to increase exposure time.") + else: + print("Need to decrease exposure time.") + + # enter new exposure time from keyboard + exposure_time = float(input("Enter new exposure time: ")) + + else: + good_exposure = True + + +if __name__ == "__main__": + config_digicam() diff --git a/scripts/measure/collect_dataset_on_device.py b/scripts/measure/collect_dataset_on_device.py index 7d58ed61..4de5dc4a 100644 --- a/scripts/measure/collect_dataset_on_device.py +++ b/scripts/measure/collect_dataset_on_device.py @@ -1,7 +1,7 @@ """ To be run on the Raspberry Pi! ``` -python scripts/collect_dataset_on_device.py +python scripts/measure/collect_dataset_on_device.py ``` Note that the script is configured for the Raspberry Pi HQ camera @@ -173,7 +173,7 @@ def collect_dataset(config): display_image_path = config.display.output_fp rot90 = config.display.rot90 os.system( - f"python scripts/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" + f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" ) time.sleep(config.capture.delay) @@ -241,7 +241,7 @@ def collect_dataset(config): display_image_path = config.display.output_fp rot90 = config.display.rot90 os.system( - f"python scripts/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" + f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" ) print(f"decreasing screen brightness to {current_screen_brightness}") diff --git a/scripts/measure/remote_capture.py b/scripts/measure/remote_capture.py index 92f2033e..66210a86 100644 --- a/scripts/measure/remote_capture.py +++ b/scripts/measure/remote_capture.py @@ -32,7 +32,9 @@ def liveview(config): rgb = config.capture.rgb gray = config.capture.gray - username, hostname = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + username = config.rpi.username + hostname = config.rpi.hostname legacy = config.capture.legacy nbits_out = config.capture.nbits_out fn = config.capture.raw_data_fn diff --git a/scripts/measure/remote_display.py b/scripts/measure/remote_display.py index f9ab3ed2..1be931a3 100644 --- a/scripts/measure/remote_display.py +++ b/scripts/measure/remote_display.py @@ -35,12 +35,15 @@ @hydra.main(version_base=None, config_path="../../configs", config_name="demo") def remote_display(config): - username, hostname = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + username = config.rpi.username + hostname = config.rpi.hostname fp = config.fp shape = np.array(config.display.screen_res) psf = config.display.psf black = config.display.black + white = config.display.white if psf: point_source = np.zeros(tuple(shape) + (3,)) @@ -58,12 +61,18 @@ def remote_display(config): im = Image.fromarray(point_source.astype("uint8"), "RGB") im.save(fp) + elif white: + point_source = np.ones(tuple(shape) + (3,)) * 255 + fp = "tmp_display.png" + im = Image.fromarray(point_source.astype("uint8"), "RGB") + im.save(fp) + """ processing on remote machine, less issues with copying """ # copy picture to Raspberry Pi print("\nCopying over picture...") display(fp=fp, rpi_username=username, rpi_hostname=hostname, **config.display) - if psf or black: + if psf or black or white: os.remove(fp) diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 17a88461..c84d5b92 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -13,8 +13,9 @@ import pathlib as plib import matplotlib.pyplot as plt import numpy as np -from lensless.utils.io import load_data +from lensless.utils.io import load_data, load_image from lensless import ADMM +from lensless.utils.plot import plot_image @hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") @@ -41,6 +42,8 @@ def admm(config): shape=config["preprocess"]["shape"], torch=config.torch, torch_device=config.torch_device, + bg_pix=config.preprocess.bg_pix, + normalize=config.preprocess.normalize, ) disp = config["display"]["disp"] @@ -51,22 +54,31 @@ def admm(config): if save: save = os.getcwd() + if save: + if config.torch: + org_data = data.cpu().numpy() + else: + org_data = data + ax = plot_image(org_data, gamma=config["display"]["gamma"]) + ax.set_title("Original measurement") + plt.savefig(plib.Path(save) / "lensless.png") + start_time = time.time() if not config.admm.unrolled: recon = ADMM(psf, **config.admm) else: assert config.torch, "Unrolled ADMM only works with torch" from lensless.recon.unrolled_admm import UnrolledADMM - import train_unrolled + import lensless.recon.utils - pre_process = train_unrolled.create_process_network( + pre_process, _ = lensless.recon.utils.create_process_network( network=config.admm.pre_process_model.network, - depth=config.admm.pre_process_depth.depth, + depth=config.admm.pre_process_model.depth, device=config.torch_device, ) - post_process = train_unrolled.create_process_network( + post_process, _ = lensless.recon.utils.create_process_network( network=config.admm.post_process_model.network, - depth=config.admm.post_process_depth.depth, + depth=config.admm.post_process_model.depth, device=config.torch_device, ) @@ -75,18 +87,28 @@ def admm(config): print("Loading checkpoint from : ", path) assert os.path.exists(path), "Checkpoint does not exist" recon.load_state_dict(torch.load(path, map_location=config.torch_device)) + recon.set_data(data) print(f"Setup time : {time.time() - start_time} s") start_time = time.time() if config.torch: with torch.no_grad(): - res = recon.apply( - disp_iter=disp, - save=save, - gamma=config["display"]["gamma"], - plot=config["display"]["plot"], - ) + if config.admm.unrolled: + res = recon.apply( + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + output_intermediate=True, + ) + else: + res = recon.apply( + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + ) else: res = recon.apply( disp_iter=disp, @@ -104,7 +126,78 @@ def admm(config): if config["display"]["plot"]: plt.show() if save: + + if config.admm.unrolled: + # Save intermediate results + if res[1] is not None: + pre_processed_image = res[1].cpu().numpy() + ax = plot_image(pre_processed_image, gamma=config["display"]["gamma"]) + ax.set_title("Image after preprocessing") + plt.savefig(plib.Path(save) / "pre_processed.png") + + if res[2] is not None: + pre_post_process_image = res[2].cpu().numpy() + ax = plot_image(pre_post_process_image, gamma=config["display"]["gamma"]) + ax.set_title("Image prior to post-processing") + plt.savefig(plib.Path(save) / "pre_post_process.png") + np.save(plib.Path(save) / "final_reconstruction.npy", img) + + if config.input.original is not None: + original = load_image( + to_absolute_path(config.input.original), + flip=config["preprocess"]["flip"], + red_gain=config["preprocess"]["red_gain"], + blue_gain=config["preprocess"]["blue_gain"], + shape=img.shape[-3:], + ) + ax = plot_image(original, gamma=config["display"]["gamma"]) + ax.set_title("Ground truth image") + plt.savefig(plib.Path(save) / "original.png") + + # compute metrics + from torchmetrics.image import lpip, psnr + + lpips_func = lpip.LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True) + psnr_funct = psnr.PeakSignalNoiseRatio() + + img_torch = torch.from_numpy(img).squeeze(0) + original_torch = torch.from_numpy(original).unsqueeze(0) + + # channel as first dimension + img_torch = img_torch.movedim(-1, -3) + original_torch = original_torch.movedim(-1, -3) + + # normalize, TODO img max value is 14 which seems strange + img_torch = img_torch / torch.amax(img_torch) + + # compute metrics + lpips = lpips_func(img_torch, original_torch) + psnr = psnr_funct(img_torch, original_torch) + print(f"LPIPS : {lpips}") + print(f"PSNR : {psnr}") + + # If the recon algorithm is unrolled and has a preprocessing step, plot result without preprocessing + if config.admm.unrolled and recon.pre_process is not None: + recon.set_data(data) + recon.pre_process = None + with torch.no_grad(): + res = recon.apply( + disp_iter=disp, + save=False, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + output_intermediate=True, + ) + + img = res[0].cpu().numpy() + np.save(plib.Path(save) / "final_reconstruction_no_preprocessing.npy", img[0]) + ax = plot_image(img, gamma=config["display"]["gamma"]) + plt.savefig(plib.Path(save) / "final_reconstruction_no_preprocessing.png") + pre_post_process_image = res[2].cpu().numpy() + ax = plot_image(pre_post_process_image, gamma=config["display"]["gamma"]) + plt.savefig(plib.Path(save) / "pre_post_process_no_preprocessing.png") + print(f"Files saved to : {save}") diff --git a/scripts/recon/apgd_pycsou.py b/scripts/recon/apgd_pycsou.py index 878b378f..0bf236d0 100644 --- a/scripts/recon/apgd_pycsou.py +++ b/scripts/recon/apgd_pycsou.py @@ -17,7 +17,7 @@ import time import matplotlib.pyplot as plt from lensless.utils.io import load_data -from lensless import APGD +from lensless.recon.apgd import APGD import os import pathlib as plib @@ -28,7 +28,7 @@ log = logging.getLogger(__name__) -@hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") +@hydra.main(version_base=None, config_path="../../configs", config_name="apgd_l1") def apgd( config, ): diff --git a/scripts/recon/dataset.py b/scripts/recon/dataset.py new file mode 100644 index 00000000..4c4192c5 --- /dev/null +++ b/scripts/recon/dataset.py @@ -0,0 +1,202 @@ +""" +Apply ADMM reconstruction to folder. + +``` +python scripts/recon/dataset.py +``` + +To run APGD, use the following command: +``` +python scripts/recon/dataset.py algo=apgd +``` + +To just copy resized raw data, use the following command: +``` +python scripts/recon/dataset.py algo=null preprocess.data_dim=[48,64] +``` + +""" + +import hydra +from hydra.utils import to_absolute_path +import os +import time +import numpy as np +from lensless.utils.io import load_psf, load_image, save_image +from lensless import ADMM +import torch +import glob +from tqdm import tqdm +from lensless.recon.apgd import APGD +from joblib import Parallel, delayed + + +@hydra.main(version_base=None, config_path="../../configs", config_name="recon_dataset") +def admm_dataset(config): + + algo = config.algo + + # get raw data file paths + dataset = to_absolute_path(config.input.raw_data) + if not os.path.isdir(dataset): + print(f"No dataset found at {dataset}") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the sample CelebA dataset measured with a random Adafruit LCD pattern (1.2 GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/m89D1tFEfktQueS/download" + filename = "celeba_adafruit_random_2mm_20230720_1K.zip" + download_and_extract_archive( + url, os.path.dirname(dataset), filename=filename, remove_finished=True + ) + data_fps = sorted(glob.glob(os.path.join(dataset, "*.png"))) + if config.n_files is not None: + data_fps = data_fps[: config.n_files] + n_files = len(data_fps) + + # load PSF + psf_fp = to_absolute_path(config.input.psf) + flip = config.preprocess.flip + dtype = config.input.dtype + print("\nPSF:") + psf, bg = load_psf( + psf_fp, + verbose=True, + downsample=config.preprocess.downsample, + return_bg=True, + flip=flip, + dtype=dtype, + ) + print(f"Downsampled PSF shape: {psf.shape}") + + data_dim = None + if config.preprocess.data_dim is not None: + data_dim = tuple(config.preprocess.data_dim) + (psf.shape[-1],) + else: + data_dim = psf.shape + + # -- create output folder + output_folder = to_absolute_path(config.output_folder) + if algo == "apgd": + output_folder = output_folder + f"_apgd{config.apgd.max_iter}" + elif algo == "admm": + output_folder = output_folder + f"_admm{config.admm.n_iter}" + else: + output_folder = output_folder + "_raw" + output_folder = output_folder + f"_{data_dim[-3]}x{data_dim[-2]}" + os.makedirs(output_folder, exist_ok=True) + + # -- apply reconstruction + if algo == "apgd": + + start_time = time.time() + + def recover(i): + + # reconstruction object + recon = APGD(psf=psf, **config.apgd) + + data_fp = data_fps[i] + + # load data + data = load_image( + data_fp, flip=flip, bg=bg, as_4d=True, return_float=True, shape=data_dim + ) + data = data[0] # first depth + + # apply reconstruction + recon.set_data(data) + img = recon.apply( + disp_iter=config.display.disp, + gamma=config.display.gamma, + plot=config.display.plot, + ) + + # -- extract region of interest and save + if config.roi is not None: + roi = config.roi + img = img[roi[0] : roi[2], roi[1] : roi[3]] + + bn = os.path.basename(data_fp) + output_fp = os.path.join(output_folder, bn) + save_image(img, output_fp) + + n_jobs = config.apgd.n_jobs + if n_jobs > 1: + Parallel(n_jobs=n_jobs)(delayed(recover)(i) for i in range(n_files)) + else: + for i in tqdm(range(n_files)): + recover(i) + + else: + + if config.torch: + torch_dtype = torch.float32 + torch_device = config.torch_device + psf = torch.from_numpy(psf).type(torch_dtype).to(torch_device) + + # create reconstruction object + recon = None + if config.algo == "admm": + recon = ADMM(psf, **config.admm) + + # loop over files and apply reconstruction + start_time = time.time() + + for i in tqdm(range(n_files)): + data_fp = data_fps[i] + + # load data + data = load_image( + data_fp, flip=flip, bg=bg, as_4d=True, return_float=True, shape=data_dim + ) + + if config.torch: + data = torch.from_numpy(data).type(torch_dtype).to(torch_device) + + if recon is not None: + + # set data + recon.set_data(data) + + # apply reconstruction + res = recon.apply( + n_iter=config.admm.n_iter, + disp_iter=config.display.disp, + gamma=config.display.gamma, + plot=config.display.plot, + ) + + else: + + # copy resized raw data + res = data + + # save reconstruction as PNG + # -- take first depth + if config.torch: + img = res[0].cpu().numpy() + else: + img = res[0] + + # -- extract region of interest + if config.roi is not None: + img = img[config.roi[0] : config.roi[2], config.roi[1] : config.roi[3]] + + bn = os.path.basename(data_fp) + output_fp = os.path.join(output_folder, bn) + save_image(img, output_fp) + + print(f"Processing time : {time.time() - start_time} s") + # time per file + print(f"Time per file : {(time.time() - start_time) / n_files} s") + print("Files saved to: ", output_folder) + + +if __name__ == "__main__": + admm_dataset() diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 883f1819..5cbee7bf 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -3,6 +3,7 @@ # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# """ @@ -12,30 +13,77 @@ python scripts/recon/train_unrolled.py ``` +By default it uses the configuration from the file `configs/train_unrolledADMM.yaml`. + +To train pre- and post-processing networks, use the following command: +``` +python scripts/recon/train_unrolled.py -cn train_pre-post-processing +``` + +To fine-tune the DiffuserCam PSF, use the following command: +``` +python scripts/recon/train_unrolled.py -cn fine-tune_PSF +``` + +To train a PSF from scratch with a simulated dataset, use the following command: +``` +python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +``` + """ -import math +import logging import hydra from hydra.utils import get_original_cwd import os import numpy as np import time -import matplotlib.pyplot as plt from lensless import UnrolledFISTA, UnrolledADMM -from waveprop.dataset_util import SimulatedPytorchDataset -from lensless.utils.image import rgb2gray -from lensless.eval.benchmark import benchmark, DiffuserCamTestDataset +from lensless.utils.dataset import ( + DiffuserCamMirflickr, + SimulatedFarFieldDataset, + SimulatedDatasetTrainableMask, +) +from torch.utils.data import Subset +import lensless.hardware.trainable_mask +from lensless.recon.utils import create_process_network +from lensless.utils.image import rgb2gray, is_grayscale +from lensless.utils.simulation import FarFieldSimulator +from lensless.recon.utils import Trainer import torch from torchvision import transforms, datasets -from tqdm import tqdm +from lensless.utils.io import load_psf +from lensless.utils.io import save_image +from lensless.utils.plot import plot_image +import matplotlib.pyplot as plt + +# A logger for this file +log = logging.getLogger(__name__) + + +def simulate_dataset(config): + + if config.torch_device == "cuda" and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" -try: - import json -except ImportError: - print("json package not found, metrics will not be saved") + # prepare PSF + psf_fp = os.path.join(get_original_cwd(), config.files.psf) + psf, _ = load_psf( + psf_fp, + downsample=config.files.downsample, + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + if config.files.diffusercam_psf: + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + psf = transform_BRG2RGB(torch.from_numpy(psf)) + # drop depth dimension + psf = psf.to(device) -def simulate_dataset(config, psf): # load dataset transforms_list = [transforms.ToTensor()] data_path = os.path.join(get_original_cwd(), "data") @@ -43,92 +91,126 @@ def simulate_dataset(config, psf): transforms_list.append(transforms.Grayscale()) transform = transforms.Compose(transforms_list) if config.files.dataset == "mnist": - ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) elif config.files.dataset == "fashion_mnist": - ds = datasets.FashionMNIST(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.FashionMNIST( + root=data_path, train=True, download=True, transform=transform + ) + test_ds = datasets.FashionMNIST( + root=data_path, train=False, download=True, transform=transform + ) elif config.files.dataset == "cifar10": - ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) + train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) elif config.files.dataset == "CelebA": - ds = datasets.CelebA(root=data_path, split="train", download=True, transform=transform) + root = config.files.celeba_root + data_path = os.path.join(root, "celeba") + assert os.path.isdir( + data_path + ), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" + train_ds = datasets.CelebA(root=root, split="train", download=False, transform=transform) + test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) else: raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") # convert PSF - if config.simulation.grayscale: + if config.simulation.grayscale and not is_grayscale(psf): psf = rgb2gray(psf) - if not isinstance(psf, torch.Tensor): - psf = transforms.ToTensor()(psf) - elif psf.shape[-1] == 3: - # Waveprop syntetic dataset expect C H W - psf = psf.permute(2, 0, 1) - - # batch_size = config.files.batch_size - batch_size = config.training.batch_size - n_files = config.files.n_files - device_conv = config.torch_device - target = config.target + + # prepare mask + mask = prep_trainable_mask(config, psf, grayscale=config.simulation.grayscale) # check if gpu is available + device_conv = config.torch_device if device_conv == "cuda" and torch.cuda.is_available(): device_conv = "cuda" else: device_conv = "cpu" + # create simulator + simulator = FarFieldSimulator( + psf=psf, + is_torch=True, + **config.simulation, + ) + # create Pytorch dataset and dataloader + n_files = config.files.n_files if n_files is not None: - ds = torch.utils.data.Subset(ds, np.arange(n_files)) - ds_prop = SimulatedPytorchDataset( - dataset=ds, psf=psf, device_conv=device_conv, target=target, **config.simulation - ) - ds_loader = torch.utils.data.DataLoader( - dataset=ds_prop, batch_size=batch_size, shuffle=True, pin_memory=(psf.device != "cpu") - ) - return ds_loader + train_ds = torch.utils.data.Subset(train_ds, np.arange(n_files)) + test_ds = torch.utils.data.Subset(test_ds, np.arange(n_files)) + if mask is None: + train_ds_prop = SimulatedFarFieldDataset( + dataset=train_ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + test_ds_prop = SimulatedFarFieldDataset( + dataset=test_ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + else: + train_ds_prop = SimulatedDatasetTrainableMask( + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + test_ds_prop = SimulatedDatasetTrainableMask( + dataset=test_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + ) + return train_ds_prop, test_ds_prop, mask -def create_process_network(network, depth, device="cpu"): - if network == "DruNet": - from lensless.recon.utils import load_drunet - process = load_drunet( - os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True - ).to(device) - process_name = "DruNet" - elif network == "UnetRes": - from lensless.recon.drunet.network_unet import UNetRes - - n_channels = 3 - process = UNetRes( - in_nc=n_channels + 1, - out_nc=n_channels, - nc=[64, 128, 256, 512], - nb=depth, - act_mode="R", - downsample_mode="strideconv", - upsample_mode="convtranspose", - ).to(device) - process_name = "UnetRes_d" + str(depth) - else: - process = None - process_name = None +def prep_trainable_mask(config, psf, grayscale=False): + mask = None + if config.trainable_mask.mask_type is not None: + mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) + + if config.trainable_mask.initial_value == "random": + initial_mask = torch.rand_like(psf) + elif config.trainable_mask.initial_value == "psf": + initial_mask = psf.clone() + else: + raise ValueError( + f"Initial PSF value {config.trainable_mask.initial_value} not supported" + ) + + if config.trainable_mask.grayscale and not is_grayscale(initial_mask): + initial_mask = rgb2gray(initial_mask) + + mask = mask_class( + initial_mask, optimizer="Adam", lr=config.trainable_mask.mask_lr, grayscale=grayscale + ) - return (process, process_name) + return mask -def measure_gradient(model): - # return the L2 norm of the gradient - total_norm = 0.0 - for p in model.parameters(): - param_norm = p.grad.detach().data.norm(2) - total_norm += param_norm.item() ** 2 - total_norm = total_norm**0.5 - return total_norm +@hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") +def train_unrolled(config): + disp = config.display.disp + if disp < 0: + disp = None + + save = config.save + if save: + save = os.getcwd() -@hydra.main(version_base=None, config_path="../../configs", config_name="unrolled_recon") -def train_unrolled( - config, -): if config.torch_device == "cuda" and torch.cuda.is_available(): print("Using GPU for training.") device = "cuda" @@ -136,35 +218,55 @@ def train_unrolled( print("Using CPU for training.") device = "cpu" - # torch.autograd.set_detect_anomaly(True) - - # if using a portrait dataset rotate the PSF - flip = config.files.dataset in ["CelebA"] - - # benchmarking dataset: - path = os.path.join(get_original_cwd(), "data") - benchmark_dataset = DiffuserCamTestDataset( - data_dir=path, downsample=config.simulation.downsample - ) - - psf = benchmark_dataset.psf.to(device) - background = benchmark_dataset.background + # load dataset and create dataloader + train_set = None + test_set = None + psf = None + if "DiffuserCam" in config.files.dataset: + + original_path = os.path.join(get_original_cwd(), config.files.dataset) + psf_path = os.path.join(get_original_cwd(), config.files.psf) + dataset = DiffuserCamMirflickr( + dataset_dir=original_path, + psf_path=psf_path, + downsample=config.files.downsample, + ) + dataset.psf = dataset.psf.to(device) + # train-test split as in https://waller-lab.github.io/LenslessLearning/dataset.html + # first 1000 files for test, the rest for training + train_indices = dataset.allowed_idx[dataset.allowed_idx > 1000] + test_indices = dataset.allowed_idx[dataset.allowed_idx <= 1000] + if config.files.n_files is not None: + train_indices = train_indices[: config.files.n_files] + test_indices = test_indices[: config.files.n_files] + + train_set = Subset(dataset, train_indices) + test_set = Subset(dataset, test_indices) + + # -- if learning mask + mask = prep_trainable_mask(config, dataset.psf) + if mask is not None: + # plot initial PSF + psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] + if config.trainable_mask.grayscale: + psf_np = psf_np[:, :, -1] + + save_image(psf_np, os.path.join(save, "psf_initial.png")) + plot_image(psf_np, gamma=config.display.gamma) + plt.savefig(os.path.join(save, "psf_initial_plot.png")) + + psf = dataset.psf - # convert psf from BGR to RGB - if config.files.dataset in ["DiffuserCam"]: - psf = psf[..., [2, 1, 0]] + else: - # if using a portrait dataset rotate the PSF - if flip: - psf = torch.rot90(psf, dims=[0, 1]) + train_set, test_set, mask = simulate_dataset(config) + psf = train_set.psf - disp = config.display.disp - if disp < 0: - disp = None + assert train_set is not None + assert psf is not None - save = config.save - if save: - save = os.getcwd() + print("Train test size : ", len(train_set)) + print("Test test size : ", len(test_set)) start_time = time.time() @@ -180,6 +282,7 @@ def train_unrolled( config.reconstruction.post_process.depth, device=device, ) + # create reconstruction algorithm if config.reconstruction.method == "unrolled_fista": recon = UnrolledFISTA( @@ -191,7 +294,6 @@ def train_unrolled( pre_process=pre_process, post_process=post_process, ).to(device) - n_iter = config.reconstruction.unrolled_fista.n_iter elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( psf, @@ -203,7 +305,6 @@ def train_unrolled( pre_process=pre_process, post_process=post_process, ).to(device) - n_iter = config.reconstruction.unrolled_admm.n_iter else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") @@ -215,196 +316,33 @@ def train_unrolled( algorithm_name += "_" + post_process_name # print number of parameters - print(f"Training model with {sum(p.numel() for p in recon.parameters())} parameters") - # transform from BGR to RGB - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - - # load dataset and create dataloader - if config.files.dataset == "DiffuserCam": - # Use a ParallelDataset - from lensless.eval.benchmark import ParallelDataset - - data_path = os.path.join(get_original_cwd(), "data", "DiffuserCam") - dataset = ParallelDataset( - root_dir=data_path, - n_files=config.files.n_files, - background=background, - psf=psf, - lensless_fn="diffuser_images", - lensed_fn="ground_truth_lensed", - downsample=config.simulation.downsample, - transform_lensless=transform_BRG2RGB, - transform_lensed=transform_BRG2RGB, - ) - data_loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config.training.batch_size, - shuffle=True, - pin_memory=(device != "cpu"), - ) - else: - # Use a simulated dataset - data_loader = simulate_dataset(config, psf) + n_param = sum(p.numel() for p in recon.parameters()) + if mask is not None: + n_param += sum(p.numel() for p in mask.parameters()) + log.info(f"Training model with {n_param} parameters") print(f"Setup time : {time.time() - start_time} s") + print(f"PSF shape : {psf.shape}") + trainer = Trainer( + recon=recon, + train_dataset=train_set, + test_dataset=test_set, + mask=mask, + batch_size=config.training.batch_size, + loss=config.loss, + lpips=config.lpips, + l1_mask=config.trainable_mask.L1_strength, + optimizer=config.optimizer.type, + optimizer_lr=config.optimizer.lr, + slow_start=config.training.slow_start, + skip_NAN=config.training.skip_NAN, + algorithm_name=algorithm_name, + metric_for_best_model=config.training.metric_for_best_model, + save_every=config.training.save_every, + gamma=config.display.gamma, + ) - start_time = time.time() - - # loss - if config.loss == "l2": - Loss = torch.nn.MSELoss() - elif config.loss == "l1": - Loss = torch.nn.L1Loss() - else: - raise ValueError(f"Unsuported loss : {config.loss}") - - # Lpips loss - if config.lpips: - try: - import lpips - - loss_lpips = lpips.LPIPS(net="vgg").to(device) - except ImportError: - return ImportError( - "lpips package is need for LPIPS loss. Install using : pip install lpips" - ) - - # optimizer - if config.optimizer.type == "Adam": - # the parameters of the base model and non torch.Module process must be added separatly - parameters = [{"params": recon.parameters()}] - optimizer = torch.optim.Adam(parameters, lr=config.optimizer.lr) - else: - raise ValueError(f"Unsuported optimizer : {config.optimizer.type}") - # Scheduler - if config.training.slow_start: - - def learning_rate_function(epoch): - if epoch == 0: - return config.training.slow_start - elif epoch == 1: - return math.sqrt(config.training.slow_start) - else: - return 1 - - else: - - def learning_rate_function(epoch): - return 1 - - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=learning_rate_function) - - metrics = { - "LOSS": [], - "MSE": [], - "MAE": [], - "LPIPS_Vgg": [], - "LPIPS_Alex": [], - "PSNR": [], - "SSIM": [], - "ReconstructionError": [], - "n_iter": n_iter, - "algorithm": algorithm_name, - } - - # Backward hook that detect NAN in the gradient and print the layer weights - if not config.training.skip_NAN: - - def detect_nan(grad): - if torch.isnan(grad).any(): - print(grad, flush=True) - for name, param in recon.named_parameters(): - if param.requires_grad: - print(name, param) - raise ValueError("Gradient is NaN") - return grad - - for param in recon.parameters(): - if param.requires_grad: - param.register_hook(detect_nan) - if param.requires_grad: - param.register_hook(detect_nan) - - # Training loop - for epoch in range(config.training.epoch): - print(f"Epoch {epoch} with learning rate {scheduler.get_last_lr()}") - mean_loss = 0.0 - i = 1.0 - pbar = tqdm(data_loader) - for X, y in pbar: - # send to device - X = X.to(device) - y = y.to(device) - if X.shape[3] == 3: - X = X - y = y - - y_pred = recon.batch_call(X.to(device)) - # normalizing each output - eps = 1e-12 - y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps - y_pred = y_pred / y_pred_max - - # normalizing y - y = y.to(device) - y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps - y = y / y_max - - if i % disp == 1 and config.display.plot: - img_pred = y_pred[0, 0].cpu().detach().numpy() - img_truth = y[0, 0].cpu().detach().numpy() - - plt.imshow(img_pred) - plt.savefig(f"y_pred_{i-1}.png") - plt.imshow(img_truth) - plt.savefig(f"y_{i-1}.png") - - optimizer.zero_grad(set_to_none=True) - # convert to CHW for loss and remove depth - y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) - y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) - - loss_v = Loss(y_pred, y) - if config.lpips: - # value for LPIPS needs to be in range [-1, 1] - loss_v = loss_v + config.lpips * torch.mean(loss_lpips(2 * y_pred - 1, 2 * y - 1)) - loss_v.backward() - torch.nn.utils.clip_grad_norm_(recon.parameters(), 1.0) - - # if any gradient is NaN, skip training step - is_NAN = False - for param in recon.parameters(): - if torch.isnan(param.grad).any(): - is_NAN = True - break - if is_NAN: - print("NAN detected in gradiant, skipping training step") - i += 1 - continue - optimizer.step() - - mean_loss += (loss_v.item() - mean_loss) * (1 / i) - pbar.set_description(f"loss : {mean_loss}") - i += 1 - - # benchmarking - current_metrics = benchmark(recon, benchmark_dataset, batchsize=10) - # update metrics with current metrics - metrics["LOSS"].append(mean_loss) - for key in current_metrics: - metrics[key].append(current_metrics[key]) - - # Update learning rate - scheduler.step() - - print(f"Train time : {time.time() - start_time} s") - - # save dictionary metrics to file with json - with open(os.path.join(save, "metrics.json"), "w") as f: - json.dump(metrics, f) - - # save pytorch model recon - torch.save(recon.state_dict(), "recon.pt") + trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) if __name__ == "__main__": diff --git a/scripts/sim/dataset.py b/scripts/sim/dataset.py index 2c08ba71..263d01f2 100644 --- a/scripts/sim/dataset.py +++ b/scripts/sim/dataset.py @@ -32,7 +32,7 @@ def simulate(config): if not os.path.isdir(dataset): print(f"No dataset found at {dataset}") try: - from torchvision.datasets.utils import download_and_extract_archive, download_url + from torchvision.datasets.utils import download_and_extract_archive except ImportError: exit() msg = "Do you want to download the sample CelebA dataset (764KB)?" diff --git a/scripts/sim/digicam_psf.py b/scripts/sim/digicam_psf.py new file mode 100644 index 00000000..d0e0636b --- /dev/null +++ b/scripts/sim/digicam_psf.py @@ -0,0 +1,154 @@ +import numpy as np +import os +import time +import hydra +import torch +from hydra.utils import to_absolute_path +import matplotlib.pyplot as plt +from slm_controller import slm +from lensless.utils.io import save_image, get_dtype, load_psf +from lensless.utils.plot import plot_image +from lensless.hardware.sensor import VirtualSensor +from lensless.hardware.slm import get_programmable_mask, get_intensity_psf +from waveprop.devices import slm_dict +from PIL import Image + + +@hydra.main(version_base=None, config_path="../../configs", config_name="sim_digicam_psf") +def digicam_psf(config): + + output_folder = os.getcwd() + + fp = to_absolute_path(config.digicam.pattern) + bn = os.path.basename(fp).split(".")[0] + + # digicam config + ap_center = np.array(config.digicam.ap_center) + ap_shape = np.array(config.digicam.ap_shape) + rotate_angle = config.digicam.rotate + slm_param = slm_dict[config.digicam.slm] + sensor = VirtualSensor.from_name(config.digicam.sensor) + + # simulation parameters + scene2mask = config.sim.scene2mask + mask2sensor = config.sim.mask2sensor + + torch_device = config.torch_device + dtype = get_dtype(config.dtype, config.use_torch) + + """ + Load pattern + """ + pattern = np.load(fp) + + # -- apply aperture + aperture = np.zeros(pattern.shape, dtype=np.uint8) + top_left = np.array(ap_center) - np.array(ap_shape) // 2 + bottom_right = top_left + np.array(ap_shape) + aperture[:, top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]] = 1 + pattern = pattern * aperture + + # -- extract aperture region + idx_1 = ap_center[0] - ap_shape[0] // 2 + idx_2 = ap_center[1] - ap_shape[1] // 2 + + pattern_sub = pattern[ + :, + idx_1 : idx_1 + ap_shape[0], + idx_2 : idx_2 + ap_shape[1], + ] + print("Controllable region shape: ", pattern_sub.shape) + print("Total number of pixels: ", np.prod(pattern_sub.shape)) + + # -- plot full + s = slm.create(config.digicam.slm) + s.set_preview(True) + s.imshow(pattern) + plt.savefig(os.path.join(output_folder, "pattern.png")) + + # -- plot sub pattern + plt.imshow(pattern_sub.transpose(1, 2, 0)) + plt.savefig(os.path.join(output_folder, "pattern_sub.png")) + + """ + Simulate PSF + """ + start_time = time.time() + slm_vals = pattern_sub / 255.0 + + if config.digicam.slm == "adafruit": + # flatten color channel along rows + slm_vals = slm_vals.reshape((-1, slm_vals.shape[-1]), order="F") + + if config.use_torch: + slm_vals = torch.from_numpy(slm_vals).to(device=torch_device, dtype=dtype) + else: + slm_vals = slm_vals.astype(dtype) + + mask = get_programmable_mask( + vals=slm_vals, + sensor=sensor, + slm_param=slm_param, + rotate=rotate_angle, + flipud=config.sim.flipud, + ) + + # -- plot mask + if config.use_torch: + mask_np = mask.cpu().detach().numpy() + else: + mask_np = mask.copy() + mask_np = np.transpose(mask_np, (1, 2, 0)) + plt.imshow(mask_np) + plt.savefig(os.path.join(output_folder, "mask.png")) + + # -- propagate to sensor + psf_in = get_intensity_psf( + mask=mask, + sensor=sensor, + waveprop=config.sim.waveprop, + scene2mask=scene2mask, + mask2sensor=mask2sensor, + ) + + # -- plot PSF + if config.use_torch: + psf_in_np = psf_in.cpu().detach().numpy() + else: + psf_in_np = psf_in.copy() + psf_in_np = np.transpose(psf_in_np, (1, 2, 0)) + + # plot + psf_meas = None + if config.digicam.psf is not None: + fp_psf = to_absolute_path(config.digicam.psf) + if os.path.exists(fp_psf): + psf_meas = load_psf(fp_psf) + else: + print("Could not load PSF image from: ", fp_psf) + + fp = os.path.join(output_folder, "psf_plot.png") + if psf_meas is not None: + _, ax = plt.subplots(1, 2) + ax[0].imshow(psf_in_np) + ax[0].set_title("Simulated") + plot_image(psf_meas, gamma=config.digicam.gamma, normalize=True, ax=ax[1]) + # ax[1].imshow(psf_meas) + ax[1].set_title("Measured") + plt.savefig(fp) + else: + plt.imshow(psf_in_np) + plt.savefig(fp) + + # save PSF as png + fp = os.path.join(output_folder, f"{bn}_SIM_psf.png") + save_image(psf_in_np, fp) + + proc_time = time.time() - start_time + print(f"\nProcessing time: {proc_time:.2f} seconds") + + print(f"\nFiles saved to : {output_folder}") + + +if __name__ == "__main__": + digicam_psf() diff --git a/scripts/sim/mask_single_file.py b/scripts/sim/mask_single_file.py index e8a741b5..8513e75c 100644 --- a/scripts/sim/mask_single_file.py +++ b/scripts/sim/mask_single_file.py @@ -19,6 +19,11 @@ python scripts/sim/mask_single_file.py mask.type=MURA mask.n_bits=99 simulation.flatcam=True recon.algo=tikhonov ``` +Using Torch +``` +python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=True recon.algo=tikhonov use_torch=True +``` + Simulate FlatCam with PSF simulation and Tikhonov reconstuction: ``` python scripts/sim/mask_single_file.py mask.type=MLS simulation.flatcam=False recon.algo=tikhonov @@ -56,6 +61,7 @@ import os from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture from lensless.recon.tikhonov import CodedApertureReconstruction +import torch @hydra.main(version_base=None, config_path="../../configs", config_name="mask_sim_single") @@ -107,6 +113,9 @@ def simulate(config): # 2) simulate measurement image = load_image(fp, verbose=True) / 255 + if config.use_torch: + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).float() flatcam_sim = config.simulation.flatcam if flatcam_sim and mask_type.upper() not in ["MURA", "MLS"]: @@ -116,17 +125,29 @@ def simulate(config): flatcam_sim = False # use far field simulator to get correct object plane sizing + psf = mask.psf + if config.use_torch: + psf = psf.transpose(2, 0, 1) + psf = torch.from_numpy(psf).float() + simulator = FarFieldSimulator( - psf=mask.psf, + psf=psf, object_height=object_height, scene2mask=scene2mask, mask2sensor=mask2sensor, sensor=sensor, snr_db=snr_db, max_val=max_val, + is_torch=config.use_torch, ) image_plane, object_plane = simulator.propagate(image, return_object_plane=True) + # channels as last dimension + if config.use_torch: + image_plane = image_plane.permute(1, 2, 0) + object_plane = object_plane.permute(1, 2, 0) + image = image.permute(1, 2, 0) + if image_format == "grayscale": image_plane = rgb2gray(image_plane) object_plane = rgb2gray(object_plane) @@ -178,6 +199,12 @@ def simulate(config): else: raise ValueError(f"Reconstruction algorithm {config.recon.algo} not recognized.") + # back to numpy for evaluation and plotting + if config.use_torch: + recovered = recovered.numpy() + object_plane = object_plane.numpy() + image_plane = image_plane.numpy() + # 4) evaluate if image_format == "grayscale": object_plane = object_plane[:, :, 0] @@ -218,7 +245,7 @@ def simulate(config): ax[4].set_title("Reconstruction") for a in ax: - a.set_xticks([]), a.set_yticks([]) + a.set_axis_off() plt.tight_layout() plt.savefig("result.png") diff --git a/setup.py b/setup.py index 20c07f7a..392fc7fe 100644 --- a/setup.py +++ b/setup.py @@ -6,15 +6,16 @@ exec(f.read()) assert __version__ is not None -with open("README.rst", "r", encoding="utf-8") as fh: - long_description = fh.read() +# with open("README.rst", "r", encoding="utf-8") as fh: +# long_description = fh.read() +long_description = "See the documentation at https://lensless.readthedocs.io/en/latest/" setuptools.setup( name="lensless", version=__version__, author="Eric Bezzam", author_email="ebezzam@gmail.com", - description="Package to control and image with a lensless camera running on a Raspberry Pi.", + description="All-in-one package for lensless imaging: design, simulation, measurement, reconstruction.", long_description=long_description, long_description_content_type="text/x-rst", url="https://github.com/LCAV/LenslessPiCam", @@ -32,6 +33,7 @@ "matplotlib>=3.4.2", "rawpy>=0.16.0", "paramiko>=3.2.0", + "hydra-core", ], extra_requires={"dev": ["pudb", "black"]}, ) diff --git a/test/test_io.py b/test/test_io.py index 16823e1f..5c2f8884 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,5 +1,4 @@ -from lensless.utils.io import load_data -import numpy as np +from lensless.utils.io import load_data, rgb2gray psf_fp = "data/psf/tape_rgb.png" data_fp = "data/raw_data/thumbs_up_rgb.png" @@ -26,4 +25,31 @@ def test_load_data(): assert data.dtype == dtype, dtype -test_load_data() +def test_rgb2gray(): + for is_torch in [True, False]: + psf, data = load_data( + psf_fp=psf_fp, + data_fp=data_fp, + downsample=downsample, + plot=False, + dtype="float32", + torch=is_torch, + ) + data = data[0] # drop first depth dimension + + # try with 4D + psf_gray = rgb2gray(psf, keepchanneldim=False) + assert len(psf_gray.shape) == 3 + psf_gray = rgb2gray(psf, keepchanneldim=True) + assert len(psf_gray.shape) == 4 + + # try with 3D + data_gray = rgb2gray(data, keepchanneldim=False) + assert len(data_gray.shape) == 2 + data_gray = rgb2gray(data, keepchanneldim=True) + assert len(data_gray.shape) == 3 + + +if __name__ == "__main__": + test_load_data() + test_rgb2gray()