diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 018143b3..5a646e3f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,7 +18,7 @@ Added - ``lensless.hardware.mask.MultiLensArray`` class for simulating multi-lens arrays. - ``lensless.hardware.trainable_mask.TrainableCodedAperture`` class for training a coded aperture mask pattern. - Support for other optimizers in ``lensless.utils.Trainer.set_optimizer``. - +- ``lensless.utils.dataset.simulate_dataset`` for simulating a dataset given a mask/PSF. Changed ~~~~~ diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 1c0cd9b8..772f718b 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -6,16 +6,18 @@ # Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# +from hydra.utils import get_original_cwd import numpy as np import glob import os import torch from abc import abstractmethod -from torch.utils.data import Dataset -from torchvision import transforms +from torch.utils.data import Dataset, Subset +from torchvision import datasets, transforms +from lensless.hardware.trainable_mask import prep_trainable_mask from lensless.utils.simulation import FarFieldSimulator from lensless.utils.io import load_image, load_psf -from lensless.utils.image import resize +from lensless.utils.image import is_grayscale, resize, rgb2gray import re from lensless.hardware.utils import capture from lensless.hardware.utils import display @@ -951,3 +953,220 @@ def __getitem__(self, index): # return simulated images (replace simulated with measured) return img, lensed + + +def simulate_dataset(config, generator=None): + """ + Prepare datasets for training and testing. + + Parameters + ---------- + config : omegaconf.DictConfig + Configuration, e.g. from Hydra. See ``scripts/recon/train_unrolled.py`` for an example that uses this function. + generator : torch.Generator, optional + Random number generator, by default ``None``. + """ + + if "cuda" in config.torch_device and torch.cuda.is_available(): + device = config.torch_device + else: + device = "cpu" + + # -- prepare PSF + psf = None + if config.trainable_mask.mask_type is None or config.trainable_mask.initial_value == "psf": + psf_fp = os.path.join(get_original_cwd(), config.files.psf) + psf, _ = load_psf( + psf_fp, + downsample=config.files.downsample, + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + if config.files.diffusercam_psf: + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + psf = transform_BRG2RGB(torch.from_numpy(psf)) + + # drop depth dimension + psf = psf.to(device) + + else: + # training mask / PSF + mask = prep_trainable_mask(config, psf) + psf = mask.get_psf().to(device) + + # -- load dataset + pre_transform = None + transforms_list = [transforms.ToTensor()] + data_path = os.path.join(get_original_cwd(), "data") + if config.simulation.grayscale: + transforms_list.append(transforms.Grayscale()) + + if config.files.dataset == "mnist": + transform = transforms.Compose(transforms_list) + train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) + + elif config.files.dataset == "fashion_mnist": + transform = transforms.Compose(transforms_list) + train_ds = datasets.FashionMNIST( + root=data_path, train=True, download=True, transform=transform + ) + test_ds = datasets.FashionMNIST( + root=data_path, train=False, download=True, transform=transform + ) + elif config.files.dataset == "cifar10": + transform = transforms.Compose(transforms_list) + train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) + + elif config.files.dataset == "CelebA": + root = config.files.celeba_root + data_path = os.path.join(root, "celeba") + assert os.path.isdir( + data_path + ), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" + transform = transforms.Compose(transforms_list) + if config.files.n_files is None: + train_ds = datasets.CelebA( + root=root, split="train", download=False, transform=transform + ) + test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) + else: + ds = datasets.CelebA(root=root, split="all", download=False, transform=transform) + + ds = Subset(ds, np.arange(config.files.n_files)) + + train_size = int((1 - config.files.test_size) * len(ds)) + test_size = len(ds) - train_size + train_ds, test_ds = torch.utils.data.random_split( + ds, [train_size, test_size], generator=generator + ) + else: + raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") + + if config.files.dataset != "CelebA": + if config.files.n_files is not None: + train_size = int((1 - config.files.test_size) * config.files.n_files) + test_size = config.files.n_files - train_size + train_ds = Subset(train_ds, np.arange(train_size)) + test_ds = Subset(test_ds, np.arange(test_size)) + + # convert PSF + if config.simulation.grayscale and not is_grayscale(psf): + psf = rgb2gray(psf) + + # check if gpu is available + device_conv = config.torch_device + if device_conv == "cuda" and torch.cuda.is_available(): + device_conv = "cuda" + else: + device_conv = "cpu" + + # create simulator + simulator = FarFieldSimulator( + psf=psf, + is_torch=True, + **config.simulation, + ) + + # create Pytorch dataset and dataloader + crop = config.files.crop.copy() if config.files.crop is not None else None + if mask is None: + train_ds_prop = SimulatedFarFieldDataset( + dataset=train_ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + test_ds_prop = SimulatedFarFieldDataset( + dataset=test_ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + else: + if config.measure is not None: + + train_ds_prop = HITLDatasetTrainableMask( + rpi_username=config.measure.rpi_username, + rpi_hostname=config.measure.rpi_hostname, + celeba_root=config.files.celeba_root, + display_config=config.measure.display, + capture_config=config.measure.capture, + mask_center=config.trainable_mask.ap_center, + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + + test_ds_prop = HITLDatasetTrainableMask( + rpi_username=config.measure.rpi_username, + rpi_hostname=config.measure.rpi_hostname, + celeba_root=config.files.celeba_root, + display_config=config.measure.display, + capture_config=config.measure.capture, + mask_center=config.trainable_mask.ap_center, + dataset=test_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + + else: + + train_ds_prop = SimulatedDatasetTrainableMask( + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + test_ds_prop = SimulatedDatasetTrainableMask( + dataset=test_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + + return train_ds_prop, test_ds_prop, mask diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index bd0f7371..4ad8493e 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -42,19 +42,13 @@ from lensless import UnrolledFISTA, UnrolledADMM, TrainableInversion from lensless.utils.dataset import ( DiffuserCamMirflickr, - SimulatedFarFieldDataset, - SimulatedDatasetTrainableMask, DigiCamCelebA, - HITLDatasetTrainableMask, ) from torch.utils.data import Subset from lensless.recon.utils import create_process_network -from lensless.utils.image import rgb2gray, is_grayscale -from lensless.utils.simulation import FarFieldSimulator +from lensless.utils.dataset import simulate_dataset from lensless.recon.utils import Trainer import torch -from torchvision import transforms, datasets -from lensless.utils.io import load_psf from lensless.utils.io import save_image from lensless.utils.plot import plot_image from lensless import ADMM @@ -64,215 +58,6 @@ log = logging.getLogger(__name__) -def simulate_dataset(config, generator=None): - - if "cuda" in config.torch_device and torch.cuda.is_available(): - log.info("Using GPU for training.") - device = config.torch_device - else: - log.info("Using CPU for training.") - device = "cpu" - - # -- prepare PSF - psf = None - if config.trainable_mask.mask_type is None or config.trainable_mask.initial_value == "psf": - psf_fp = os.path.join(get_original_cwd(), config.files.psf) - psf, _ = load_psf( - psf_fp, - downsample=config.files.downsample, - return_float=True, - return_bg=True, - bg_pix=(0, 15), - ) - if config.files.diffusercam_psf: - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - psf = transform_BRG2RGB(torch.from_numpy(psf)) - - # drop depth dimension - psf = psf.to(device) - - else: - # training mask / PSF - mask = prep_trainable_mask(config, psf) - psf = mask.get_psf().to(device) - - # -- load dataset - pre_transform = None - transforms_list = [transforms.ToTensor()] - data_path = os.path.join(get_original_cwd(), "data") - if config.simulation.grayscale: - transforms_list.append(transforms.Grayscale()) - - if config.files.dataset == "mnist": - transform = transforms.Compose(transforms_list) - train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) - test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) - - elif config.files.dataset == "fashion_mnist": - transform = transforms.Compose(transforms_list) - train_ds = datasets.FashionMNIST( - root=data_path, train=True, download=True, transform=transform - ) - test_ds = datasets.FashionMNIST( - root=data_path, train=False, download=True, transform=transform - ) - elif config.files.dataset == "cifar10": - transform = transforms.Compose(transforms_list) - train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) - test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) - - elif config.files.dataset == "CelebA": - root = config.files.celeba_root - data_path = os.path.join(root, "celeba") - assert os.path.isdir( - data_path - ), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" - transform = transforms.Compose(transforms_list) - if config.files.n_files is None: - train_ds = datasets.CelebA( - root=root, split="train", download=False, transform=transform - ) - test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) - else: - ds = datasets.CelebA(root=root, split="all", download=False, transform=transform) - - ds = Subset(ds, np.arange(config.files.n_files)) - - train_size = int((1 - config.files.test_size) * len(ds)) - test_size = len(ds) - train_size - train_ds, test_ds = torch.utils.data.random_split( - ds, [train_size, test_size], generator=generator - ) - else: - raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") - - if config.files.dataset != "CelebA": - if config.files.n_files is not None: - train_size = int((1 - config.files.test_size) * config.files.n_files) - test_size = config.files.n_files - train_size - train_ds = Subset(train_ds, np.arange(train_size)) - test_ds = Subset(test_ds, np.arange(test_size)) - - # convert PSF - if config.simulation.grayscale and not is_grayscale(psf): - psf = rgb2gray(psf) - - # check if gpu is available - device_conv = config.torch_device - if device_conv == "cuda" and torch.cuda.is_available(): - device_conv = "cuda" - else: - device_conv = "cpu" - - # create simulator - simulator = FarFieldSimulator( - psf=psf, - is_torch=True, - **config.simulation, - ) - - # create Pytorch dataset and dataloader - crop = config.files.crop.copy() if config.files.crop is not None else None - if mask is None: - train_ds_prop = SimulatedFarFieldDataset( - dataset=train_ds, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - test_ds_prop = SimulatedFarFieldDataset( - dataset=test_ds, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - else: - if config.measure is not None: - - train_ds_prop = HITLDatasetTrainableMask( - rpi_username=config.measure.rpi_username, - rpi_hostname=config.measure.rpi_hostname, - celeba_root=config.files.celeba_root, - display_config=config.measure.display, - capture_config=config.measure.capture, - mask_center=config.trainable_mask.ap_center, - dataset=train_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - - test_ds_prop = HITLDatasetTrainableMask( - rpi_username=config.measure.rpi_username, - rpi_hostname=config.measure.rpi_hostname, - celeba_root=config.files.celeba_root, - display_config=config.measure.display, - capture_config=config.measure.capture, - mask_center=config.trainable_mask.ap_center, - dataset=test_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - - else: - - train_ds_prop = SimulatedDatasetTrainableMask( - dataset=train_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - test_ds_prop = SimulatedDatasetTrainableMask( - dataset=test_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - - return train_ds_prop, test_ds_prop, mask - - @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") def train_unrolled(config):