Skip to content

Commit

Permalink
Add support for training from measured CelebA.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Sep 28, 2023
1 parent 816f050 commit ab9e351
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 211 deletions.
9 changes: 8 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)


def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs):
def benchmark(model, dataset, batchsize=1, metrics=None, mask_crop=None, **kwargs):
"""
Compute multiple metrics for a reconstruction algorithm.
Expand All @@ -36,6 +36,8 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs):
Batch size for processing. For maximum compatibility use 1 (batchsize above 1 are not supported on all algorithm), by default 1
metrics : dict, optional
Dictionary of metrics to compute. If None, MSE, MAE, SSIM, LPIPS and PSNR are computed.
mask_crop : torch.Tensor, optional
Mask to apply to the output of the reconstruction algorithm, by default None.
Returns
-------
Expand Down Expand Up @@ -80,6 +82,11 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs):
# Convert to [N*D, C, H, W] for torchmetrics
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3)

if mask_crop is not None:
prediction = prediction * mask_crop
lensed = lensed * mask_crop

# normalization
prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True)
if torch.all(prediction_max != 0):
Expand Down
27 changes: 26 additions & 1 deletion lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(
save_every=None,
gamma=None,
logger=None,
crop=None,
):
"""
Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace <https://huggingface.co/docs/transformers/main_classes/trainer>`__.
Expand Down Expand Up @@ -309,6 +310,8 @@ def __init__(
Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None.
logger : :py:class:`logging.Logger`, optional
Logger to use for logging. If None, just print to terminal. Default is None.
crop : dict, optional
Crop to apply to images before computing loss (by applying a mask). If None, no crop is applied. Default is None.
"""
Expand Down Expand Up @@ -370,6 +373,21 @@ def __init__(
"lpips package is need for LPIPS loss. Install using : pip install lpips"
)

if crop is not None:
datashape = train_dataset[0][0].shape
# create binary mask to multiply with before computing loss
self.mask_crop = torch.zeros(datashape, dtype=torch.bool).to(self.device)

# move channel dimension to third to last
self.mask_crop = self.mask_crop.movedim(-1, -3)

# set values to True in mask
self.mask_crop[
:, :, crop.vertical[0] : crop.vertical[1], crop.horizontal[0] : crop.horizontal[1]
] = True
else:
self.mask_crop = None

# optimizer
if optimizer == "Adam":
# the parameters of the base model and non torch.Module process must be added separatly
Expand Down Expand Up @@ -495,6 +513,11 @@ def train_epoch(self, data_loader, disp=-1):
y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3)
y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3)

# crop
if self.mask_crop is not None:
y_pred = y_pred * self.mask_crop
y = y * self.mask_crop

loss_v = self.Loss(y_pred, y)
if self.lpips:

Expand Down Expand Up @@ -556,7 +579,9 @@ def evaluate(self, mean_loss, save_pt):
if self.test_dataset is None:
return
# benchmarking
current_metrics = benchmark(self.recon, self.test_dataset, batchsize=self.eval_batch_size)
current_metrics = benchmark(
self.recon, self.test_dataset, batchsize=self.eval_batch_size, mask_crop=self.mask_crop
)

# update metrics with current metrics
self.metrics["LOSS"].append(mean_loss)
Expand Down
Loading

0 comments on commit ab9e351

Please sign in to comment.