Skip to content

Commit

Permalink
Adding plotting of reconstructions.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jan 12, 2024
1 parent 3a362b3 commit b4d41f0
Show file tree
Hide file tree
Showing 7 changed files with 644 additions and 128 deletions.
89 changes: 88 additions & 1 deletion lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
# #############################################################################

import abc
import numpy as np
import os
import torch
from lensless.utils.image import is_grayscale
from lensless.hardware.slm import get_programmable_mask, get_intensity_psf
from lensless.hardware.sensor import VirtualSensor
from waveprop.devices import slm_dict
from lensless.hardware.slm import full2subpattern
from lensless.utils.image import rgb2gray


class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -129,7 +133,7 @@ def __init__(
mask2sensor=None,
downsample=None,
min_val=0,
**kwargs
**kwargs,
):
"""
Parameters
Expand Down Expand Up @@ -230,3 +234,86 @@ def project(self):
self.color_filter.data = self.color_filter / self.color_filter.sum(
dim=[1, 2]
).unsqueeze(-1).unsqueeze(-1)


"""
Utility functions to help prepare trainable masks.
"""

mask_type_to_class = {
"TrainablePSF": TrainablePSF,
"AdafruitLCD": AdafruitLCD,
}


def prep_trainable_mask(config, psf=None, downsample=None):

mask = None
color_filter = None
downsample = config["files"]["downsample"] if downsample is None else downsample
if config["trainable_mask"]["mask_type"] is not None:
mask_class = mask_type_to_class[config["trainable_mask"]["mask_type"]]

if config["trainable_mask"]["initial_value"] == "random":
if psf is not None:
initial_mask = torch.rand_like(psf)
else:
sensor = VirtualSensor.from_name(
config["simulation"]["sensor"], downsample=downsample
)
resolution = sensor.resolution
initial_mask = torch.rand((1, *resolution, 3))

elif config["trainable_mask"]["initial_value"] == "psf":
initial_mask = psf.clone()

# if file ending with "npy"
elif config["trainable_mask"]["initial_value"].endswith("npy"):
# from hydra.utils import get_original_cwd
# pattern = np.load(
# os.path.join(get_original_cwd(), config["trainable_mask"]["initial_value"])
# )

pattern = np.load(config["trainable_mask"]["initial_value"])

initial_mask = full2subpattern(
pattern=pattern,
shape=config["trainable_mask"]["ap_shape"],
center=config["trainable_mask"]["ap_center"],
slm=config["trainable_mask"]["slm"],
)
initial_mask = torch.from_numpy(initial_mask.astype(np.float32))

# prepare color filter if needed
from waveprop.devices import slm_dict
from waveprop.devices import SLMParam as SLMParam_wp

slm_param = slm_dict[config["trainable_mask"]["slm"]]
if (
config["trainable_mask"]["train_color_filter"]
and SLMParam_wp.COLOR_FILTER in slm_param.keys()
):
color_filter = slm_param[SLMParam_wp.COLOR_FILTER]
color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32)

# add small random values
color_filter = color_filter + 0.1 * torch.rand_like(color_filter)

else:
raise ValueError(
f"Initial PSF value {config['trainable_mask']['initial_value']} not supported"
)

# convert to grayscale if needed
if config["trainable_mask"]["grayscale"] and not is_grayscale(initial_mask):
initial_mask = rgb2gray(initial_mask)

mask = mask_class(
initial_mask,
optimizer="Adam",
downsample=downsample,
color_filter=color_filter,
**config["trainable_mask"],
)

return mask
110 changes: 110 additions & 0 deletions lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
"""

import os
import numpy as np
import yaml
import torch
from lensless.recon.utils import create_process_network
from lensless.recon.unrolled_admm import UnrolledADMM
from huggingface_hub import snapshot_download
from lensless.hardware.trainable_mask import prep_trainable_mask


model_dir_path = os.path.join(os.path.dirname(__file__), "..", "..", "models")
Expand Down Expand Up @@ -61,3 +67,107 @@ def download_model(camera, dataset, model):
snapshot_download(repo_id=repo_id, local_dir=model_dir)

return model_dir


def load_model(model_path, psf, device):
"""
Load best model from model path.
Parameters
----------
model_path : str
Path to model.
psf : py:class:`~torch.Tensor`
PSF tensor.
device : str
Device to load model on.
"""

# load config
config_path = os.path.join(model_path, ".hydra", "config.yaml")
with open(config_path, "r") as stream:
config = yaml.safe_load(stream)

# TODO : quick fix
if config["trainable_mask"]["initial_value"].endswith("npy"):
config["trainable_mask"][
"initial_value"
] = "/home/bezzam/LenslessPiCam/adafruit_random_pattern_20231004_174047.npy"

# check if trainable mask
downsample = (
config["files"]["downsample"] * 4
) # measured files are 4x downsampled (TODO, maybe celeba only?)
mask = prep_trainable_mask(config, psf, downsample=downsample)
if mask is not None:
# if config["trainable_mask"]["mask_type"] is not None:
# load best mask setting and update PSF

if config["trainable_mask"]["mask_type"] == "AdafruitLCD":
# -- load best values
mask_vals = np.load(os.path.join(model_path, "mask_epochBEST.npy"))
cf_path = os.path.join(model_path, "mask_color_filter_epochBEST.npy")
if os.path.exists(cf_path):
cf = np.load(cf_path)
else:
cf = None

# -- set values and get new PSF
with torch.no_grad():
mask._mask = torch.nn.Parameter(torch.tensor(mask_vals, device=device))
if cf is not None:
mask.color_filter = torch.nn.Parameter(torch.tensor(cf, device=device))
psf = mask.get_psf().to(device)

else:

raise NotImplementedError

# load best model
model_checkpoint = os.path.join(model_path, "recon_epochBEST")
model_state_dict = torch.load(model_checkpoint, map_location=device)

pre_process = None
post_process = None

if "skip_unrolled" not in config["reconstruction"].keys():
config["reconstruction"]["skip_unrolled"] = False

if config["reconstruction"]["pre_process"]["network"] is not None:

pre_process, _ = create_process_network(
network=config["reconstruction"]["pre_process"]["network"],
depth=config["reconstruction"]["pre_process"]["depth"],
nc=config["reconstruction"]["pre_process"]["nc"]
if "nc" in config["reconstruction"]["pre_process"].keys()
else None,
device=device,
)

if config["reconstruction"]["post_process"]["network"] is not None:

post_process, _ = create_process_network(
network=config["reconstruction"]["post_process"]["network"],
depth=config["reconstruction"]["post_process"]["depth"],
nc=config["reconstruction"]["post_process"]["nc"]
if "nc" in config["reconstruction"]["post_process"].keys()
else None,
device=device,
)

if config["reconstruction"]["method"] == "unrolled_admm":
recon = UnrolledADMM(
psf,
pre_process=pre_process,
post_process=post_process,
n_iter=config["reconstruction"]["unrolled_admm"]["n_iter"],
skip_unrolled=config["reconstruction"]["skip_unrolled"],
)

recon.load_state_dict(model_state_dict)
else:
raise ValueError(
f"Reconstruction method {config['reconstruction']['method']} not supported."
)

return recon
Loading

0 comments on commit b4d41f0

Please sign in to comment.