diff --git a/configs/train_tapecam_measured_background.yaml b/configs/train_tapecam_measured_background.yaml new file mode 100644 index 00000000..b76b485f --- /dev/null +++ b/configs/train_tapecam_measured_background.yaml @@ -0,0 +1,19 @@ +# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape +defaults: + - train_mirflickr_tape + - _self_ + +wandb_project: +device_ids: + + +# Dataset +files: + dataset: Lensless/TapeCam-Mirflickr-Ambient + background_snr_range: [0,0] + image_res: [507, 380] + +alignment: + # when there is no downsampling + top_left: [45, 95] # height, width + height: 250 diff --git a/configs/train_tapecam_simulated_background.yaml b/configs/train_tapecam_simulated_background.yaml index 9335d626..c415e6cc 100644 --- a/configs/train_tapecam_simulated_background.yaml +++ b/configs/train_tapecam_simulated_background.yaml @@ -9,6 +9,5 @@ device_ids: # Dataset files: - dataset: Lensless/TapeCam-Mirflickr-Ambient - background_fp: + background_fp: "" background_snr_range: [0,0] \ No newline at end of file diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index ade69966..05dc4ebe 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1276,7 +1276,6 @@ def __init__( flip_lensed=False, downsample=1, downsample_lensed=1, - downsample_background=1, display_res=None, sensor="rpi_hq", slm="adafruit", @@ -1310,8 +1309,6 @@ def __init__( If True, lensless images and PSF are rotated 180 degrees. Lensed/original image is not rotated! By default False. downsample : float, optional Downsample factor of the lensless images, by default 1. - downsample : float, optional - Downsample factor of the background images, by default 1. downsample_lensed : float, optional Downsample factor of the lensed images, by default 1. display_res : tuple, optional @@ -1362,7 +1359,6 @@ def __init__( data_0 = self.dataset[0] self.downsample_lensless = downsample self.downsample_lensed = downsample_lensed - self.downsample_background = downsample_background lensless = np.array(data_0["lensless"]) if "ambient" in data_0.keys(): self.measured_bg = True @@ -1593,7 +1589,7 @@ def _get_images_pair(self, idx): background_np = ( resize( background_np, - factor=1 / self.downsample_background, + factor=1 / self.downsample, interpolation=cv2.INTER_NEAREST, ) if not None @@ -1638,32 +1634,18 @@ def _get_images_pair(self, idx): lensed = resize( lensed_np, shape=(*self.display_res, 3), interpolation=cv2.INTER_NEAREST ) - background = ( - resize(background_np, shape=(*self.display_res, 3), interpolation=cv2.INTER_NEAREST) - if not None - else None - ) elif self.downsample_lensed != 1.0: lensed = resize( lensed_np, factor=1 / self.downsample_lensed, interpolation=cv2.INTER_NEAREST, ) - background = ( - resize( - background_np, - factor=1 / self.downsample_lensed, - interpolation=cv2.INTER_NEAREST, - ) - if not None - else None - ) return lensless, lensed, background if background is not None else None def __getitem__(self, idx): lensless, lensed, background = self._get_images_pair(idx) - if not self.simulate_lensless: # TODO apply transformation to bg as well? + if not self.simulate_lensless: if self.rotate: lensless = torch.rot90(lensless, dims=(-3, -2), k=2) if self.flipud: @@ -1732,7 +1714,6 @@ def __getitem__(self, idx): # If measured background available in the dataset return it elif self.measured_bg: return_items.append(background) - # TODO push data to gpu in the training loop and hvae a flag that for the simple subtraction does that without pushing the bg to the gpu return return_items def extract_roi( diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 0aa65db7..7eb21862 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -395,8 +395,8 @@ def train_learned(config): rotate_angle, shift, ) - # save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png") - save_image(lensed, f"lensed_{_idx}.png") + save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png") + # save_image(lensed, f"lensed_{_idx}.png") if test_set.bg_sim is not None: # Reconstruct and plot background subtracted image reconstruct_save( @@ -660,14 +660,14 @@ def reconstruct_save( ): recon = ADMM(psf_recon) - # recon.set_data(lensless.to(psf_recon.device)) - recon.set_data(torch.from_numpy(lensless).to(psf_recon.device)) + recon.set_data(lensless.to(psf_recon.device)) + # recon.set_data(torch.from_numpy(lensless).to(psf_recon.device)) res = recon.apply(disp_iter=None, plot=False, n_iter=10) res_np = res[0].cpu().numpy() res_np = res_np / res_np.max() lensed_np = lensed[0] # .cpu().numpy() - lensless_np = lensless # [0]#.cpu().numpy() + lensless_np = lensless.cpu().numpy() # [0]#.cpu().numpy() save_image(lensless_np, f"lensless_raw_{_idx}.png") # -- plot lensed and res on top of each other