Skip to content

Commit

Permalink
Trainable amplitude mask (#105)
Browse files Browse the repository at this point in the history
* Start interface for trainable coded aperture.

* Update trainable mask interface.

* Improve trainable mask API.

* Fix MURA.

* Fix coded aperture training (fashion mnist).

* Set coded aperture optimization to grayscale.

* Correctly set torch device.

* Move prep trainable mask into package.

* Set wavelength and optimizer param through config.

* Subset files before train-test split.

* Add multilens array.

* Clean up.

* Update changelog.

* Add utility for simulating dataset with mask/psf.
  • Loading branch information
ebezzam authored Feb 23, 2024
1 parent 467c927 commit 9e1e8f1
Show file tree
Hide file tree
Showing 13 changed files with 976 additions and 527 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ Added
~~~~~

- Script to upload measured datasets to Hugging Face: ``scripts/data/upload_dataset_huggingface.py``
- Pytorch support for simulating PSFs of masks.
- ``lensless.hardware.mask.MultiLensArray`` class for simulating multi-lens arrays.
- ``lensless.hardware.trainable_mask.TrainableCodedAperture`` class for training a coded aperture mask pattern.
- Support for other optimizers in ``lensless.utils.Trainer.set_optimizer``.
- ``lensless.utils.dataset.simulate_dataset`` for simulating a dataset given a mask/PSF.

Changed
~~~~~

- Dataset reconstruction script uses datasets from Hugging Face: ``scripts/recon/dataset.py``
- For trainable masks, set trainable parameters inside the child class.

Bugfix
~~~~~
Expand Down
3 changes: 2 additions & 1 deletion configs/train_celeba_digicam_mask.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# fine-tune mask for PSF, but don't re-simulate
# python scripts/recon/train_unrolled.py -cn train_celeba_digicam_mask
defaults:
- train_celeba_digicam
Expand Down Expand Up @@ -78,7 +79,7 @@ trainable_mask:
# horizontal_shift: -100 # [px]


initial_value: adafruit_random_pattern_20231004_174047.npy
initial_value: /home/bezzam/LenslessPiCam/adafruit_random_pattern_20231004_174047.npy
ap_center: [58, 76]
ap_shape: [19, 25]
rotate: 0 # rotation in degrees
Expand Down
55 changes: 55 additions & 0 deletions configs/train_coded_aperture.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# python scripts/recon/train_unrolled.py -cn train_coded_aperture
defaults:
- train_unrolledADMM
- _self_

# Train Dataset
files:
dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
celeba_root: /scratch/bezzam
downsample: 16 # TODO use downsample simulation instead?
n_files: 100
crop:
vertical: [810, 2240]
horizontal: [1310, 2750]

torch_device: "cuda:1"

optimizer:
# type: Adam # Adam, SGD...
# lr: 1e-4
type: SGD
lr: 0.01

#Trainable Mask
trainable_mask:
mask_type: TrainableCodedAperture
# optimizer: Adam
# mask_lr: 1e-3
optimizer: SGD
mask_lr: 0.01
L1_strength: False
binary: False
initial_value:
psf_wavelength: [550e-9]
method: MLS
n_bits: 8 # (2**n_bits-1, 2**n_bits-1)
# method: MURA
# n_bits: 25 # (4*nbits*1, 4*nbits*1)
# # -- applicable for phase masks
# design_wv: 550e-9

simulation:
grayscale: True
flip: False
scene2mask: 40e-2
mask2sensor: 2e-3
sensor: "rpi_hq"
object_height: 0.30

training:
crop_preloss: True # crop region for computing loss
batch_size: 4
epoch: 25
eval_batch_size: 16
save_every: 1
3 changes: 2 additions & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ trainable_mask:
initial_value: psf
grayscale: False
mask_lr: 1e-3
optimizer: Adam # Adam, SGD... (Pytorch class)
L1_strength: 1.0 #False or float

target: "object_plane" # "original" or "object_plane" or "label"
Expand Down Expand Up @@ -135,7 +136,7 @@ training:
crop_preloss: False # crop region for computing loss, files.crop should be set

optimizer:
type: Adam
type: Adam # Adam, SGD... (Pytorch class)
lr: 1e-4
slow_start: False #float how much to reduce lr for first epoch
# Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html
Expand Down
184 changes: 93 additions & 91 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,17 @@ def benchmark(
dataloader = DataLoader(dataset, batch_size=batchsize, pin_memory=(device != "cpu"))
model.reset()
idx = 0
for lensless, lensed in tqdm(dataloader):
lensless = lensless.to(device)
lensed = lensed.to(device)
with torch.no_grad():
for lensless, lensed in tqdm(dataloader):
lensless = lensless.to(device)
lensed = lensed.to(device)

# add shot noise
if snr is not None:
for i in range(lensless.shape[0]):
lensless[i] = add_shot_noise(lensless[i], float(snr))
# add shot noise
if snr is not None:
for i in range(lensless.shape[0]):
lensless[i] = add_shot_noise(lensless[i], float(snr))

# compute predictions
with torch.no_grad():
# compute predictions
if batchsize == 1:
model.set_data(lensless)
prediction = model.apply(
Expand All @@ -126,113 +126,115 @@ def benchmark(
else:
prediction = model.batch_call(lensless, **kwargs)

if unrolled_output_factor:
unrolled_out = prediction[-1]
prediction = prediction[0]

# Convert to [N*D, C, H, W] for torchmetrics
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3)

if crop is not None:
prediction = prediction[
...,
crop["vertical"][0] : crop["vertical"][1],
crop["horizontal"][0] : crop["horizontal"][1],
]
lensed = lensed[
...,
crop["vertical"][0] : crop["vertical"][1],
crop["horizontal"][0] : crop["horizontal"][1],
]

if save_idx is not None:
batch_idx = np.arange(idx, idx + batchsize)

for i, idx in enumerate(batch_idx):
if idx in save_idx:
prediction_np = prediction.cpu().numpy()[i].squeeze()
# switch to [H, W, C]
prediction_np = np.moveaxis(prediction_np, 0, -1)
save_image(prediction_np, fp=os.path.join(output_dir, f"{idx}.png"))

# normalization
prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True)
if torch.all(prediction_max != 0):
prediction = prediction / prediction_max
else:
print("Warning: prediction is zero")
lensed_max = torch.amax(lensed, dim=(1, 2, 3), keepdim=True)
lensed = lensed / lensed_max

# compute metrics
for metric in metrics:
if metric == "ReconstructionError":
metrics_values[metric].append(model.reconstruction_error().cpu().item())
else:
if "LPIPS" in metric:
if prediction.shape[1] == 1:
# LPIPS needs 3 channels
metrics_values[metric].append(
metrics[metric](
prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1)
)
.cpu()
.item()
)
else:
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item()
)
else:
metrics_values[metric].append(metrics[metric](prediction, lensed).cpu().item())
if unrolled_output_factor:
unrolled_out = prediction[-1]
prediction = prediction[0]

