diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index e96451d9..4e741789 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -18,7 +18,7 @@ from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD from lensless.utils.simulation import FarFieldSimulator from lensless.utils.io import load_image, load_psf, save_image -from lensless.utils.image import is_grayscale, resize, rgb2gray, rotate_HWC +from lensless.utils.image import is_grayscale, resize, rgb2gray import re from lensless.hardware.utils import capture from lensless.hardware.utils import display @@ -30,6 +30,7 @@ from lensless.hardware.sensor import sensor_dict, SensorParam from scipy.ndimage import rotate import warnings +from PIL import Image def convert(text): @@ -1031,6 +1032,11 @@ def __init__( cache_dir=None, single_channel_psf=False, flipud=False, + display_res=None, + alignment=None, + sensor="rpi_hq", + slm="adafruit", + simulation_config=dict(), **kwargs, ): """ @@ -1058,27 +1064,89 @@ def __init__( # download PSF from huggingface # TODO : assuming psf is not None self.multimask = False - psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") - psf, _ = load_psf( - psf_fp, - shape=self.lensless_shape, - return_float=True, - return_bg=True, - flip_ud=flipud, - bg_pix=(0, 15), - single_psf=single_channel_psf, - ) - self.psf = torch.from_numpy(psf) - if single_channel_psf: - # replicate across three channels - self.psf = self.psf.repeat(1, 1, 1, 3) + self.convolver = None + if psf is not None: + psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") + psf, _ = load_psf( + psf_fp, + shape=self.lensless_shape, + return_float=True, + return_bg=True, + flip_ud=flipud, + bg_pix=(0, 15), + single_psf=single_channel_psf, + ) + self.psf = torch.from_numpy(psf) + if single_channel_psf: + # replicate across three channels + self.psf = self.psf.repeat(1, 1, 1, 3) + + # create convolver object + self.convolver = RealFFTConvolve2D(self.psf) + + elif "mask_label" in data_0: + self.multimask = True + mask_labels = [] + for i in range(len(self.dataset)): + mask_labels.append(self.dataset[i]["mask_label"]) + mask_labels = list(set(mask_labels)) + + # simulate all PSFs + self.psf = dict() + for label in mask_labels: + mask_fp = hf_hub_download( + repo_id=huggingface_repo, + filename=f"masks/mask_{label}.npy", + repo_type="dataset", + ) + mask_vals = np.load(mask_fp) + + if psf is None: + sensor_res = sensor_dict[sensor][SensorParam.RESOLUTION] + downsample_fact = min(sensor_res / lensless.shape[:2]) + else: + downsample_fact = 1 + + mask = AdafruitLCD( + initial_vals=torch.from_numpy(mask_vals.astype(np.float32)), + sensor=sensor, + slm=slm, + downsample=downsample_fact, + flipud=rotate or flipud, # TODO separate commands? + use_waveprop=simulation_config.get("use_waveprop", False), + scene2mask=simulation_config.get("scene2mask", None), + mask2sensor=simulation_config.get("mask2sensor", None), + deadspace=simulation_config.get("deadspace", True), + ) + self.psf[label] = mask.get_psf().detach() + + assert ( + self.psf[label].shape[-3:-1] == lensless.shape[:2] + ), f"PSF shape should match lensless shape: PSF {self.psf[label].shape[-3:-1]} vs lensless {lensless.shape[:2]}" + + # create convolver object + self.convolver = RealFFTConvolve2D(self.psf[label]) + assert self.convolver is not None - # TODO create convolver object - self.convolver = RealFFTConvolve2D(self.psf) self.crop = None self.random_flip = None self.flipud = flipud + self.display_res = display_res + self.alignment = None + self.cropped_lensed_shape = None + if alignment is not None: + self.alignment = dict(alignment.copy()) + self.alignment["top_left"] = ( + int(self.alignment["top_left"][0] / downsample), + int(self.alignment["top_left"][1] / downsample), + ) + self.alignment["height"] = int(self.alignment["height"] / downsample) + + original_aspect_ratio = display_res[1] / display_res[0] + self.alignment["width"] = int(self.alignment["height"] * original_aspect_ratio) + self.cropped_lensed_shape = (self.alignment["height"], self.alignment["width"], 3) + super(HFSimulated, self).__init__(**kwargs) def __len__(self): @@ -1099,7 +1167,20 @@ def _get_images_pair(self, idx): lensed_np = lensed_np.astype(np.float32) / 65535 # resize if necessary - if (self.lensless_shape != np.array(lensed_np.shape[:2])).any(): + if self.cropped_lensed_shape is not None: + cropped_lensed_np = resize( + lensed_np, shape=self.cropped_lensed_shape, interpolation=cv2.INTER_NEAREST + ) + lensed_np = np.zeros(tuple(self.lensless_shape) + (3,), dtype=np.float32) + lensed_np[ + self.alignment["top_left"][0] : self.alignment["top_left"][0] + + self.alignment["height"], + self.alignment["top_left"][1] : self.alignment["top_left"][1] + + self.alignment["width"], + ] = cropped_lensed_np + + elif (self.lensless_shape != np.array(lensed_np.shape[:2])).any(): + lensed_np = resize( lensed_np, shape=self.lensless_shape, interpolation=cv2.INTER_NEAREST ) @@ -1107,13 +1188,72 @@ def _get_images_pair(self, idx): # simulate lensless with convolution lensed = lensed.unsqueeze(0) # add batch dimension + + if self.multimask: + mask_label = self.dataset[idx]["mask_label"] + self.convolver.set_psf(self.psf[mask_label]) lensless = self.convolver.convolve(lensed) if lensless.max() > 1: print("CLIPPING!") lensless /= lensless.max() + if self.cropped_lensed_shape: + return lensless, torch.from_numpy(cropped_lensed_np) + else: + return lensless, lensed + + def __getitem__(self, idx): + lensless, lensed = super().__getitem__(idx) + if self.multimask: + mask_label = self.dataset[idx]["mask_label"] + return lensless, lensed, self.psf[mask_label] return lensless, lensed + def extract_roi(self, reconstruction, axis=(1, 2), **kwargs): + """ + Extract region of interest from lensless and lensed images. + """ + assert self.alignment is not None, "Alignment parameters should be provided." + + n_dim = len(reconstruction.shape) + assert max(axis) < n_dim, "Axis should be within the dimensions of the reconstruction." + + # add batch dimension + if n_dim == 3: + if isinstance(reconstruction, torch.Tensor): + reconstruction = reconstruction.unsqueeze(0) + else: + reconstruction = reconstruction[np.newaxis] + # increment axis + axis = (axis[0] + 1, axis[1] + 1) + + # extract + top_left = self.alignment["top_left"] + height = self.alignment["height"] + width = self.alignment["width"] + + # extract according to axis + index = [slice(None)] * n_dim + index[axis[0]] = slice(top_left[0], top_left[0] + height) + index[axis[1]] = slice(top_left[1], top_left[1] + width) + reconstruction = reconstruction[tuple(index)] + + # rotate if necessary + angle = self.alignment.get("angle", 0) + if isinstance(reconstruction, torch.Tensor) and angle: + reconstruction = F.rotate(reconstruction, angle, expand=False) + elif angle: + reconstruction = rotate(reconstruction, angle, axes=axis, reshape=False) + + # remove batch dimension + if n_dim == 3: + if isinstance(reconstruction, torch.Tensor): + reconstruction = reconstruction.squeeze(0) + else: + reconstruction = reconstruction[0] + + return reconstruction + class HFDataset(DualDataset): def __init__( diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 982ac013..3de4ff73 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -223,6 +223,8 @@ def train_learned(config): cache_dir=config.files.cache_dir, single_channel_psf=config.files.single_channel_psf, flipud=config.files.flipud, + display_res=config.files.image_res, + alignment=config.alignment, ) else: