From 321154fab09cb8269d61e2b61b9f1b8085e80d08 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Mon, 18 Dec 2023 14:20:59 +0100 Subject: [PATCH] Set wavelength and optimizer param through config. --- configs/train_coded_aperture.yaml | 15 +++++++++++++-- configs/train_unrolledADMM.yaml | 4 ++-- lensless/hardware/mask.py | 1 + lensless/hardware/trainable_mask.py | 1 - lensless/recon/utils.py | 14 +++++++++----- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/configs/train_coded_aperture.yaml b/configs/train_coded_aperture.yaml index e9d80c24..ea39b6ab 100644 --- a/configs/train_coded_aperture.yaml +++ b/configs/train_coded_aperture.yaml @@ -15,18 +15,29 @@ files: torch_device: "cuda:1" +optimizer: + # type: Adam # Adam, SGD... + # lr: 1e-4 + type: SGD + lr: 0.01 + #Trainable Mask trainable_mask: mask_type: TrainableCodedAperture - optimizer: Adam - mask_lr: 1e-3 + # optimizer: Adam + # mask_lr: 1e-3 + optimizer: SGD + mask_lr: 0.01 L1_strength: False binary: False initial_value: + psf_wavelength: [550e-9] method: MLS n_bits: 8 # (2**n_bits-1, 2**n_bits-1) # method: MURA # n_bits: 25 # (4*nbits*1, 4*nbits*1) + # # -- applicable for phase masks + # design_wv: 550e-9 simulation: grayscale: True diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 16c6040d..918f2eec 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -84,7 +84,7 @@ trainable_mask: initial_value: psf grayscale: False mask_lr: 1e-3 - optimizer: Adam + optimizer: Adam # Adam, SGD... (Pytorch class) L1_strength: 1.0 #False or float target: "object_plane" # "original" or "object_plane" or "label" @@ -130,7 +130,7 @@ training: crop_preloss: True # crop region for computing loss optimizer: - type: Adam + type: Adam # Adam, SGD... (Pytorch class) lr: 1e-4 slow_start: False #float how much to reduce lr for first epoch # Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index 4b16ef0d..c34e8bf8 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -113,6 +113,7 @@ def __init__( self.shape = self.mask.shape # PSF + assert hasattr(psf_wavelength, "__len__"), "psf_wavelength should be a list" self.psf_wavelength = psf_wavelength self.psf = None self.compute_psf() diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 4618acf3..ae74472b 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -285,7 +285,6 @@ def __init__( self._mask_obj = CodedAperture.from_sensor( sensor_name, downsample, - psf_wavelength=[460e-9], is_torch=True, torch_device=torch_device, **kwargs, diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 7daf44ed..902d0212 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -471,11 +471,15 @@ def detect_nan(grad): def set_optimizer(self, last_epoch=-1): - if self.optimizer_config.type == "Adam": - parameters = [{"params": self.recon.parameters()}] - self.optimizer = torch.optim.Adam(parameters, lr=self.optimizer_config.lr) - else: - raise ValueError(f"Unsupported optimizer : {self.optimizer_config.type}") + # if self.optimizer_config.type == "Adam": + # parameters = [{"params": self.recon.parameters()}] + # self.optimizer = torch.optim.Adam(parameters, lr=self.optimizer_config.lr) + # else: + # raise ValueError(f"Unsupported optimizer : {self.optimizer_config.type}") + parameters = [{"params": self.recon.parameters()}] + self.optimizer = getattr(torch.optim, self.optimizer_config.type)( + parameters, lr=self.optimizer_config.lr + ) # Scheduler if self.optimizer_config.slow_start: