Skip to content

Commit

Permalink
Merge branch 'train_measured' into rpi_gs
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Oct 5, 2023
2 parents 18e46ad + 2444094 commit 242d4ee
Show file tree
Hide file tree
Showing 11 changed files with 384 additions and 211 deletions.
4 changes: 4 additions & 0 deletions configs/compute_metrics_from_original.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
hydra:
job:
chdir: True # change to output folder

files:
# Can be downloaded here: https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Freconstruction
recon: data/reconstruction/admm_thumbs_up_rgb.npy
Expand Down
7 changes: 7 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ files:
diffusercam_psf: True
n_files: null # null to use all for both train/test
downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution
test_size: 0.15

torch: True
torch_device: 'cuda'
Expand Down Expand Up @@ -67,6 +68,7 @@ target: "object_plane" # "original" or "object_plane" or "label"
#for simulated dataset
simulation:
grayscale: False
output_dim: null # should be set if no PSF is used
# random variations
object_height: 0.04 # range for random height or scalar
flip: True # change the orientation of the object (from vertical to horizontal)
Expand Down Expand Up @@ -94,12 +96,17 @@ simulation:
training:
batch_size: 8
epoch: 50
eval_batch_size: 10
metric_for_best_model: null # e.g. LPIPS_Vgg, null does test loss
save_every: null
#In case of instable training
skip_NAN: True
slow_start: False #float how much to reduce lr for first epoch

crop: null # crop region for computing loss
# crop:
# vertical: [30, 560]
# horizontal: [275, 710]

optimizer:
type: Adam
Expand Down
17 changes: 16 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,16 @@
)


def benchmark(model, dataset, batchsize=1, metrics=None, save_idx=None, output_dir=None, **kwargs):
def benchmark(
model,
dataset,
batchsize=1,
metrics=None,
mask_crop=None,
save_idx=None,
output_dir=None,
**kwargs,
):
"""
Compute multiple metrics for a reconstruction algorithm.
Expand All @@ -43,6 +52,8 @@ def benchmark(model, dataset, batchsize=1, metrics=None, save_idx=None, output_d
List of indices to save the predictions, by default None (not to save any).
output_dir : str, optional
Directory to save the predictions, by default save in working directory if save_idx is provided.
mask_crop : torch.Tensor, optional
Mask to apply to the output of the reconstruction algorithm, by default None.
Returns
-------
Expand Down Expand Up @@ -103,6 +114,10 @@ def benchmark(model, dataset, batchsize=1, metrics=None, save_idx=None, output_d
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3)

if mask_crop is not None:
prediction = prediction * mask_crop
lensed = lensed * mask_crop

# normalization
prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True)
if torch.all(prediction_max != 0):
Expand Down
2 changes: 2 additions & 0 deletions lensless/eval/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def extract(
estimate = rotate(
estimate[vertical_crop[0] : vertical_crop[1], horizontal_crop[0] : horizontal_crop[1]],
angle=rotation,
mode="nearest",
reshape=False,
)
estimate /= estimate.max()
estimate = np.clip(estimate, 0, 1)
Expand Down
24 changes: 24 additions & 0 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(
save_every=None,
gamma=None,
logger=None,
crop=None,
):
"""
Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace <https://huggingface.co/docs/transformers/main_classes/trainer>`__.
Expand Down Expand Up @@ -309,6 +310,8 @@ def __init__(
Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None.
logger : :py:class:`logging.Logger`, optional
Logger to use for logging. If None, just print to terminal. Default is None.
crop : dict, optional
Crop to apply to images before computing loss (by applying a mask). If None, no crop is applied. Default is None.
"""
Expand Down Expand Up @@ -370,6 +373,21 @@ def __init__(
"lpips package is need for LPIPS loss. Install using : pip install lpips"
)

if crop is not None:
datashape = train_dataset[0][0].shape
# create binary mask to multiply with before computing loss
self.mask_crop = torch.zeros(datashape, dtype=torch.bool).to(self.device)

# move channel dimension to third to last
self.mask_crop = self.mask_crop.movedim(-1, -3)

# set values to True in mask
self.mask_crop[
:, :, crop.vertical[0] : crop.vertical[1], crop.horizontal[0] : crop.horizontal[1]
] = True
else:
self.mask_crop = None

# optimizer
if optimizer == "Adam":
# the parameters of the base model and non torch.Module process must be added separatly
Expand Down Expand Up @@ -484,6 +502,11 @@ def train_epoch(self, data_loader):
y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3)
y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3)

# crop
if self.mask_crop is not None:
y_pred = y_pred * self.mask_crop
y = y * self.mask_crop

loss_v = self.Loss(y_pred, y)
if self.lpips:

Expand Down Expand Up @@ -560,6 +583,7 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None):
batchsize=self.eval_batch_size,
save_idx=disp,
output_dir=output_dir,
mask_crop=self.mask_crop,
)

# update metrics with current metrics
Expand Down
Loading

0 comments on commit 242d4ee

Please sign in to comment.