Skip to content

Commit

Permalink
Add scripts and notebook to demonstrate results for DigiCam HQ paper.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jan 10, 2024
1 parent 07f533f commit 3a362b3
Show file tree
Hide file tree
Showing 5 changed files with 906 additions and 4 deletions.
11 changes: 11 additions & 0 deletions configs/recon_digicam_celeba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# python scripts/recon/digicam_celeba.py
defaults:
- defaults_recon
- _self_


model: pre4M_unrolled_admm10_post4M # see "lensless/recon/model_dict.py" (digicam/celeba_26k)
device: cuda:1
n_trials: 1 # more if you want to get average inference time
idx: 4 # index from test set to reconstruct
save: True
63 changes: 63 additions & 0 deletions lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
First key is camera, second key is training data, third key is model name.
Download link corresponds to output folder from training
script, which contains the model checkpoint and config file,
and other intermediate files. Models are stored on Hugging Face.
"""

import os
from huggingface_hub import snapshot_download


model_dir_path = os.path.join(os.path.dirname(__file__), "..", "..", "models")

model_dict = {
"digicam": {
"celeba_26k": {
"unrolled_admm10": "bezzam/digicam-celeba-unrolled-admm10",
"unrolled_admm10_ft_psf": "bezzam/digicam-celeba-unrolled-admm10-ft-psf",
"unet8M": "bezzam/digicam-celeba-unet8M",
"unrolled_admm10_post8M": "bezzam/digicam-celeba-unrolled-admm10-post8M",
"unrolled_admm10_ft_psf_post8M": "bezzam/digicam-celeba-unrolled-admm10-ft-psf-post8M",
"pre8M_unrolled_admm10": "bezzam/digicam-celeba-pre8M-unrolled-admm10",
"pre4M_unrolled_admm10_post4M": "bezzam/digicam-celeba-pre4M-unrolled-admm10-post4M",
"pre4M_unrolled_admm10_post4M_OLD": "bezzam/digicam-celeba-pre4M-unrolled-admm10-post4M_OLD",
"pre4M_unrolled_admm10_ft_psf_post4M": "bezzam/digicam-celeba-pre4M-unrolled-admm10-ft-psf-post4M",
# baseline benchmarks which don't have model file but use ADMM
"admm_measured_psf": "bezzam/digicam-celeba-admm-measured-psf",
"admm_simulated_psf": "bezzam/digicam-celeba-admm-simulated-psf",
}
}
}


def download_model(camera, dataset, model):

"""
Download model from model_dict (if needed).
Parameters
----------
dataset : str
Dataset used for training.
model_name : str
Name of model.
"""

if camera not in model_dict:
raise ValueError(f"Camera {camera} not found in model_dict.")

if dataset not in model_dict[camera]:
raise ValueError(f"Dataset {dataset} not found in model_dict.")

if model not in model_dict[camera][dataset]:
raise ValueError(f"Model {model} not found in model_dict.")

repo_id = model_dict[camera][dataset][model]
model_dir = os.path.join(model_dir_path, camera, dataset, model)

if not os.path.exists(model_dir):
snapshot_download(repo_id=repo_id, local_dir=model_dir)

return model_dir
12 changes: 8 additions & 4 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import glob
import os
import torch
import warnings
from abc import abstractmethod
from torch.utils.data import Dataset
from torchvision import transforms
Expand Down Expand Up @@ -496,10 +497,13 @@ def __init__(

# create simulator
simulation_config["output_dim"] = tuple(self.psf.shape[-3:-1])
simulator = FarFieldSimulator(
is_torch=True,
**simulation_config,
)
# -- ignore warning about no PSF
with warnings.catch_warnings():
warnings.simplefilter("ignore")
simulator = FarFieldSimulator(
is_torch=True,
**simulation_config,
)

super().__init__(
measured_dir=data_dir,
Expand Down
496 changes: 496 additions & 0 deletions notebooks/digicam_hq.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 3a362b3

Please sign in to comment.