Skip to content

Commit

Permalink
Better API for DiffuserCam MirFlickr (with Hugging Face) (#125)
Browse files Browse the repository at this point in the history
* Add DiffuserCam Mirflickr HF wrapper.

* Update script for reconstructing dataset.

* Fix device ids check.

* Fix for non-multimask.

* FIx path for DRUNet.

* Adjust defaults of Mirflickr.

* Fix diffusercam upload.

* Formatting.

* Normalization fix.

* Fallback in case NPY corrupted.

* Remove try-except.

* Move notebook to Google Colab.

* Update CHANGELOG.
  • Loading branch information
ebezzam authored Apr 17, 2024
1 parent 56d26bd commit 718cadd
Show file tree
Hide file tree
Showing 18 changed files with 224 additions and 2,360 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Added
- DigiCam dataset which interfaces with Hugging Face.
- Scripts for authentication.
- DigiCam support for Telegram demo.
- DiffuserCamMirflickr Hugging Face API.
- Fallback for normalization if data not in 8bit range (``lensless.utils.io.save_image``).

Changed
~~~~~
Expand All @@ -35,6 +37,9 @@ Bugfix
~~~~~

- ``lensless.hardware.trainable_mask.AdafruitLCD`` input handling.
- Local path for DRUNet download.
- APGD input handling (float32).
- Multimask handling.

1.0.6 - (2024-02-21)
--------------------
Expand Down
32 changes: 0 additions & 32 deletions configs/apply_admm_single_mirflickr.yaml

This file was deleted.

23 changes: 0 additions & 23 deletions configs/evaluate_mirflickr_admm.yaml

This file was deleted.

43 changes: 24 additions & 19 deletions configs/recon_celeba_digicam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,35 @@ defaults:
- recon_dataset
- _self_

torch: True
torch_device: 'cuda:0'

repo_id: "bezzam/DigiCam-CelebA-10K"
split: "test" # "train", "test", "all"
psf_fn: "psf_measured.png" # in repo root
n_files: 25 # null for all files
dataset: bezzam/DigiCam-CelebA-10K
psf_fn: psf_measured.png # "psf_simulated.png" or "psf_measured.png "
split: test # "train", "test", "all"
downsample: 2
rotate: True # if measurement is upside-down

preprocess:
flip_ud: True
flip_lr: True
downsample: 6
alignment:
# cropping when there is no downsampling
crop:
vertical: [0, 525]
horizontal: [265, 695]

# to have different data shape than PSF
data_dim: null
# data_dim: [48, 64] # down 64
# data_dim: [506, 676] # down 6
# for prepping ground truth data
simulation:
scene2mask: 0.25 # [m]
mask2sensor: 0.002 # [m]
object_height: 0.33 # [m]
sensor: "rpi_hq"
snr_db: null
downsample: null
random_vflip: False
random_hflip: False
quantize: False
# shifting when there is no files.downsample
vertical_shift: -117
horizontal_shift: -25

algo: admm # "admm", "apgd", "null" to just copy over (resized) raw data
admm:
n_iter: 10

# extraction region of interest
# roi: null # top, left, bottom, right
# roi: [10, 300, 560, 705] # down 4
roi: [10, 190, 377, 490] # down 6
# roi: [5, 150, 280, 352] # down 8
16 changes: 5 additions & 11 deletions configs/recon_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,13 @@ defaults:
torch: True
torch_device: 'cuda:0'

repo_id: "bezzam/DiffuserCam-Lensless-Mirflickr-Dataset"
split: "test" # "train", "test", "all"
psf_fn: "psf.png" # in repo root
output_folder: null # autocreate name if not spe
dataset: diffusercam
split: test # "train", "test", "all"
downsample: 2
data_dim: null
output_folder: null # autocreate name if not specified
n_files: 25 # null for all files

preprocess:
flip_ud: True
flip_lr: False
downsample: 6
# to have different data shape than PSF
data_dim: null

algo: admm # "admm", "apgd", "null" to just copy over (resized) raw data
admm:
n_iter: 100
Expand Down
10 changes: 7 additions & 3 deletions configs/upload_diffusercam_huggingface.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ hydra:
chdir: True # change to output folder

repo_id: "bezzam/DiffuserCam-Lensless-Mirflickr-Dataset"
dir_diffuser: "/scratch/bezzam/DiffuserCam_mirflickr/dataset/diffuser_images"
dir_lensed: "/scratch/bezzam/DiffuserCam_mirflickr/dataset/ground_truth_lensed"
psf_fp: "/home/bezzam/LenslessPiCam/data/psf/diffusercam_psf.tiff"
normalize: False
# repo_id: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM
# normalize: True

dir_diffuser: "DiffuserCam_mirflickr/dataset/diffuser_images"
dir_lensed: "DiffuserCam_mirflickr/dataset/ground_truth_lensed"
psf_fp: "DiffuserCam_mirflickr/psf.tiff"
hf_token: null
file_ext: ".npy"
n_files: null
Expand Down
7 changes: 4 additions & 3 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def benchmark(
idx = 0
with torch.no_grad():
for batch in tqdm(dataloader):
if dataset.multimask:
lensless, lensed, psfs = batch
psfs = psfs.to(device)
if hasattr(dataset, "multimask"):
if dataset.multimask:
lensless, lensed, psfs = batch
psfs = psfs.to(device)
else:
lensless, lensed = batch
psfs = None
Expand Down
2 changes: 1 addition & 1 deletion lensless/recon/apgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self,
psf,
max_iter=500,
dtype=np.float32,
dtype="float32",
diff_penalty=None,
prox_penalty=APGDPriors.NONNEG,
acceleration=True,
Expand Down
2 changes: 1 addition & 1 deletion lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def load_model(model_path, psf, device="cpu", legacy_denoiser=False, verbose=Tru
psf_learned = torch.nn.Parameter(psf_learned)
recon._set_psf(psf_learned)

if config["device_ids"] is not None:
if "device_ids" in config.keys() and config["device_ids"] is not None:
model_state_dict = remove_data_parallel(model_state_dict)

# # return model_state_dict
Expand Down
19 changes: 11 additions & 8 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np
import matplotlib.pyplot as plt
import time
from hydra.utils import get_original_cwd
import os
import torch
from lensless.eval.benchmark import benchmark
Expand Down Expand Up @@ -55,7 +54,7 @@ def load_drunet(model_path=None, n_channels=3, requires_grad=False):

# default to yes if no input is given
valid = input("%s (Y/n) " % msg).lower() != "n"
output_path = os.path.join(get_original_cwd(), "models")
output_path = os.path.join(this_file_path, "..", "..", "models")
if valid:
url = "https://drive.switch.ch/index.php/s/jTdeMHom025RFRQ/download"
filename = "drunet_color.pth"
Expand Down Expand Up @@ -430,6 +429,9 @@ def __init__(
self.lpips = lpips
self.skip_NAN = skip_NAN
self.eval_batch_size = eval_batch_size
self.train_multimask = False
if hasattr(train_dataset, "multimask"):
self.train_multimask = train_dataset.multimask

# check if Subset and if simulating dataset
self.simulated_dataset_trainable_mask = False
Expand Down Expand Up @@ -599,7 +601,7 @@ def train_epoch(self, data_loader):
for batch in pbar:

# get batch
if self.train_dataset.multimask:
if self.train_multimask:
X, y, psfs = batch
psfs = psfs.to(self.device)
else:
Expand Down Expand Up @@ -835,10 +837,11 @@ def evaluate(self, mean_loss, epoch, disp=None):
os.mkdir(output_dir)
output_dir = os.path.join(output_dir, str(epoch) + f"_{eval_set}")

if not self.extra_eval_sets[eval_set].multimask:
# need to set correct PSF for evaluation
# TODO cleaner way to set PSF?
self.recon._set_psf(self.extra_eval_sets[eval_set].psf.to(self.device))
if hasattr(self.extra_eval_sets[eval_set], "multimask"):
if not self.extra_eval_sets[eval_set].multimask:
# need to set correct PSF for evaluation
# TODO cleaner way to set PSF?
self.recon._set_psf(self.extra_eval_sets[eval_set].psf.to(self.device))

# benchmarking
extra_metrics = benchmark(
Expand All @@ -860,7 +863,7 @@ def evaluate(self, mean_loss, epoch, disp=None):

# set back PSF to original in case changed
# TODO: cleaner way?
if not self.train_dataset.multimask:
if not self.train_multimask:
self.recon._set_psf(self.train_dataset.psf.to(self.device))

return eval_loss
Expand Down
84 changes: 84 additions & 0 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __getitem__(self, idx):

if self.background is not None:
lensless = lensless - self.background
lensless = torch.clamp(lensless, min=0)

# add noise
if self.input_snr is not None:
Expand Down Expand Up @@ -959,6 +960,64 @@ def __getitem__(self, index):
return img, lensed


class DiffuserCamMirflickrHF(DualDataset):
def __init__(
self,
split,
repo_id="bezzam/DiffuserCam-Lensless-Mirflickr-Dataset",
psf="psf.tiff",
downsample=2,
flip_ud=True,
**kwargs,
):
"""
Parameters
----------
split : str
Split of the dataset to use: 'train', 'test', or 'all'.
downsample : int, optional
Downsample factor of the PSF, which is 4x the resolution of the images, by default 6 for resolution of (180, 320).
flip_ud : bool, optional
If True, data is flipped up-down, by default ``True``. Otherwise data is upside-down.
"""

# fixed parameters
dtype = "float32"

# get dataset
self.dataset = load_dataset(repo_id, split=split)

# get PSF
psf_fp = hf_hub_download(repo_id=repo_id, filename=psf, repo_type="dataset")
psf, bg = load_psf(
psf_fp,
verbose=False,
downsample=downsample * 4,
return_bg=True,
flip_ud=flip_ud,
dtype=dtype,
bg_pix=(0, 15),
)
self.psf = torch.from_numpy(psf)

super(DiffuserCamMirflickrHF, self).__init__(
flip_ud=flip_ud, downsample=downsample, background=bg, **kwargs
)

def __len__(self):
return len(self.dataset)

def _get_images_pair(self, idx):
lensless = np.array(self.dataset[idx]["lensless"])
lensed = np.array(self.dataset[idx]["lensed"])

# normalize
lensless = lensless.astype(np.float32) / 255
lensed = lensed.astype(np.float32) / 255

return lensless, lensed


class DigiCam(DualDataset):
def __init__(
self,
Expand Down Expand Up @@ -1175,6 +1234,31 @@ def __getitem__(self, idx):
else:
return lensless, lensed

def extract_roi(self, reconstruction, lensed=None):
assert len(reconstruction.shape) == 4, "Reconstruction should have shape [B, H, W, C]"
if lensed is not None:
assert len(lensed.shape) == 4, "Lensed should have shape [B, H, W, C]"

if self.alignment is not None:
top_right = self.alignment["topright"]
height = self.alignment["height"]
width = self.alignment["width"]
reconstruction = reconstruction[
:, top_right[0] : top_right[0] + height, top_right[1] : top_right[1] + width
]
elif self.crop is not None:
vertical = self.crop["vertical"]
horizontal = self.crop["horizontal"]
reconstruction = reconstruction[
:, vertical[0] : vertical[1], horizontal[0] : horizontal[1]
]
if lensed is not None:
lensed = lensed[:, vertical[0] : vertical[1], horizontal[0] : horizontal[1]]
if lensed is not None:
return reconstruction, lensed
else:
return reconstruction


def simulate_dataset(config, generator=None):
"""
Expand Down
10 changes: 10 additions & 0 deletions lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,16 @@ def save_image(img, fp, max_val=255, normalize=True):
if normalize:
img_tmp -= img_tmp.min()
img_tmp /= img_tmp.max()
else:
normalized = False
if img_tmp.min() < 0:
img_tmp -= img_tmp.min()
normalize = True
if img_tmp.max() > 1:
img_tmp /= img_tmp.max()
normalize = True
if normalized:
print(f"Warning (out of range): {fp} normalizing data to [0, 1]")

if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32:
img_tmp *= max_val
Expand Down
1 change: 1 addition & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ The following notebooks can be run from Google Colab:

- [DigiCam: Single-Shot Lensless Sensing with a Low-Cost Programmable Mask](https://colab.research.google.com/drive/1t59uyZMMyCUYVHGXdqdlNlDlb--FL_3P#scrollTo=t9o50zTf3oUg)
- [Aligning a reconstruction with the screen displayed image](https://colab.research.google.com/drive/1c6kUbiB5JO1vro0-IMd-YDDP1g7NFXv3#scrollTo=MtN7GWCIrBKr)
- [A Modular and Robust Physics-Based Approach for Lensless Image Reconstruction](https://colab.research.google.com/drive/1Wgt6ZMRZVuctLHaXxk7PEyPaBaUPvU33)
- [Towards Scalable and Secure Lensless Imaging with a Programmable Mask](https://colab.research.google.com/drive/1YGfs9p4T4NefX8GemVWwtrw4aX8zH1qu#scrollTo=tipedTe4vGwD)
Loading

0 comments on commit 718cadd

Please sign in to comment.