Skip to content

Commit

Permalink
-Fixed PR
Browse files Browse the repository at this point in the history
-TODO: find alignement factor
  • Loading branch information
StefanPetersTM committed Aug 9, 2024
1 parent df1f8bd commit 1bc6a30
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 28 deletions.
19 changes: 19 additions & 0 deletions configs/train_tapecam_measured_background.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape
defaults:
- train_mirflickr_tape
- _self_

wandb_project:
device_ids:


# Dataset
files:
dataset: Lensless/TapeCam-Mirflickr-Ambient
background_snr_range: [0,0]
image_res: [507, 380]

alignment:
# when there is no downsampling
top_left: [45, 95] # height, width
height: 250
3 changes: 1 addition & 2 deletions configs/train_tapecam_simulated_background.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ device_ids:

# Dataset
files:
dataset: Lensless/TapeCam-Mirflickr-Ambient
background_fp:
background_fp: ""
background_snr_range: [0,0]
23 changes: 2 additions & 21 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,6 @@ def __init__(
flip_lensed=False,
downsample=1,
downsample_lensed=1,
downsample_background=1,
display_res=None,
sensor="rpi_hq",
slm="adafruit",
Expand Down Expand Up @@ -1310,8 +1309,6 @@ def __init__(
If True, lensless images and PSF are rotated 180 degrees. Lensed/original image is not rotated! By default False.
downsample : float, optional
Downsample factor of the lensless images, by default 1.
downsample : float, optional
Downsample factor of the background images, by default 1.
downsample_lensed : float, optional
Downsample factor of the lensed images, by default 1.
display_res : tuple, optional
Expand Down Expand Up @@ -1362,7 +1359,6 @@ def __init__(
data_0 = self.dataset[0]
self.downsample_lensless = downsample
self.downsample_lensed = downsample_lensed
self.downsample_background = downsample_background
lensless = np.array(data_0["lensless"])
if "ambient" in data_0.keys():
self.measured_bg = True
Expand Down Expand Up @@ -1593,7 +1589,7 @@ def _get_images_pair(self, idx):
background_np = (
resize(
background_np,
factor=1 / self.downsample_background,
factor=1 / self.downsample,
interpolation=cv2.INTER_NEAREST,
)
if not None
Expand Down Expand Up @@ -1638,32 +1634,18 @@ def _get_images_pair(self, idx):
lensed = resize(
lensed_np, shape=(*self.display_res, 3), interpolation=cv2.INTER_NEAREST
)
background = (
resize(background_np, shape=(*self.display_res, 3), interpolation=cv2.INTER_NEAREST)
if not None
else None
)
elif self.downsample_lensed != 1.0:
lensed = resize(
lensed_np,
factor=1 / self.downsample_lensed,
interpolation=cv2.INTER_NEAREST,
)
background = (
resize(
background_np,
factor=1 / self.downsample_lensed,
interpolation=cv2.INTER_NEAREST,
)
if not None
else None
)

return lensless, lensed, background if background is not None else None

def __getitem__(self, idx):
lensless, lensed, background = self._get_images_pair(idx)
if not self.simulate_lensless: # TODO apply transformation to bg as well?
if not self.simulate_lensless:
if self.rotate:
lensless = torch.rot90(lensless, dims=(-3, -2), k=2)
if self.flipud:
Expand Down Expand Up @@ -1732,7 +1714,6 @@ def __getitem__(self, idx):
# If measured background available in the dataset return it
elif self.measured_bg:
return_items.append(background)
# TODO push data to gpu in the training loop and hvae a flag that for the simple subtraction does that without pushing the bg to the gpu
return return_items

def extract_roi(
Expand Down
10 changes: 5 additions & 5 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ def train_learned(config):
rotate_angle,
shift,
)
# save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png")
save_image(lensed, f"lensed_{_idx}.png")
save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png")
# save_image(lensed, f"lensed_{_idx}.png")
if test_set.bg_sim is not None:
# Reconstruct and plot background subtracted image
reconstruct_save(
Expand Down Expand Up @@ -660,14 +660,14 @@ def reconstruct_save(
):
recon = ADMM(psf_recon)

# recon.set_data(lensless.to(psf_recon.device))
recon.set_data(torch.from_numpy(lensless).to(psf_recon.device))
recon.set_data(lensless.to(psf_recon.device))
# recon.set_data(torch.from_numpy(lensless).to(psf_recon.device))
res = recon.apply(disp_iter=None, plot=False, n_iter=10)
res_np = res[0].cpu().numpy()
res_np = res_np / res_np.max()
lensed_np = lensed[0] # .cpu().numpy()

lensless_np = lensless # [0]#.cpu().numpy()
lensless_np = lensless.cpu().numpy() # [0]#.cpu().numpy()
save_image(lensless_np, f"lensless_raw_{_idx}.png")

# -- plot lensed and res on top of each other
Expand Down

0 comments on commit 1bc6a30

Please sign in to comment.