Skip to content

Commit

Permalink
Clean up fine-tuning PSF.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Sep 19, 2023
1 parent 7eb5db7 commit 91923f9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 57 deletions.
3 changes: 1 addition & 2 deletions configs/fine-tune_PSF.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ defaults:
#Trainable Mask
trainable_mask:
mask_type: TrainablePSF #Null or "TrainablePSF"
initial_value: "DiffuserCam" # "random" or "DiffuserCam" or "DiffuserCam_gray"
initial_value: psf
mask_lr: 1e-3
L1_strength: 1.0 #False or float
use_mask_in_dataset : False # Work only with simulated dataset

#Training
training:
epoch: 50
save_every: 5

display:
Expand Down
4 changes: 3 additions & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ reconstruction:
#Trainable Mask
trainable_mask:
mask_type: Null #Null or "TrainablePSF"
initial_value: "DiffuserCam_gray" # "random" or "DiffuserCam" or "DiffuserCam_gray"
# "random" (with shape of config.files.psf) or "psf" (using config.files.psf)
initial_value: psf
grayscale: False
mask_lr: 1e-3
L1_strength: 1.0 #False or float
use_mask_in_dataset : True # Work only with simulated dataset
Expand Down
9 changes: 7 additions & 2 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ==================
# Authors :
# Yohann PERRON [[email protected]]
# Eric BEZZAM [[email protected]]
# #############################################################################

import abc
Expand Down Expand Up @@ -76,12 +77,16 @@ class TrainablePSF(TrainableMask):
def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, is_rgb=True, **kwargs):
super().__init__(initial_mask, optimizer, lr, **kwargs)
self._is_rgb = is_rgb
if is_rgb:
assert initial_mask.shape[-1] == 3, "RGB mask should have 3 channels"
else:
assert initial_mask.shape[-1] == 1, "Monochrome mask should have 1 channel"

def get_psf(self):
if self._is_rgb:
return self._mask.expand(-1, -1, -1, 3)
else:
return self._mask
else:
return self._mask.expand(-1, -1, -1, 3)

def project(self):
self._mask.data = torch.clamp(self._mask, 0, 1)
85 changes: 33 additions & 52 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ def simulate_dataset(config):
return ds_prop, mask


def prep_trainable_mask(config, dataset):
mask = None
if config.trainable_mask.mask_type is not None:
mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type)

if config.trainable_mask.initial_value == "random":
initial_mask = torch.rand_like(dataset.psf)
elif config.trainable_mask.initial_value == "psf":
initial_mask = dataset.psf.clone()
else:
raise ValueError(
f"Initial PSF value {config.trainable_mask.initial_value} not supported"
)

if config.trainable_mask.grayscale:
initial_mask = rgb2gray(initial_mask)

mask = mask_class(
initial_mask,
optimizer="Adam",
lr=config.trainable_mask.mask_lr,
is_rgb=not config.trainable_mask.grayscale,
)

return mask


@hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM")
def train_unrolled(config):

Expand All @@ -171,37 +198,6 @@ def train_unrolled(config):
print("Using CPU for training.")
device = "cpu"

# # benchmarking dataset:
# eval_path = os.path.join(get_original_cwd(), config.files.eval_dataset)
# benchmark_dataset = DiffuserCamTestDataset(
# data_dir=eval_path, downsample=config.files.downsample, n_files=config.files.n_files
# )

# diffusercam_psf = benchmark_dataset.psf.to(device)
# # background = benchmark_dataset.background

# # convert psf from BGR to RGB
# diffusercam_psf = diffusercam_psf[..., [2, 1, 0]]

# # create mask
# if config.trainable_mask.mask_type is not None:
# mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type)
# if config.trainable_mask.initial_value == "random":
# mask = mask_class(
# torch.rand_like(diffusercam_psf), optimizer="Adam", lr=config.trainable_mask.mask_lr
# )
# elif config.trainable_mask.initial_value == "DiffuserCam":
# mask = mask_class(diffusercam_psf, optimizer="Adam", lr=config.trainable_mask.mask_lr)
# elif config.trainable_mask.initial_value == "DiffuserCam_gray":
# mask = mask_class(
# diffusercam_psf[:, :, :, 0, None],
# optimizer="Adam",
# lr=config.trainable_mask.mask_lr,
# is_rgb=not config.simulation.grayscale,
# )
# else:
# mask = None

# load dataset and create dataloader
train_set = None
test_set = None
Expand Down Expand Up @@ -229,41 +225,26 @@ def train_unrolled(config):
print("Test test size : ", len(test_set))

# -- if learning mask
mask = None
if config.trainable_mask.mask_type is not None:
mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type)

if config.trainable_mask.initial_value == "random":
mask = mask_class(
torch.rand_like(dataset.psf), optimizer="Adam", lr=config.trainable_mask.mask_lr
)
# TODO : change to PSF
elif config.trainable_mask.initial_value == "DiffuserCam":
mask = mask_class(dataset.psf, optimizer="Adam", lr=config.trainable_mask.mask_lr)
elif config.trainable_mask.initial_value == "DiffuserCam_gray":
# TODO convert to grayscale
mask = mask_class(
dataset.psf[:, :, :, 0, None],
optimizer="Adam",
lr=config.trainable_mask.mask_lr,
is_rgb=not config.simulation.grayscale,
)

mask = prep_trainable_mask(config, dataset)
if mask is not None:
# plot initial PSF
psf_np = mask.get_psf().detach().cpu().numpy()[0, ...]
if config.trainable_mask.grayscale:
psf_np = psf_np[:, :, -1]

save_image(psf_np, os.path.join(save, "psf_initial.png"))
plot_image(psf_np, gamma=config.display.gamma)
plt.savefig(os.path.join(save, "psf_initial_plot.png"))

else:

# Use a simulated dataset
if config.trainable_mask.use_mask_in_dataset:
train_set, mask = simulate_dataset(config)
# the mask use will differ from the one in the benchmark dataset
print("Trainable Mask will be used in the test dataset")
test_set = None
else:
# TODO handlge case where finetuning PSF
train_set, mask = simulate_dataset(config)

start_time = time.time()
Expand Down

0 comments on commit 91923f9

Please sign in to comment.