Skip to content

Commit

Permalink
Add noise, fix return for extract roi.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jul 16, 2024
1 parent 5263765 commit 937a59c
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lensless.hardware.sensor import sensor_dict, SensorParam
from scipy.ndimage import rotate
import warnings
from waveprop.noise import add_shot_noise
from PIL import Image


Expand Down Expand Up @@ -169,8 +170,6 @@ def __getitem__(self, idx):

# add noise
if self.input_snr is not None:
from waveprop.noise import add_shot_noise

lensless = add_shot_noise(lensless, self.input_snr)

# flip image x and y if needed
Expand Down Expand Up @@ -1037,6 +1036,7 @@ def __init__(
sensor="rpi_hq",
slm="adafruit",
simulation_config=dict(),
snr_db=40,
**kwargs,
):
"""
Expand Down Expand Up @@ -1131,6 +1131,7 @@ def __init__(
self.crop = None
self.random_flip = None
self.flipud = flipud
self.snr_db = snr_db

self.display_res = display_res
self.alignment = None
Expand Down Expand Up @@ -1193,6 +1194,11 @@ def _get_images_pair(self, idx):
mask_label = self.dataset[idx]["mask_label"]
self.convolver.set_psf(self.psf[mask_label])
lensless = self.convolver.convolve(lensed)

# add noise
if self.snr_db is not None:
lensless = add_shot_noise(lensless, self.snr_db)

if lensless.max() > 1:
print("CLIPPING!")
lensless /= lensless.max()
Expand Down Expand Up @@ -1252,6 +1258,8 @@ def extract_roi(self, reconstruction, lensed=None, axis=(1, 2), **kwargs):
else:
reconstruction = reconstruction[0]

if lensed is None:
return reconstruction
return reconstruction, lensed


Expand Down

0 comments on commit 937a59c

Please sign in to comment.