Skip to content

Commit

Permalink
Add utility for simulating dataset with mask/psf.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Feb 23, 2024
1 parent db7d90f commit 7949764
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 220 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Added
- ``lensless.hardware.mask.MultiLensArray`` class for simulating multi-lens arrays.
- ``lensless.hardware.trainable_mask.TrainableCodedAperture`` class for training a coded aperture mask pattern.
- Support for other optimizers in ``lensless.utils.Trainer.set_optimizer``.

- ``lensless.utils.dataset.simulate_dataset`` for simulating a dataset given a mask/PSF.

Changed
~~~~~
Expand Down
225 changes: 222 additions & 3 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
# Eric BEZZAM [[email protected]]
# #############################################################################

from hydra.utils import get_original_cwd
import numpy as np
import glob
import os
import torch
from abc import abstractmethod
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import Dataset, Subset
from torchvision import datasets, transforms
from lensless.hardware.trainable_mask import prep_trainable_mask
from lensless.utils.simulation import FarFieldSimulator
from lensless.utils.io import load_image, load_psf
from lensless.utils.image import resize
from lensless.utils.image import is_grayscale, resize, rgb2gray
import re
from lensless.hardware.utils import capture
from lensless.hardware.utils import display
Expand Down Expand Up @@ -951,3 +953,220 @@ def __getitem__(self, index):

# return simulated images (replace simulated with measured)
return img, lensed


def simulate_dataset(config, generator=None):
"""
Prepare datasets for training and testing.
Parameters
----------
config : omegaconf.DictConfig
Configuration, e.g. from Hydra. See ``scripts/recon/train_unrolled.py`` for an example that uses this function.
generator : torch.Generator, optional
Random number generator, by default ``None``.
"""

if "cuda" in config.torch_device and torch.cuda.is_available():
device = config.torch_device
else:
device = "cpu"

# -- prepare PSF
psf = None
if config.trainable_mask.mask_type is None or config.trainable_mask.initial_value == "psf":
psf_fp = os.path.join(get_original_cwd(), config.files.psf)
psf, _ = load_psf(
psf_fp,
downsample=config.files.downsample,
return_float=True,
return_bg=True,
bg_pix=(0, 15),
)
if config.files.diffusercam_psf:
transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]])
psf = transform_BRG2RGB(torch.from_numpy(psf))

# drop depth dimension
psf = psf.to(device)

else:
# training mask / PSF
mask = prep_trainable_mask(config, psf)
psf = mask.get_psf().to(device)

# -- load dataset
pre_transform = None
transforms_list = [transforms.ToTensor()]
data_path = os.path.join(get_original_cwd(), "data")
if config.simulation.grayscale:
transforms_list.append(transforms.Grayscale())

if config.files.dataset == "mnist":
transform = transforms.Compose(transforms_list)
train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)

elif config.files.dataset == "fashion_mnist":
transform = transforms.Compose(transforms_list)
train_ds = datasets.FashionMNIST(
root=data_path, train=True, download=True, transform=transform
)
test_ds = datasets.FashionMNIST(
root=data_path, train=False, download=True, transform=transform
)
elif config.files.dataset == "cifar10":
transform = transforms.Compose(transforms_list)
train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform)
test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform)

elif config.files.dataset == "CelebA":
root = config.files.celeba_root
data_path = os.path.join(root, "celeba")
assert os.path.isdir(
data_path
), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html"
transform = transforms.Compose(transforms_list)
if config.files.n_files is None:
train_ds = datasets.CelebA(
root=root, split="train", download=False, transform=transform
)
test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform)
else:
ds = datasets.CelebA(root=root, split="all", download=False, transform=transform)

ds = Subset(ds, np.arange(config.files.n_files))

train_size = int((1 - config.files.test_size) * len(ds))
test_size = len(ds) - train_size
train_ds, test_ds = torch.utils.data.random_split(
ds, [train_size, test_size], generator=generator
)
else:
raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.")

