diff --git a/configs/train_digicam_celeba.yaml b/configs/train_digicam_celeba.yaml index 973d13f5..b2724dc9 100644 --- a/configs/train_digicam_celeba.yaml +++ b/configs/train_digicam_celeba.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_learning_based.py -cn train_digicam_singlemask +# python scripts/recon/train_learning_based.py -cn train_digicam_celeba defaults: - train_unrolledADMM - _self_ @@ -13,6 +13,7 @@ files: huggingface_psf: "psf_simulated.png" huggingface_dataset: True split_seed: 0 + test_size: 0.15 downsample: 2 rotate: True # if measurement is upside-down save_psf: False @@ -34,14 +35,14 @@ alignment: random_vflip: False random_hflip: False quantize: False - # shifting when there is no files.downsample + # shifting when there is no files to downsample vertical_shift: -117 horizontal_shift: -25 training: batch_size: 4 epoch: 25 - eval_batch_size: 4 + eval_batch_size: 16 crop_preloss: True reconstruction: diff --git a/configs/train_digicam_multimask.yaml b/configs/train_digicam_multimask.yaml index 4ce73215..e05dda06 100644 --- a/configs/train_digicam_multimask.yaml +++ b/configs/train_digicam_multimask.yaml @@ -1,15 +1,23 @@ # python scripts/recon/train_learning_based.py -cn train_digicam_multimask defaults: - - train_digicam_singlemask + - train_unrolledADMM - _self_ torch_device: 'cuda:0' device_ids: [0, 1, 2, 3] eval_disp_idx: [1, 2, 4, 5, 9] + # Dataset files: dataset: bezzam/DigiCam-Mirflickr-MultiMask-25K + huggingface_dataset: True + huggingface_psf: null + downsample: 1 + # TODO: these parameters should be in the dataset? + image_res: [900, 1200] # used during measurement + rotate: True # if measurement is upside-down + save_psf: False extra_eval: singlemask: @@ -19,3 +27,34 @@ files: alignment: topright: [80, 100] # height, width height: 200 + +# TODO: these parameters should be in the dataset? +alignment: + # when there is no downsampling + topright: [80, 100] # height, width + height: 200 + +training: + batch_size: 4 + epoch: 25 + eval_batch_size: 4 + +reconstruction: + method: unrolled_admm + unrolled_admm: + # Number of iterations + n_iter: 10 + # Hyperparameters + mu1: 1e-4 + mu2: 1e-4 + mu3: 1e-4 + tau: 2e-4 + pre_process: + network : UnetRes # UnetRes or DruNet or null + depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet + nc: [32,64,116,128] + post_process: + network : UnetRes # UnetRes or DruNet or null + depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet + nc: [32,64,116,128] + diff --git a/configs/train_digicam_singlemask.yaml b/configs/train_digicam_singlemask.yaml index f284385d..932d68a8 100644 --- a/configs/train_digicam_singlemask.yaml +++ b/configs/train_digicam_singlemask.yaml @@ -11,6 +11,7 @@ eval_disp_idx: [1, 2, 4, 5, 9] files: dataset: bezzam/DigiCam-Mirflickr-SingleMask-25K huggingface_dataset: True + huggingface_psf: null downsample: 1 # TODO: these parameters should be in the dataset? image_res: [900, 1200] # used during measurement diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index e3e70cd7..47fba326 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -10,11 +10,13 @@ start_delay: null # Dataset files: + # -- using local dataset # dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" # celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html # psf: data/psf/diffusercam_psf.tiff # diffusercam_psf: True + # -- using huggingface dataset dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM huggingface_dataset: True huggingface_psf: psf.tiff diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst index ad21defb..0a8c503b 100644 --- a/docs/source/dataset.rst +++ b/docs/source/dataset.rst @@ -19,6 +19,26 @@ or measured). :special-members: __init__, __len__ +Measured dataset objects +------------------------ + +.. autoclass:: lensless.utils.dataset.HFDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset + :members: + :special-members: __init__ + + Simulated dataset objects ------------------------- @@ -43,19 +63,3 @@ mask / PSF. .. autoclass:: lensless.utils.dataset.SimulatedDatasetTrainableMask :members: :special-members: __init__ - - -Measured dataset objects ------------------------- - -.. autoclass:: lensless.utils.dataset.MeasuredDataset - :members: - :special-members: __init__ - -.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal - :members: - :special-members: __init__ - -.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset - :members: - :special-members: __init__ diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 5729484c..b758337a 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1016,25 +1016,59 @@ def _get_images_pair(self, idx): return lensless, lensed -class DigiCam(DualDataset): +class HFDataset(DualDataset): def __init__( self, huggingface_repo, split, + n_files=None, psf=None, - display_res=None, - sensor="rpi_hq", - slm="adafruit", rotate=False, # just the lensless image downsample=1, downsample_lensed=1, + display_res=None, + sensor="rpi_hq", + slm="adafruit", alignment=None, - save_psf=False, - simulation_config=None, return_mask_label=False, - n_files=None, + save_psf=False, **kwargs, ): + """ + Wrapper for lensless datasets on Hugging Face. + + Parameters + ---------- + huggingface_repo : str + Hugging Face repository ID. + split : str or :py:class:`torch.utils.data.Dataset` + Split of the dataset to use: 'train', 'test', or 'all'. If a Dataset object is given, it is used directly. + n_files : int, optional + Number of files to load from the dataset, by default None, namely all. + psf : str, optional + File name of the PSF at the repository. If None, it is assumed that there is a mask pattern from which the PSF is simulated, by default None. + rotate : bool, optional + If True, lensless images and PSF are rotated 180 degrees. Lensed/original image is not rotated! By default False. + downsample : float, optional + Downsample factor of the lensless images, by default 1. + downsample_lensed : float, optional + Downsample factor of the lensed images, by default 1. + display_res : tuple, optional + Resolution of images when displayed on screen during measurement. + sensor : str, optional + If `psf` not provided, the sensor to use for the PSF simulation, by default "rpi_hq". + slm : str, optional + If `psf` not provided, the SLM to use for the PSF simulation, by default "adafruit". + alignment : dict, optional + Alignment parameters between lensless and lensed data. + If "topright", "height", and "width" are provided, the region-of-interest from the reconstruction of ``lensless`` is extracted and ``lensed`` is reshaped to match. + If "crop" is provided, the region-of-interest is extracted from the simulated lensed image, namely a ``simulation`` configuration should be provided within ``alignment``. + return_mask_label : bool, optional + If multimask dataset, return the mask label (True) or the corresponding PSF (False). + save_psf : bool, optional + If multimask dataset, save the simulated PSFs. + + """ if isinstance(split, str): if n_files is not None: @@ -1080,6 +1114,7 @@ def __init__( # preparing ground-truth as simulated measurement of original elif "crop" in alignment: + assert "simulation" in alignment, "Simulation config should be provided" self.crop = dict(alignment["crop"].copy()) self.crop["vertical"][0] = int(self.crop["vertical"][0] / downsample) self.crop["vertical"][1] = int(self.crop["vertical"][1] / downsample) @@ -1170,7 +1205,7 @@ def __init__( if "horizontal_shift" in simulation_config: self.horizontal_shift = int(simulation_config["horizontal_shift"] / downsample) - super(DigiCam, self).__init__(**kwargs) + super(HFDataset, self).__init__(**kwargs) def __len__(self): return len(self.dataset) @@ -1196,7 +1231,6 @@ def _get_images_pair(self, idx): lensless_np, factor=1 / self.downsample_lensless, interpolation=cv2.INTER_NEAREST ) - lensless = lensless_np lensed = lensed_np if self.simulator is not None: @@ -1226,7 +1260,7 @@ def _get_images_pair(self, idx): elif self.downsample_lensed != 1.0: lensed = resize( lensed_np, - factor=self.downsample_lensed, + factor=1 / self.downsample_lensed, interpolation=cv2.INTER_NEAREST, ) diff --git a/scripts/data/authenticate.py b/scripts/data/authenticate.py index 14f1d97b..9f71819c 100644 --- a/scripts/data/authenticate.py +++ b/scripts/data/authenticate.py @@ -29,7 +29,7 @@ """ -from lensless.utils.dataset import DigiCam +from lensless.utils.dataset import HFDataset import torch from lensless import ADMM from lensless.utils.image import rgb2gray @@ -67,14 +67,14 @@ def authen(config): # load multimask dataset if split == "all": - train_set = DigiCam( + train_set = HFDataset( huggingface_repo=huggingface_repo, split="train", rotate=rotate, downsample=downsample, return_mask_label=True, ) - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=huggingface_repo, split="test", rotate=rotate, @@ -114,7 +114,7 @@ def authen(config): file_idx += list(np.arange(n_train_psf) + i * n_train_psf + test_files_offet) else: - all_set = DigiCam( + all_set = HFDataset( huggingface_repo=huggingface_repo, split=split, rotate=rotate, diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 1e45971d..ece0bcfa 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -26,7 +26,7 @@ from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent -from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, DigiCam +from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset from lensless.utils.io import save_image import torch @@ -85,7 +85,7 @@ def benchmark_recon(config): dataset, [train_size, test_size], generator=generator ) elif dataset == "DigiCamHF": - benchmark_dataset = DigiCam( + benchmark_dataset = HFDataset( huggingface_repo=config.huggingface.repo, split="test", display_res=config.huggingface.image_res, diff --git a/scripts/recon/dataset.py b/scripts/recon/dataset.py index e14f4ecd..906508db 100644 --- a/scripts/recon/dataset.py +++ b/scripts/recon/dataset.py @@ -35,7 +35,7 @@ from tqdm import tqdm from joblib import Parallel, delayed import numpy as np -from lensless.utils.dataset import DiffuserCamMirflickrHF, DigiCam +from lensless.utils.dataset import DiffuserCamMirflickrHF, HFDataset from lensless.eval.metric import psnr, lpips from lensless.utils.image import resize @@ -47,7 +47,7 @@ def recon_dataset(config): if config.dataset == "diffusercam": dataset = DiffuserCamMirflickrHF(split=config.split, downsample=config.downsample) else: - dataset = DigiCam( + dataset = HFDataset( huggingface_repo=config.dataset, split=config.split, downsample=config.downsample, diff --git a/scripts/recon/digicam_mirflickr.py b/scripts/recon/digicam_mirflickr.py index 88a6a036..60411fd0 100644 --- a/scripts/recon/digicam_mirflickr.py +++ b/scripts/recon/digicam_mirflickr.py @@ -3,7 +3,7 @@ import torch from lensless import ADMM from lensless.utils.plot import plot_image -from lensless.utils.dataset import DigiCam +from lensless.utils.dataset import HFDataset import os from lensless.utils.io import save_image import time @@ -35,7 +35,7 @@ def apply_pretrained(config): model_config = yaml.safe_load(stream) # load dataset - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=model_config["files"]["dataset"], psf=model_config["files"]["huggingface_psf"] if "huggingface_psf" in model_config["files"] diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 3f99049c..9ad7a016 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -40,7 +40,7 @@ from lensless.utils.dataset import ( DiffuserCamMirflickr, DigiCamCelebA, - DigiCam, + HFDataset, MyDataParallel, simulate_dataset, ) @@ -57,7 +57,7 @@ @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") -def train_unrolled(config): +def train_learned(config): if config.wandb_project is not None: # start a new wandb run to track this script @@ -189,8 +189,13 @@ def train_unrolled(config): generator = torch.Generator().manual_seed(seed) # - combine train and test into single dataset - train_dataset = load_dataset(config.files.dataset, split="train") - test_dataset = load_dataset(config.files.dataset, split="test") + train_split = "train" + test_split = "test" + if config.files.n_files is not None: + train_split = f"train[:{config.files.n_files}]" + test_split = f"test[:{config.files.n_files}]" + train_dataset = load_dataset(config.files.dataset, split=train_split) + test_dataset = load_dataset(config.files.dataset, split=test_split) dataset = concatenate_datasets([test_dataset, train_dataset]) # - split into train and test @@ -200,7 +205,7 @@ def train_unrolled(config): dataset, [train_size, test_size], generator=generator ) - train_set = DigiCam( + train_set = HFDataset( huggingface_repo=config.files.dataset, psf=config.files.huggingface_psf, split=split_train, @@ -212,7 +217,7 @@ def train_unrolled(config): save_psf=config.files.save_psf, n_files=config.files.n_files, ) - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=config.files.dataset, psf=config.files.huggingface_psf, split=split_test, @@ -226,7 +231,10 @@ def train_unrolled(config): ) if train_set.multimask: # get first PSF for initialization - first_psf_key = list(train_set.psf.keys())[device_ids[0]] + if device_ids is not None: + first_psf_key = list(train_set.psf.keys())[device_ids[0]] + else: + first_psf_key = list(train_set.psf.keys())[0] psf = train_set.psf[first_psf_key].to(device) else: psf = train_set.psf.to(device) @@ -265,7 +273,7 @@ def train_unrolled(config): extra_eval_sets = dict() for eval_set in config.files.extra_eval: - extra_eval_sets[eval_set] = DigiCam( + extra_eval_sets[eval_set] = HFDataset( split="test", downsample=config.files.downsample, # needs to be same size n_files=config.files.n_files, @@ -492,4 +500,4 @@ def train_unrolled(config): if __name__ == "__main__": - train_unrolled() + train_learned()