From 937a59ccb43e95d3bef5c81d4a70c4fb7809f45d Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 16 Jul 2024 23:47:09 +0000 Subject: [PATCH] Add noise, fix return for extract roi. --- lensless/utils/dataset.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 9d78a895..76c6e6fa 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -30,6 +30,7 @@ from lensless.hardware.sensor import sensor_dict, SensorParam from scipy.ndimage import rotate import warnings +from waveprop.noise import add_shot_noise from PIL import Image @@ -169,8 +170,6 @@ def __getitem__(self, idx): # add noise if self.input_snr is not None: - from waveprop.noise import add_shot_noise - lensless = add_shot_noise(lensless, self.input_snr) # flip image x and y if needed @@ -1037,6 +1036,7 @@ def __init__( sensor="rpi_hq", slm="adafruit", simulation_config=dict(), + snr_db=40, **kwargs, ): """ @@ -1131,6 +1131,7 @@ def __init__( self.crop = None self.random_flip = None self.flipud = flipud + self.snr_db = snr_db self.display_res = display_res self.alignment = None @@ -1193,6 +1194,11 @@ def _get_images_pair(self, idx): mask_label = self.dataset[idx]["mask_label"] self.convolver.set_psf(self.psf[mask_label]) lensless = self.convolver.convolve(lensed) + + # add noise + if self.snr_db is not None: + lensless = add_shot_noise(lensless, self.snr_db) + if lensless.max() > 1: print("CLIPPING!") lensless /= lensless.max() @@ -1252,6 +1258,8 @@ def extract_roi(self, reconstruction, lensed=None, axis=(1, 2), **kwargs): else: reconstruction = reconstruction[0] + if lensed is None: + return reconstruction return reconstruction, lensed