Skip to content

Commit

Permalink
Add flag for direct background subtraction.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Aug 12, 2024
1 parent 542fe9a commit 31637ea
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 19 deletions.
7 changes: 7 additions & 0 deletions configs/train_mirflickr_tape_ambient.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ files:
dataset: Lensless/TapeCam-Mirflickr-Ambient
image_res: [600, 600]

reconstruction:
direct_background_subtraction: True

alignment:
# when there is no downsampling
top_left: [85, 185] # height, width
height: 178

optimizer:
type: AdamW
cosine_decay_warmup: True
3 changes: 3 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ reconstruction:
init_pre: True # if `init_processors`, set pre-procesor is available
init_post: True # if `init_processors`, set post-procesor is available

# background subtraction (if dataset has corresponding background images)
direct_background_subtraction: False

# Hyperparameters for each method
unrolled_fista: # for unrolled_fista
# Number of iterations
Expand Down
6 changes: 5 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,11 @@ def benchmark(

flip_lr = None
flip_ud = None
background = None
lensless = batch[0].to(device)
lensed = batch[1].to(device)
if dataset.measured_bg:
background = batch[-1].to(device)
if dataset.multimask or dataset.random_flip:
psfs = batch[2]
psfs = psfs.to(device)
Expand All @@ -146,11 +149,12 @@ def benchmark(
plot=False,
save=False,
output_intermediate=unrolled_output_factor or pre_process_aux,
background=background,
**kwargs,
)

else:
prediction = model.forward(lensless, psfs, **kwargs)
prediction = model.forward(lensless, psfs, background=background, **kwargs)

if unrolled_output_factor or pre_process_aux:
pre_process_out = prediction[2]
Expand Down
10 changes: 9 additions & 1 deletion lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
legacy_denoiser=False,
compensation=None,
compensation_residual=True,
direct_background_subtraction=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
self.skip_unrolled = skip_unrolled
self.skip_pre = skip_pre
self.skip_post = skip_post
self.direct_background_subtraction = direct_background_subtraction
self.return_intermediate = return_intermediate
self.compensation_branch = compensation
if compensation is not None:
Expand Down Expand Up @@ -216,7 +218,7 @@ def unfreeze_post_process(self):
for param in self.post_process_model.parameters():
param.requires_grad = True

def forward(self, batch, psfs=None):
def forward(self, batch, psfs=None, background=None):
"""
Method for performing iterative reconstruction on a batch of images.
This implementation is a properly vectorized implementation of FISTA.
Expand All @@ -237,6 +239,12 @@ def forward(self, batch, psfs=None):
assert len(self._data.shape) == 5, "batch must be of shape (N, D, C, H, W)"
batch_size = batch.shape[0]

if self.direct_background_subtraction:
assert (
background is not None
), "If direct_background_subtraction is True, background must be defined."
self._data = self._data - background

if psfs is not None:
# assert same shape
assert psfs.shape == batch.shape, "psfs must have the same shape as batch"
Expand Down
26 changes: 11 additions & 15 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,18 +879,18 @@ def train_epoch(self, data_loader):
# get batch
flip_lr = None
flip_ud = None
if self.train_random_flip:
X, y, psfs, flip_lr, flip_ud = batch
psfs = psfs.to(self.device)
elif self.train_multimask:
X, y, psfs = batch
psfs = psfs.to(self.device)
background = None
X = batch[0].to(self.device)
y = batch[1].to(self.device)
if self.background:
background = batch[-1].to(self.device)
if self.train_random_flip or self.train_multimask:
psfs = batch[2].to(self.device)
else:
if self.background:
X, y, background = batch
else:
X, y = batch
psfs = None
if self.train_random_flip:
flip_lr = batch[3]
flip_ud = batch[4]

random_rotate = False
if self.random_rotate:
Expand All @@ -904,17 +904,13 @@ def train_epoch(self, data_loader):
else:
psfs = rotate_HWC(psfs, random_rotate)

# send to device
X = X.to(self.device)
y = y.to(self.device)

# update psf according to mask
if self.use_mask:
self.recon._set_psf(self.mask.get_psf().to(self.device))

# forward pass
# torch.autograd.set_detect_anomaly(True) # for debugging
y_pred = self.recon.forward(batch=X.unsqueeze(1), psfs=psfs)
y_pred = self.recon.forward(batch=X, psfs=psfs, background=background)
if self.unrolled_output_factor or self.pre_proc_aux:
y_pred, camera_inv_out, pre_proc_out = y_pred[0], y_pred[1], y_pred[2]

Expand Down
4 changes: 2 additions & 2 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,9 +1330,9 @@ def __init__(
simulation_config : dict, optional
Simulation parameters for PSF if using a mask pattern.
bg_snr_range : list, optional
List [low, high] of range of possible SNRs for which to add the background. Used in conjunction with 'bg'
List [low, high] of range of possible SNRs for which to add the background. Used in conjunction with 'bg'.
bg_fp : string, optional
File path of background to add to the data for simulating a measurement in ambient light
File path of background to add to the data for simulating a measurement in ambient light.
"""

Expand Down
8 changes: 8 additions & 0 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ def train_learned(config):
if name1 in dict_params2_post:
dict_params2_post[name1].data.copy_(param1.data)

if config.reconstruction.direct_background_subtraction:
assert test_set.measured_bg and train_set.measured_bg

# create reconstruction algorithm
if config.reconstruction.init is not None:
assert config.reconstruction.init_processors is None
Expand Down Expand Up @@ -526,6 +529,7 @@ def train_learned(config):
),
compensation=config.reconstruction.compensation,
compensation_residual=config.reconstruction.compensation_residual,
direct_background_subtraction=config.reconstruction.direct_background_subtraction,
)
elif config.reconstruction.method == "unrolled_admm":
recon = UnrolledADMM(
Expand All @@ -543,6 +547,7 @@ def train_learned(config):
),
compensation=config.reconstruction.compensation,
compensation_residual=config.reconstruction.compensation_residual,
direct_background_subtraction=config.reconstruction.direct_background_subtraction,
)
elif config.reconstruction.method == "trainable_inv":
assert config.trainable_mask.mask_type == "TrainablePSF"
Expand All @@ -554,6 +559,7 @@ def train_learned(config):
return_intermediate=(
True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False
),
direct_background_subtraction=config.reconstruction.direct_background_subtraction,
)
elif config.reconstruction.method == "multi_wiener":

Expand All @@ -563,6 +569,8 @@ def train_learned(config):
else:
psf_channels = 3

assert config.reconstruction.direct_background_subtraction is False, "Not supported"

recon = MultiWiener(
in_channels=3,
out_channels=3,
Expand Down

0 comments on commit 31637ea

Please sign in to comment.