-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add utility for simulating dataset with mask/psf.
- Loading branch information
Showing
3 changed files
with
224 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,16 +6,18 @@ | |
# Eric BEZZAM [[email protected]] | ||
# ############################################################################# | ||
|
||
from hydra.utils import get_original_cwd | ||
import numpy as np | ||
import glob | ||
import os | ||
import torch | ||
from abc import abstractmethod | ||
from torch.utils.data import Dataset | ||
from torchvision import transforms | ||
from torch.utils.data import Dataset, Subset | ||
from torchvision import datasets, transforms | ||
from lensless.hardware.trainable_mask import prep_trainable_mask | ||
from lensless.utils.simulation import FarFieldSimulator | ||
from lensless.utils.io import load_image, load_psf | ||
from lensless.utils.image import resize | ||
from lensless.utils.image import is_grayscale, resize, rgb2gray | ||
import re | ||
from lensless.hardware.utils import capture | ||
from lensless.hardware.utils import display | ||
|
@@ -951,3 +953,220 @@ def __getitem__(self, index): | |
|
||
# return simulated images (replace simulated with measured) | ||
return img, lensed | ||
|
||
|
||
def simulate_dataset(config, generator=None): | ||
""" | ||
Prepare datasets for training and testing. | ||
Parameters | ||
---------- | ||
config : omegaconf.DictConfig | ||
Configuration, e.g. from Hydra. See ``scripts/recon/train_unrolled.py`` for an example that uses this function. | ||
generator : torch.Generator, optional | ||
Random number generator, by default ``None``. | ||
""" | ||
|
||
if "cuda" in config.torch_device and torch.cuda.is_available(): | ||
device = config.torch_device | ||
else: | ||
device = "cpu" | ||
|
||
# -- prepare PSF | ||
psf = None | ||
if config.trainable_mask.mask_type is None or config.trainable_mask.initial_value == "psf": | ||
psf_fp = os.path.join(get_original_cwd(), config.files.psf) | ||
psf, _ = load_psf( | ||
psf_fp, | ||
downsample=config.files.downsample, | ||
return_float=True, | ||
return_bg=True, | ||
bg_pix=(0, 15), | ||
) | ||
if config.files.diffusercam_psf: | ||
transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) | ||
psf = transform_BRG2RGB(torch.from_numpy(psf)) | ||
|
||
# drop depth dimension | ||
psf = psf.to(device) | ||
|
||
else: | ||
# training mask / PSF | ||
mask = prep_trainable_mask(config, psf) | ||
psf = mask.get_psf().to(device) | ||
|
||
# -- load dataset | ||
pre_transform = None | ||
transforms_list = [transforms.ToTensor()] | ||
data_path = os.path.join(get_original_cwd(), "data") | ||
if config.simulation.grayscale: | ||
transforms_list.append(transforms.Grayscale()) | ||
|
||
if config.files.dataset == "mnist": | ||
transform = transforms.Compose(transforms_list) | ||
train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) | ||
test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) | ||
|
||
elif config.files.dataset == "fashion_mnist": | ||
transform = transforms.Compose(transforms_list) | ||
train_ds = datasets.FashionMNIST( | ||
root=data_path, train=True, download=True, transform=transform | ||
) | ||
test_ds = datasets.FashionMNIST( | ||
root=data_path, train=False, download=True, transform=transform | ||
) | ||
elif config.files.dataset == "cifar10": | ||
transform = transforms.Compose(transforms_list) | ||
train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) | ||
test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) | ||
|
||
elif config.files.dataset == "CelebA": | ||
root = config.files.celeba_root | ||
data_path = os.path.join(root, "celeba") | ||
assert os.path.isdir( | ||
data_path | ||
), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" | ||
transform = transforms.Compose(transforms_list) | ||
if config.files.n_files is None: | ||
train_ds = datasets.CelebA( | ||
root=root, split="train", download=False, transform=transform | ||
) | ||
test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) | ||
else: | ||
ds = datasets.CelebA(root=root, split="all", download=False, transform=transform) | ||
|
||
ds = Subset(ds, np.arange(config.files.n_files)) | ||
|
||
train_size = int((1 - config.files.test_size) * len(ds)) | ||
test_size = len(ds) - train_size | ||
train_ds, test_ds = torch.utils.data.random_split( | ||
ds, [train_size, test_size], generator=generator | ||
) | ||
else: | ||
raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") | ||
|
||
if config.files.dataset != "CelebA": | ||
if config.files.n_files is not None: | ||
train_size = int((1 - config.files.test_size) * config.files.n_files) | ||
test_size = config.files.n_files - train_size | ||
train_ds = Subset(train_ds, np.arange(train_size)) | ||
test_ds = Subset(test_ds, np.arange(test_size)) | ||
|
||
# convert PSF | ||
if config.simulation.grayscale and not is_grayscale(psf): | ||
psf = rgb2gray(psf) | ||
|
||
# check if gpu is available | ||
device_conv = config.torch_device | ||
if device_conv == "cuda" and torch.cuda.is_available(): | ||
device_conv = "cuda" | ||
else: | ||
device_conv = "cpu" | ||
|
||
# create simulator | ||
simulator = FarFieldSimulator( | ||
psf=psf, | ||
is_torch=True, | ||
**config.simulation, | ||
) | ||
|
||
# create Pytorch dataset and dataloader | ||
crop = config.files.crop.copy() if config.files.crop is not None else None | ||
if mask is None: | ||
train_ds_prop = SimulatedFarFieldDataset( | ||
dataset=train_ds, | ||
simulator=simulator, | ||
dataset_is_CHW=True, | ||
device_conv=device_conv, | ||
flip=config.simulation.flip, | ||
vertical_shift=config.files.vertical_shift, | ||
horizontal_shift=config.files.horizontal_shift, | ||
crop=crop, | ||
downsample=config.files.downsample, | ||
pre_transform=pre_transform, | ||
) | ||
test_ds_prop = SimulatedFarFieldDataset( | ||
dataset=test_ds, | ||
simulator=simulator, | ||
dataset_is_CHW=True, | ||
device_conv=device_conv, | ||
flip=config.simulation.flip, | ||
vertical_shift=config.files.vertical_shift, | ||
horizontal_shift=config.files.horizontal_shift, | ||
crop=crop, | ||
downsample=config.files.downsample, | ||
pre_transform=pre_transform, | ||
) | ||
else: | ||
if config.measure is not None: | ||
|
||
train_ds_prop = HITLDatasetTrainableMask( | ||
rpi_username=config.measure.rpi_username, | ||
rpi_hostname=config.measure.rpi_hostname, | ||
celeba_root=config.files.celeba_root, | ||
display_config=config.measure.display, | ||
capture_config=config.measure.capture, | ||
mask_center=config.trainable_mask.ap_center, | ||
dataset=train_ds, | ||
mask=mask, | ||
simulator=simulator, | ||
dataset_is_CHW=True, | ||
device_conv=device_conv, | ||
flip=config.simulation.flip, | ||
vertical_shift=config.files.vertical_shift, | ||
horizontal_shift=config.files.horizontal_shift, | ||
crop=crop, | ||
downsample=config.files.downsample, | ||
pre_transform=pre_transform, | ||
) | ||
|
||
test_ds_prop = HITLDatasetTrainableMask( | ||
rpi_username=config.measure.rpi_username, | ||
rpi_hostname=config.measure.rpi_hostname, | ||
celeba_root=config.files.celeba_root, | ||
display_config=config.measure.display, | ||
capture_config=config.measure.capture, | ||
mask_center=config.trainable_mask.ap_center, | ||
dataset=test_ds, | ||
mask=mask, | ||
simulator=simulator, | ||
dataset_is_CHW=True, | ||
device_conv=device_conv, | ||
flip=config.simulation.flip, | ||
vertical_shift=config.files.vertical_shift, | ||
horizontal_shift=config.files.horizontal_shift, | ||
crop=crop, | ||
downsample=config.files.downsample, | ||
pre_transform=pre_transform, | ||
) | ||
|
||
else: | ||
|
||
train_ds_prop = SimulatedDatasetTrainableMask( | ||
dataset=train_ds, | ||
mask=mask, | ||
simulator=simulator, | ||
dataset_is_CHW=True, | ||
device_conv=device_conv, | ||
flip=config.simulation.flip, | ||
vertical_shift=config.files.vertical_shift, | ||
horizontal_shift=config.files.horizontal_shift, | ||
crop=crop, | ||
downsample=config.files.downsample, | ||
pre_transform=pre_transform, | ||
) | ||
test_ds_prop = SimulatedDatasetTrainableMask( | ||
dataset=test_ds, | ||
mask=mask, | ||
simulator=simulator, | ||
dataset_is_CHW=True, | ||
device_conv=device_conv, | ||
flip=config.simulation.flip, | ||
vertical_shift=config.files.vertical_shift, | ||
horizontal_shift=config.files.horizontal_shift, | ||
crop=crop, | ||
downsample=config.files.downsample, | ||
pre_transform=pre_transform, | ||
) | ||
|
||
return train_ds_prop, test_ds_prop, mask |
Oops, something went wrong.