Skip to content

Commit

Permalink
Add option for adding noise.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jan 16, 2024
1 parent 0fad9ef commit af4942f
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 5 deletions.
3 changes: 2 additions & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ files:
downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution
test_size: 0.15

input_snr: null # adding shot noise at input (for measured dataset) at this SNR in dB
vertical_shift: null
horizontal_shift: null
crop: null
Expand Down Expand Up @@ -126,7 +127,7 @@ training:
skip_NAN: True
clip_grad: 1.0

crop_preloss: True # crop region for computing loss
crop_preloss: False # crop region for computing loss, files.crop should be set

optimizer:
type: Adam
Expand Down
10 changes: 10 additions & 0 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
flip=False,
transform_lensless=None,
transform_lensed=None,
input_snr=None,
**kwargs,
):
"""
Expand All @@ -72,11 +73,14 @@ def __init__(
Transform to apply to the lensless images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision).
transform_lensed : PyTorch Transform or None, optional
Transform to apply to the lensed images, by default ``None``. Note that this transform is applied on HWC images (different from torchvision).
input_snr : float, optional
If not ``None``, Poisson noise is added to the lensless images to match the given SNR.
"""
if isinstance(indices, int):
indices = range(indices)
self.indices = indices
self.background = background
self.input_snr = input_snr
self.downsample = downsample
self.flip = flip
self.transform_lensless = transform_lensless
Expand Down Expand Up @@ -147,6 +151,12 @@ def __getitem__(self, idx):
if self.background is not None:
lensless = lensless - self.background

# add noise
if self.input_snr is not None:
from waveprop.noise import add_shot_noise

lensless = add_shot_noise(lensless, self.input_snr)

# flip image x and y if needed
if self.flip:
lensless = torch.rot90(lensless, dims=(-3, -2), k=2)
Expand Down
2 changes: 1 addition & 1 deletion mask_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sympy>=1.11.1
perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016
waveprop>=0.0.9
waveprop>=0.0.10
slm_controller @ git+https://github.com/ebezzam/slm-controller.git
2 changes: 1 addition & 1 deletion recon_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ lpips==0.1.4
pylops==1.18.0
scikit-image>=0.19.0rc0
click>=8.0.1
waveprop>=0.0.9 # for simulation
waveprop>=0.0.10 # for simulation

# Library for learning algorithm
torch >= 2.0.0
Expand Down
8 changes: 6 additions & 2 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def train_unrolled(config):
dataset_dir=original_path,
psf_path=psf_path,
downsample=config.files.downsample,
input_snr=config.files.input_snr,
)
dataset.psf = dataset.psf.to(device)
# train-test split as in https://waller-lab.github.io/LenslessLearning/dataset.html
Expand Down Expand Up @@ -404,6 +405,7 @@ def train_unrolled(config):
horizontal_shift=config.files.horizontal_shift,
simulation_config=config.simulation,
crop=config.files.crop,
input_snr=config.files.input_snr,
)
crop = dataset.crop
dataset.psf = dataset.psf.to(device)
Expand Down Expand Up @@ -486,6 +488,7 @@ def train_unrolled(config):

# -- plot lensed and res on top of each other
if config.training.crop_preloss:
assert crop is not None

res_np = res_np[
crop["vertical"][0] : crop["vertical"][1],
Expand All @@ -511,15 +514,16 @@ def train_unrolled(config):

start_time = time.time()

# Load pre process model
# Load pre-process model
pre_process, pre_process_name = create_process_network(
config.reconstruction.pre_process.network,
config.reconstruction.pre_process.depth,
nc=config.reconstruction.pre_process.nc,
device=device,
)
pre_proc_delay = config.reconstruction.pre_process.delay
# Load post process model

# Load post-process model
post_process, post_process_name = create_process_network(
config.reconstruction.post_process.network,
config.reconstruction.post_process.depth,
Expand Down

0 comments on commit af4942f

Please sign in to comment.