if config.files.dataset != "CelebA":
if config.files.n_files is not None:
train_size = int((1 - config.files.test_size) * config.files.n_files)
test_size = config.files.n_files - train_size
train_ds = Subset(train_ds, np.arange(train_size))
test_ds = Subset(test_ds, np.arange(test_size))

# convert PSF
if config.simulation.grayscale and not is_grayscale(psf):
psf = rgb2gray(psf)

# check if gpu is available
device_conv = config.torch_device
if device_conv == "cuda" and torch.cuda.is_available():
device_conv = "cuda"
else:
device_conv = "cpu"

# create simulator
simulator = FarFieldSimulator(
psf=psf,
is_torch=True,
**config.simulation,
)

# create Pytorch dataset and dataloader
crop = config.files.crop.copy() if config.files.crop is not None else None
if mask is None:
train_ds_prop = SimulatedFarFieldDataset(
dataset=train_ds,
simulator=simulator,
dataset_is_CHW=True,
device_conv=device_conv,
flip=config.simulation.flip,
vertical_shift=config.files.vertical_shift,
horizontal_shift=config.files.horizontal_shift,
crop=crop,
downsample=config.files.downsample,
pre_transform=pre_transform,
)
test_ds_prop = SimulatedFarFieldDataset(
dataset=test_ds,
simulator=simulator,
dataset_is_CHW=True,
device_conv=device_conv,
flip=config.simulation.flip,
vertical_shift=config.files.vertical_shift,
horizontal_shift=config.files.horizontal_shift,
crop=crop,
downsample=config.files.downsample,
pre_transform=pre_transform,
)
else:
if config.measure is not None:

train_ds_prop = HITLDatasetTrainableMask(
rpi_username=config.measure.rpi_username,
rpi_hostname=config.measure.rpi_hostname,
celeba_root=config.files.celeba_root,
display_config=config.measure.display,
capture_config=config.measure.capture,
mask_center=config.trainable_mask.ap_center,
dataset=train_ds,
mask=mask,
simulator=simulator,
dataset_is_CHW=True,
device_conv=device_conv,
flip=config.simulation.flip,
vertical_shift=config.files.vertical_shift,
horizontal_shift=config.files.horizontal_shift,
crop=crop,
downsample=config.files.downsample,
pre_transform=pre_transform,
)

test_ds_prop = HITLDatasetTrainableMask(
rpi_username=config.measure.rpi_username,
rpi_hostname=config.measure.rpi_hostname,
celeba_root=config.files.celeba_root,
display_config=config.measure.display,
capture_config=config.measure.capture,
mask_center=config.trainable_mask.ap_center,
dataset=test_ds,
mask=mask,
simulator=simulator,
dataset_is_CHW=True,
device_conv=device_conv,
flip=config.simulation.flip,
vertical_shift=config.files.vertical_shift,
horizontal_shift=config.files.horizontal_shift,
crop=crop,
downsample=config.files.downsample,
pre_transform=pre_transform,
)

else:

train_ds_prop = SimulatedDatasetTrainableMask(
dataset=train_ds,
mask=mask,
simulator=simulator,
dataset_is_CHW=True,
device_conv=device_conv,
flip=config.simulation.flip,
vertical_shift=config.files.vertical_shift,
horizontal_shift=config.files.horizontal_shift,
crop=crop,
downsample=config.files.downsample,
pre_transform=pre_transform,
)
test_ds_prop = SimulatedDatasetTrainableMask(
dataset=test_ds,
mask=mask,
simulator=simulator,
dataset_is_CHW=True,
device_conv=device_conv,
flip=config.simulation.flip,
vertical_shift=config.files.vertical_shift,
horizontal_shift=config.files.horizontal_shift,
crop=crop,
downsample=config.files.downsample,
pre_transform=pre_transform,
)

return train_ds_prop, test_ds_prop, mask
Loading

0 comments on commit 7949764

Please sign in to comment.