Skip to content

Commit

Permalink
Set wavelength and optimizer param through config.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Dec 18, 2023
1 parent d2cc322 commit 321154f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 10 deletions.
15 changes: 13 additions & 2 deletions configs/train_coded_aperture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,29 @@ files:

torch_device: "cuda:1"

optimizer:
# type: Adam # Adam, SGD...
# lr: 1e-4
type: SGD
lr: 0.01

#Trainable Mask
trainable_mask:
mask_type: TrainableCodedAperture
optimizer: Adam
mask_lr: 1e-3
# optimizer: Adam
# mask_lr: 1e-3
optimizer: SGD
mask_lr: 0.01
L1_strength: False
binary: False
initial_value:
psf_wavelength: [550e-9]
method: MLS
n_bits: 8 # (2**n_bits-1, 2**n_bits-1)
# method: MURA
# n_bits: 25 # (4*nbits*1, 4*nbits*1)
# # -- applicable for phase masks
# design_wv: 550e-9

simulation:
grayscale: True
Expand Down
4 changes: 2 additions & 2 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ trainable_mask:
initial_value: psf
grayscale: False
mask_lr: 1e-3
optimizer: Adam
optimizer: Adam # Adam, SGD... (Pytorch class)
L1_strength: 1.0 #False or float

target: "object_plane" # "original" or "object_plane" or "label"
Expand Down Expand Up @@ -130,7 +130,7 @@ training:
crop_preloss: True # crop region for computing loss

optimizer:
type: Adam
type: Adam # Adam, SGD... (Pytorch class)
lr: 1e-4
slow_start: False #float how much to reduce lr for first epoch
# Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html
Expand Down
1 change: 1 addition & 0 deletions lensless/hardware/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
self.shape = self.mask.shape

# PSF
assert hasattr(psf_wavelength, "__len__"), "psf_wavelength should be a list"
self.psf_wavelength = psf_wavelength
self.psf = None
self.compute_psf()
Expand Down
1 change: 0 additions & 1 deletion lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def __init__(
self._mask_obj = CodedAperture.from_sensor(
sensor_name,
downsample,
psf_wavelength=[460e-9],
is_torch=True,
torch_device=torch_device,
**kwargs,
Expand Down
14 changes: 9 additions & 5 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,15 @@ def detect_nan(grad):

def set_optimizer(self, last_epoch=-1):

if self.optimizer_config.type == "Adam":
parameters = [{"params": self.recon.parameters()}]
self.optimizer = torch.optim.Adam(parameters, lr=self.optimizer_config.lr)
else:
raise ValueError(f"Unsupported optimizer : {self.optimizer_config.type}")
# if self.optimizer_config.type == "Adam":
# parameters = [{"params": self.recon.parameters()}]
# self.optimizer = torch.optim.Adam(parameters, lr=self.optimizer_config.lr)
# else:
# raise ValueError(f"Unsupported optimizer : {self.optimizer_config.type}")
parameters = [{"params": self.recon.parameters()}]
self.optimizer = getattr(torch.optim, self.optimizer_config.type)(
parameters, lr=self.optimizer_config.lr
)

# Scheduler
if self.optimizer_config.slow_start:
Expand Down

0 comments on commit 321154f

Please sign in to comment.