Skip to content

Commit

Permalink
Add support for simulated multimask.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jul 15, 2024
1 parent f0cc3af commit c524d2b
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 18 deletions.
176 changes: 158 additions & 18 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD
from lensless.utils.simulation import FarFieldSimulator
from lensless.utils.io import load_image, load_psf, save_image
from lensless.utils.image import is_grayscale, resize, rgb2gray, rotate_HWC
from lensless.utils.image import is_grayscale, resize, rgb2gray
import re
from lensless.hardware.utils import capture
from lensless.hardware.utils import display
Expand All @@ -30,6 +30,7 @@
from lensless.hardware.sensor import sensor_dict, SensorParam
from scipy.ndimage import rotate
import warnings
from PIL import Image


def convert(text):
Expand Down Expand Up @@ -1031,6 +1032,11 @@ def __init__(
cache_dir=None,
single_channel_psf=False,
flipud=False,
display_res=None,
alignment=None,
sensor="rpi_hq",
slm="adafruit",
simulation_config=dict(),
**kwargs,
):
"""
Expand Down Expand Up @@ -1058,27 +1064,89 @@ def __init__(
# download PSF from huggingface
# TODO : assuming psf is not None
self.multimask = False
psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset")
psf, _ = load_psf(
psf_fp,
shape=self.lensless_shape,
return_float=True,
return_bg=True,
flip_ud=flipud,
bg_pix=(0, 15),
single_psf=single_channel_psf,
)
self.psf = torch.from_numpy(psf)
if single_channel_psf:
# replicate across three channels
self.psf = self.psf.repeat(1, 1, 1, 3)
self.convolver = None
if psf is not None:
psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset")
psf, _ = load_psf(
psf_fp,
shape=self.lensless_shape,
return_float=True,
return_bg=True,
flip_ud=flipud,
bg_pix=(0, 15),
single_psf=single_channel_psf,
)
self.psf = torch.from_numpy(psf)
if single_channel_psf:
# replicate across three channels
self.psf = self.psf.repeat(1, 1, 1, 3)

# create convolver object
self.convolver = RealFFTConvolve2D(self.psf)

elif "mask_label" in data_0:
self.multimask = True
mask_labels = []
for i in range(len(self.dataset)):
mask_labels.append(self.dataset[i]["mask_label"])
mask_labels = list(set(mask_labels))

# simulate all PSFs
self.psf = dict()
for label in mask_labels:
mask_fp = hf_hub_download(
repo_id=huggingface_repo,
filename=f"masks/mask_{label}.npy",
repo_type="dataset",
)
mask_vals = np.load(mask_fp)

if psf is None:
sensor_res = sensor_dict[sensor][SensorParam.RESOLUTION]
downsample_fact = min(sensor_res / lensless.shape[:2])
else:
downsample_fact = 1

mask = AdafruitLCD(
initial_vals=torch.from_numpy(mask_vals.astype(np.float32)),
sensor=sensor,
slm=slm,
downsample=downsample_fact,
flipud=rotate or flipud, # TODO separate commands?
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
deadspace=simulation_config.get("deadspace", True),
)
self.psf[label] = mask.get_psf().detach()

assert (
self.psf[label].shape[-3:-1] == lensless.shape[:2]
), f"PSF shape should match lensless shape: PSF {self.psf[label].shape[-3:-1]} vs lensless {lensless.shape[:2]}"

# create convolver object
self.convolver = RealFFTConvolve2D(self.psf[label])
assert self.convolver is not None

# TODO create convolver object
self.convolver = RealFFTConvolve2D(self.psf)
self.crop = None
self.random_flip = None
self.flipud = flipud

self.display_res = display_res
self.alignment = None
self.cropped_lensed_shape = None
if alignment is not None:
self.alignment = dict(alignment.copy())
self.alignment["top_left"] = (
int(self.alignment["top_left"][0] / downsample),
int(self.alignment["top_left"][1] / downsample),
)
self.alignment["height"] = int(self.alignment["height"] / downsample)

original_aspect_ratio = display_res[1] / display_res[0]
self.alignment["width"] = int(self.alignment["height"] * original_aspect_ratio)
self.cropped_lensed_shape = (self.alignment["height"], self.alignment["width"], 3)

super(HFSimulated, self).__init__(**kwargs)

def __len__(self):
Expand All @@ -1099,21 +1167,93 @@ def _get_images_pair(self, idx):
lensed_np = lensed_np.astype(np.float32) / 65535

# resize if necessary
if (self.lensless_shape != np.array(lensed_np.shape[:2])).any():
if self.cropped_lensed_shape is not None:
cropped_lensed_np = resize(
lensed_np, shape=self.cropped_lensed_shape, interpolation=cv2.INTER_NEAREST
)
lensed_np = np.zeros(tuple(self.lensless_shape) + (3,), dtype=np.float32)
lensed_np[
self.alignment["top_left"][0] : self.alignment["top_left"][0]
+ self.alignment["height"],
self.alignment["top_left"][1] : self.alignment["top_left"][1]
+ self.alignment["width"],
] = cropped_lensed_np

elif (self.lensless_shape != np.array(lensed_np.shape[:2])).any():

lensed_np = resize(
lensed_np, shape=self.lensless_shape, interpolation=cv2.INTER_NEAREST
)
lensed = torch.from_numpy(lensed_np)

# simulate lensless with convolution
lensed = lensed.unsqueeze(0) # add batch dimension

if self.multimask:
mask_label = self.dataset[idx]["mask_label"]
self.convolver.set_psf(self.psf[mask_label])
lensless = self.convolver.convolve(lensed)
if lensless.max() > 1:
print("CLIPPING!")
lensless /= lensless.max()

if self.cropped_lensed_shape:
return lensless, torch.from_numpy(cropped_lensed_np)
else:
return lensless, lensed

def __getitem__(self, idx):
lensless, lensed = super().__getitem__(idx)
if self.multimask:
mask_label = self.dataset[idx]["mask_label"]
return lensless, lensed, self.psf[mask_label]
return lensless, lensed

def extract_roi(self, reconstruction, axis=(1, 2), **kwargs):
"""
Extract region of interest from lensless and lensed images.
"""
assert self.alignment is not None, "Alignment parameters should be provided."

n_dim = len(reconstruction.shape)
assert max(axis) < n_dim, "Axis should be within the dimensions of the reconstruction."

# add batch dimension
if n_dim == 3:
if isinstance(reconstruction, torch.Tensor):
reconstruction = reconstruction.unsqueeze(0)
else:
reconstruction = reconstruction[np.newaxis]
# increment axis
axis = (axis[0] + 1, axis[1] + 1)

# extract
top_left = self.alignment["top_left"]
height = self.alignment["height"]
width = self.alignment["width"]

# extract according to axis
index = [slice(None)] * n_dim
index[axis[0]] = slice(top_left[0], top_left[0] + height)
index[axis[1]] = slice(top_left[1], top_left[1] + width)
reconstruction = reconstruction[tuple(index)]

# rotate if necessary
angle = self.alignment.get("angle", 0)
if isinstance(reconstruction, torch.Tensor) and angle:
reconstruction = F.rotate(reconstruction, angle, expand=False)
elif angle:
reconstruction = rotate(reconstruction, angle, axes=axis, reshape=False)

# remove batch dimension
if n_dim == 3:
if isinstance(reconstruction, torch.Tensor):
reconstruction = reconstruction.squeeze(0)
else:
reconstruction = reconstruction[0]

return reconstruction


class HFDataset(DualDataset):
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def train_learned(config):
cache_dir=config.files.cache_dir,
single_channel_psf=config.files.single_channel_psf,
flipud=config.files.flipud,
display_res=config.files.image_res,
alignment=config.alignment,
)

else:
Expand Down

0 comments on commit c524d2b

Please sign in to comment.