Skip to content

Commit

Permalink
Now supports a dataset that also has background images
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanPetersTM committed Aug 9, 2024
1 parent 7dd98ce commit df1f8bd
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 21 deletions.
3 changes: 2 additions & 1 deletion configs/train_tapecam_simulated_background.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ device_ids:

# Dataset
files:
background_fp: ""
dataset: Lensless/TapeCam-Mirflickr-Ambient
background_fp:
background_snr_range: [0,0]
1 change: 1 addition & 0 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def benchmark(
)

else:
lensless = lensless.unsqueeze(1)
prediction = model.forward(lensless, psfs, **kwargs)

if unrolled_output_factor or pre_process_aux:
Expand Down
11 changes: 9 additions & 2 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,10 @@ def __init__(
self.train_random_flip = train_dataset.random_flip
self.random_rotate = random_rotate
self.random_shift = random_shift
if hasattr(train_dataset, "measured_bg"):
self.background = train_dataset.measured_bg
else:
self.background = False
if self.random_shift:
raise NotImplementedError("Random shift not implemented yet.")

Expand Down Expand Up @@ -882,7 +886,10 @@ def train_epoch(self, data_loader):
X, y, psfs = batch
psfs = psfs.to(self.device)
else:
X, y = batch
if self.background:
X, y, background = batch
else:
X, y = batch
psfs = None

random_rotate = False
Expand All @@ -907,7 +914,7 @@ def train_epoch(self, data_loader):

# forward pass
# torch.autograd.set_detect_anomaly(True) # for debugging
y_pred = self.recon.forward(batch=X, psfs=psfs)
y_pred = self.recon.forward(batch=X.unsqueeze(1), psfs=psfs)
if self.unrolled_output_factor or self.pre_proc_aux:
y_pred, camera_inv_out, pre_proc_out = y_pred[0], y_pred[1], y_pred[2]

Expand Down
78 changes: 64 additions & 14 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def __init__(

if len(self.files) == 0:
raise FileNotFoundError(
f"No files found in {self.measured_dir} with extension {self.measurement_ext }"
f"No files found in {self.measured_dir} with extension {self.measurement_ext}"
)

# check that corresponding files exist
Expand Down Expand Up @@ -1276,6 +1276,7 @@ def __init__(
flip_lensed=False,
downsample=1,
downsample_lensed=1,
downsample_background=1,
display_res=None,
sensor="rpi_hq",
slm="adafruit",
Expand Down Expand Up @@ -1309,6 +1310,8 @@ def __init__(
If True, lensless images and PSF are rotated 180 degrees. Lensed/original image is not rotated! By default False.
downsample : float, optional
Downsample factor of the lensless images, by default 1.
downsample : float, optional
Downsample factor of the background images, by default 1.
downsample_lensed : float, optional
Downsample factor of the lensed images, by default 1.
display_res : tuple, optional
Expand Down Expand Up @@ -1359,7 +1362,12 @@ def __init__(
data_0 = self.dataset[0]
self.downsample_lensless = downsample
self.downsample_lensed = downsample_lensed
self.downsample_background = downsample_background
lensless = np.array(data_0["lensless"])
if "ambient" in data_0.keys():
self.measured_bg = True
else:
self.measured_bg = False

if self.downsample_lensless != 1.0:
lensless = resize(lensless, factor=1 / self.downsample_lensless)
Expand Down Expand Up @@ -1541,10 +1549,10 @@ def __len__(self):

def _get_images_pair(self, idx):

# load image
# load images
lensless_np = np.array(self.dataset[idx]["lensless"])
lensed_np = np.array(self.dataset[idx]["lensed"])

background_np = np.array(self.dataset[idx]["ambient"]) if self.measured_bg else None
if self.force_rgb:
if len(lensless_np.shape) == 2:
warnings.warn(f"Converting lensless[{idx}] to RGB")
Expand All @@ -1560,23 +1568,41 @@ def _get_images_pair(self, idx):
elif len(lensed_np.shape) == 3:
pass

if len(background_np.shape) == 2:
warnings.warn(f"Converting background[{idx}] to RGB")
background_np = np.stack([background_np] * 3, axis=2) if not None else None
elif len(background_np.shape) == 3:
pass

# convert to float
if lensless_np.dtype == np.uint8:
lensless_np = lensless_np.astype(np.float32) / 255
lensed_np = lensed_np.astype(np.float32) / 255
background_np = background_np.astype(np.float32) / 255 if not None else None
else:
# 16 bit
lensless_np = lensless_np.astype(np.float32) / 65535
lensed_np = lensed_np.astype(np.float32) / 65535
background_np = background_np.astype(np.float32) / 65535 if not None else None

# downsample if necessary
if self.downsample_lensless != 1.0:
lensless_np = resize(
lensless_np, factor=1 / self.downsample_lensless, interpolation=cv2.INTER_NEAREST
)
background_np = (
resize(
background_np,
factor=1 / self.downsample_background,
interpolation=cv2.INTER_NEAREST,
)
if not None
else None
)

lensless = lensless_np
lensed = lensed_np
background = background_np if not None else None

if self.simulator is not None:
# convert to torch
Expand All @@ -1599,22 +1625,45 @@ def _get_images_pair(self, idx):
shape=(self.alignment["height"], self.alignment["width"], 3),
interpolation=cv2.INTER_NEAREST,
)
background = (
resize(
background_np,
shape=(self.alignment["height"], self.alignment["width"], 3),
interpolation=cv2.INTER_NEAREST,
)
if not None
else None
)
elif self.display_res is not None:
lensed = resize(
lensed_np, shape=(*self.display_res, 3), interpolation=cv2.INTER_NEAREST
)
background = (
resize(background_np, shape=(*self.display_res, 3), interpolation=cv2.INTER_NEAREST)
if not None
else None
)
elif self.downsample_lensed != 1.0:
lensed = resize(
lensed_np,
factor=1 / self.downsample_lensed,
interpolation=cv2.INTER_NEAREST,
)
background = (
resize(
background_np,
factor=1 / self.downsample_lensed,
interpolation=cv2.INTER_NEAREST,
)
if not None
else None
)

return lensless, lensed
return lensless, lensed, background if background is not None else None

def __getitem__(self, idx):
lensless, lensed = super().__getitem__(idx)
if not self.simulate_lensless:
lensless, lensed, background = self._get_images_pair(idx)
if not self.simulate_lensless: # TODO apply transformation to bg as well?
if self.rotate:
lensless = torch.rot90(lensless, dims=(-3, -2), k=2)
if self.flipud:
Expand Down Expand Up @@ -1644,10 +1693,12 @@ def __getitem__(self, idx):
lensless = torch.flip(lensless, dims=(-2,))
lensed = torch.flip(lensed, dims=(-2,))
psf_aug = torch.flip(psf_aug, dims=(-2,))
background = torch.flip(background, dims=(-2,))
if flip_ud:
lensless = torch.flip(lensless, dims=(-3,))
lensed = torch.flip(lensed, dims=(-3,))
psf_aug = torch.flip(psf_aug, dims=(-3,))
background = torch.flip(background, dims=(-3,))

return_items = [lensless, lensed]
if self.multimask:
Expand Down Expand Up @@ -1675,14 +1726,13 @@ def __getitem__(self, idx):
# Add background noise to the target image
image_with_bg = lensless + scaled_bg

return image_with_bg, lensed, scaled_bg
else:
return lensless, lensed

# add simulated background to get image_with_bg and scaled_bg
return_items[0] = image_with_bg
return_items[0].append(scaled_bg)

# Add simulated background to get image_with_bg and scaled_bg
return_items[0] = image_with_bg
return_items[0].append(scaled_bg)
# If measured background available in the dataset return it
elif self.measured_bg:
return_items.append(background)
# TODO push data to gpu in the training loop and hvae a flag that for the simple subtraction does that without pushing the bg to the gpu
return return_items

def extract_roi(
Expand Down
10 changes: 6 additions & 4 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ def train_learned(config):
rotate_angle,
shift,
)
save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png")
# save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png")
save_image(lensed, f"lensed_{_idx}.png")
if test_set.bg_sim is not None:
# Reconstruct and plot background subtracted image
reconstruct_save(
Expand Down Expand Up @@ -659,13 +660,14 @@ def reconstruct_save(
):
recon = ADMM(psf_recon)

recon.set_data(lensless.to(psf_recon.device))
# recon.set_data(lensless.to(psf_recon.device))
recon.set_data(torch.from_numpy(lensless).to(psf_recon.device))
res = recon.apply(disp_iter=None, plot=False, n_iter=10)
res_np = res[0].cpu().numpy()
res_np = res_np / res_np.max()
lensed_np = lensed[0].cpu().numpy()
lensed_np = lensed[0] # .cpu().numpy()

lensless_np = lensless[0].cpu().numpy()
lensless_np = lensless # [0]#.cpu().numpy()
save_image(lensless_np, f"lensless_raw_{_idx}.png")

# -- plot lensed and res on top of each other
Expand Down

0 comments on commit df1f8bd

Please sign in to comment.