diff --git a/configs/train_mirflickr_tape_ambient.yaml b/configs/train_mirflickr_tape_ambient.yaml index 458d055f..1495d9e1 100644 --- a/configs/train_mirflickr_tape_ambient.yaml +++ b/configs/train_mirflickr_tape_ambient.yaml @@ -11,7 +11,14 @@ files: dataset: Lensless/TapeCam-Mirflickr-Ambient image_res: [600, 600] +reconstruction: + direct_background_subtraction: True + alignment: # when there is no downsampling top_left: [85, 185] # height, width height: 178 + +optimizer: + type: AdamW + cosine_decay_warmup: True diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 238bfaa0..1c9b6469 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -90,6 +90,9 @@ reconstruction: init_pre: True # if `init_processors`, set pre-procesor is available init_post: True # if `init_processors`, set post-procesor is available + # background subtraction (if dataset has corresponding background images) + direct_background_subtraction: False + # Hyperparameters for each method unrolled_fista: # for unrolled_fista # Number of iterations diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 41d43ab4..4b9ecc4f 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -121,8 +121,11 @@ def benchmark( flip_lr = None flip_ud = None + background = None lensless = batch[0].to(device) lensed = batch[1].to(device) + if dataset.measured_bg: + background = batch[-1].to(device) if dataset.multimask or dataset.random_flip: psfs = batch[2] psfs = psfs.to(device) @@ -146,11 +149,12 @@ def benchmark( plot=False, save=False, output_intermediate=unrolled_output_factor or pre_process_aux, + background=background, **kwargs, ) else: - prediction = model.forward(lensless, psfs, **kwargs) + prediction = model.forward(lensless, psfs, background=background, **kwargs) if unrolled_output_factor or pre_process_aux: pre_process_out = prediction[2] diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index f106633d..2e764631 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -58,6 +58,7 @@ def __init__( legacy_denoiser=False, compensation=None, compensation_residual=True, + direct_background_subtraction=False, **kwargs, ): """ @@ -104,6 +105,7 @@ def __init__( self.skip_unrolled = skip_unrolled self.skip_pre = skip_pre self.skip_post = skip_post + self.direct_background_subtraction = direct_background_subtraction self.return_intermediate = return_intermediate self.compensation_branch = compensation if compensation is not None: @@ -216,7 +218,7 @@ def unfreeze_post_process(self): for param in self.post_process_model.parameters(): param.requires_grad = True - def forward(self, batch, psfs=None): + def forward(self, batch, psfs=None, background=None): """ Method for performing iterative reconstruction on a batch of images. This implementation is a properly vectorized implementation of FISTA. @@ -237,6 +239,12 @@ def forward(self, batch, psfs=None): assert len(self._data.shape) == 5, "batch must be of shape (N, D, C, H, W)" batch_size = batch.shape[0] + if self.direct_background_subtraction: + assert ( + background is not None + ), "If direct_background_subtraction is True, background must be defined." + self._data = self._data - background + if psfs is not None: # assert same shape assert psfs.shape == batch.shape, "psfs must have the same shape as batch" diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 9cfb9c53..c013c22c 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -879,18 +879,18 @@ def train_epoch(self, data_loader): # get batch flip_lr = None flip_ud = None - if self.train_random_flip: - X, y, psfs, flip_lr, flip_ud = batch - psfs = psfs.to(self.device) - elif self.train_multimask: - X, y, psfs = batch - psfs = psfs.to(self.device) + background = None + X = batch[0].to(self.device) + y = batch[1].to(self.device) + if self.background: + background = batch[-1].to(self.device) + if self.train_random_flip or self.train_multimask: + psfs = batch[2].to(self.device) else: - if self.background: - X, y, background = batch - else: - X, y = batch psfs = None + if self.train_random_flip: + flip_lr = batch[3] + flip_ud = batch[4] random_rotate = False if self.random_rotate: @@ -904,17 +904,13 @@ def train_epoch(self, data_loader): else: psfs = rotate_HWC(psfs, random_rotate) - # send to device - X = X.to(self.device) - y = y.to(self.device) - # update psf according to mask if self.use_mask: self.recon._set_psf(self.mask.get_psf().to(self.device)) # forward pass # torch.autograd.set_detect_anomaly(True) # for debugging - y_pred = self.recon.forward(batch=X.unsqueeze(1), psfs=psfs) + y_pred = self.recon.forward(batch=X, psfs=psfs, background=background) if self.unrolled_output_factor or self.pre_proc_aux: y_pred, camera_inv_out, pre_proc_out = y_pred[0], y_pred[1], y_pred[2] diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 5832115d..bcf14817 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1330,9 +1330,9 @@ def __init__( simulation_config : dict, optional Simulation parameters for PSF if using a mask pattern. bg_snr_range : list, optional - List [low, high] of range of possible SNRs for which to add the background. Used in conjunction with 'bg' + List [low, high] of range of possible SNRs for which to add the background. Used in conjunction with 'bg'. bg_fp : string, optional - File path of background to add to the data for simulating a measurement in ambient light + File path of background to add to the data for simulating a measurement in ambient light. """ diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 3d847cd7..09547d47 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -492,6 +492,9 @@ def train_learned(config): if name1 in dict_params2_post: dict_params2_post[name1].data.copy_(param1.data) + if config.reconstruction.direct_background_subtraction: + assert test_set.measured_bg and train_set.measured_bg + # create reconstruction algorithm if config.reconstruction.init is not None: assert config.reconstruction.init_processors is None @@ -526,6 +529,7 @@ def train_learned(config): ), compensation=config.reconstruction.compensation, compensation_residual=config.reconstruction.compensation_residual, + direct_background_subtraction=config.reconstruction.direct_background_subtraction, ) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( @@ -543,6 +547,7 @@ def train_learned(config): ), compensation=config.reconstruction.compensation, compensation_residual=config.reconstruction.compensation_residual, + direct_background_subtraction=config.reconstruction.direct_background_subtraction, ) elif config.reconstruction.method == "trainable_inv": assert config.trainable_mask.mask_type == "TrainablePSF" @@ -554,6 +559,7 @@ def train_learned(config): return_intermediate=( True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False ), + direct_background_subtraction=config.reconstruction.direct_background_subtraction, ) elif config.reconstruction.method == "multi_wiener": @@ -563,6 +569,8 @@ def train_learned(config): else: psf_channels = 3 + assert config.reconstruction.direct_background_subtraction is False, "Not supported" + recon = MultiWiener( in_channels=3, out_channels=3,