# compute metrics for unrolled output
if unrolled_output_factor:
# Convert to [N*D, C, H, W] for torchmetrics
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3)

# -- convert to CHW and remove depth
unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3)

# -- extraction region of interest
if crop is not None:
unrolled_out = unrolled_out[
prediction = prediction[
...,
crop["vertical"][0] : crop["vertical"][1],
crop["horizontal"][0] : crop["horizontal"][1],
]
lensed = lensed[
...,
crop["vertical"][0] : crop["vertical"][1],
crop["horizontal"][0] : crop["horizontal"][1],
]

# -- normalization
unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True)
if torch.all(unrolled_out_max != 0):
unrolled_out = unrolled_out / unrolled_out_max
if save_idx is not None:
batch_idx = np.arange(idx, idx + batchsize)

for i, idx in enumerate(batch_idx):
if idx in save_idx:
prediction_np = prediction.cpu().numpy()[i]
# switch to [H, W, C] for saving
prediction_np = np.moveaxis(prediction_np, 0, -1)
save_image(prediction_np, fp=os.path.join(output_dir, f"{idx}.png"))

# -- compute metrics
# normalization
prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True)
if torch.all(prediction_max != 0):
prediction = prediction / prediction_max
else:
print("Warning: prediction is zero")
lensed_max = torch.amax(lensed, dim=(1, 2, 3), keepdim=True)
lensed = lensed / lensed_max

# compute metrics
for metric in metrics:
if metric == "ReconstructionError":
# only have this for final output
continue
metrics_values[metric].append(model.reconstruction_error().cpu().item())
else:
if "LPIPS" in metric:
if unrolled_out.shape[1] == 1:
if prediction.shape[1] == 1:
# LPIPS needs 3 channels
metrics_values[metric].append(
metrics[metric](
unrolled_out.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1)
prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1)
)
.cpu()
.item()
)
else:
metrics_values[metric + "_unrolled"].append(
metrics[metric](unrolled_out, lensed).cpu().item()
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item()
)
else:
metrics_values[metric + "_unrolled"].append(
metrics[metric](unrolled_out, lensed).cpu().item()
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item()
)

model.reset()
idx += batchsize
# compute metrics for unrolled output
if unrolled_output_factor:

# -- convert to CHW and remove depth
unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3)

# -- extraction region of interest
if crop is not None:
unrolled_out = unrolled_out[
...,
crop["vertical"][0] : crop["vertical"][1],
crop["horizontal"][0] : crop["horizontal"][1],
]

# -- normalization
unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True)
if torch.all(unrolled_out_max != 0):
unrolled_out = unrolled_out / unrolled_out_max

# -- compute metrics
for metric in metrics:
if metric == "ReconstructionError":
# only have this for final output
continue
else:
if "LPIPS" in metric:
if unrolled_out.shape[1] == 1:
# LPIPS needs 3 channels
metrics_values[metric].append(
metrics[metric](
unrolled_out.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1)
)
.cpu()
.item()
)
else:
metrics_values[metric + "_unrolled"].append(
metrics[metric](unrolled_out, lensed).cpu().item()
)
else:
metrics_values[metric + "_unrolled"].append(
metrics[metric](unrolled_out, lensed).cpu().item()
)

model.reset()
idx += batchsize

# average metrics
if return_average:
Expand Down
Loading

0 comments on commit 9e1e8f1

Please sign in to comment.