From af4942fd59071a0c8d37edffd881906f7905358e Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 16 Jan 2024 17:05:45 +0100 Subject: [PATCH] Add option for adding noise. --- configs/train_unrolledADMM.yaml | 3 ++- lensless/utils/dataset.py | 10 ++++++++++ mask_requirements.txt | 2 +- recon_requirements.txt | 2 +- scripts/recon/train_unrolled.py | 8 ++++++-- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 3ead25e4..883f4089 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -17,6 +17,7 @@ files: downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution test_size: 0.15 + input_snr: null # adding shot noise at input (for measured dataset) at this SNR in dB vertical_shift: null horizontal_shift: null crop: null @@ -126,7 +127,7 @@ training: skip_NAN: True clip_grad: 1.0 - crop_preloss: True # crop region for computing loss + crop_preloss: False # crop region for computing loss, files.crop should be set optimizer: type: Adam diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 6cd20cd0..7fef2e17 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -49,6 +49,7 @@ def __init__( flip=False, transform_lensless=None, transform_lensed=None, + input_snr=None, **kwargs, ): """ @@ -72,11 +73,14 @@ def __init__( Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). transform_lensed : PyTorch Transform or None, optional Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision). + input_snr : float, optional + If not ``None``, Poisson noise is added to the lensless images to match the given SNR. """ if isinstance(indices, int): indices = range(indices) self.indices = indices self.background = background + self.input_snr = input_snr self.downsample = downsample self.flip = flip self.transform_lensless = transform_lensless @@ -147,6 +151,12 @@ def __getitem__(self, idx): if self.background is not None: lensless = lensless - self.background + # add noise + if self.input_snr is not None: + from waveprop.noise import add_shot_noise + + lensless = add_shot_noise(lensless, self.input_snr) + # flip image x and y if needed if self.flip: lensless = torch.rot90(lensless, dims=(-3, -2), k=2) diff --git a/mask_requirements.txt b/mask_requirements.txt index 548d6559..43958461 100644 --- a/mask_requirements.txt +++ b/mask_requirements.txt @@ -1,4 +1,4 @@ sympy>=1.11.1 perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 -waveprop>=0.0.9 +waveprop>=0.0.10 slm_controller @ git+https://github.com/ebezzam/slm-controller.git \ No newline at end of file diff --git a/recon_requirements.txt b/recon_requirements.txt index 29fafbf0..0a2ff942 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -3,7 +3,7 @@ lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 click>=8.0.1 -waveprop>=0.0.9 # for simulation +waveprop>=0.0.10 # for simulation # Library for learning algorithm torch >= 2.0.0 diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 8b4b2867..8f49e19b 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -366,6 +366,7 @@ def train_unrolled(config): dataset_dir=original_path, psf_path=psf_path, downsample=config.files.downsample, + input_snr=config.files.input_snr, ) dataset.psf = dataset.psf.to(device) # train-test split as in https://waller-lab.github.io/LenslessLearning/dataset.html @@ -404,6 +405,7 @@ def train_unrolled(config): horizontal_shift=config.files.horizontal_shift, simulation_config=config.simulation, crop=config.files.crop, + input_snr=config.files.input_snr, ) crop = dataset.crop dataset.psf = dataset.psf.to(device) @@ -486,6 +488,7 @@ def train_unrolled(config): # -- plot lensed and res on top of each other if config.training.crop_preloss: + assert crop is not None res_np = res_np[ crop["vertical"][0] : crop["vertical"][1], @@ -511,7 +514,7 @@ def train_unrolled(config): start_time = time.time() - # Load pre process model + # Load pre-process model pre_process, pre_process_name = create_process_network( config.reconstruction.pre_process.network, config.reconstruction.pre_process.depth, @@ -519,7 +522,8 @@ def train_unrolled(config): device=device, ) pre_proc_delay = config.reconstruction.pre_process.delay - # Load post process model + + # Load post-process model post_process, post_process_name = create_process_network( config.reconstruction.post_process.network, config.reconstruction.post_process.depth,