Skip to content

Commit

Permalink
Add exp decay of learning rate.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jul 9, 2024
1 parent b962a7b commit 7cace0e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
1 change: 1 addition & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ optimizer:
type: Adam # Adam, SGD... (Pytorch class)
lr: 1e-4
final_lr: False # if set, linearly decay to this value
exp_decay: False # if set, exponentially decay with this value
slow_start: False #float how much to reduce lr for first epoch
# Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html
step: False # int, period of learning rate decay. False to not apply
Expand Down
35 changes: 29 additions & 6 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def __init__(
post_process_delay=None,
post_process_freeze=None,
post_process_unfreeze=None,
n_epoch=None,
):
"""
Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace <https://huggingface.co/docs/transformers/main_classes/trainer>`__.
Expand Down Expand Up @@ -687,6 +688,7 @@ def __init__(
# optimizer
self.clip_grad_norm = clip_grad
self.optimizer_config = optimizer
self.n_epoch = n_epoch
self.set_optimizer()

# metrics
Expand Down Expand Up @@ -750,10 +752,14 @@ def set_optimizer(self, last_epoch=-1):
print("USING ADAMW")
self.optimizer = torch.optim.AdamW(
[
{'params': [p for p in self.recon.parameters() if p.dim() > 1]},
{'params': [p for p in self.recon.parameters() if p.dim() <= 1], 'weight_decay': 0} # no weight decay on bias terms
{"params": [p for p in self.recon.parameters() if p.dim() > 1]},
{
"params": [p for p in self.recon.parameters() if p.dim() <= 1],
"weight_decay": 0,
}, # no weight decay on bias terms
],
lr=self.optimizer_config.lr, weight_decay=0.01
lr=self.optimizer_config.lr,
weight_decay=0.01,
)
else:
print(f"USING {self.optimizer_config.type}")
Expand All @@ -780,11 +786,28 @@ def learning_rate_function(epoch):
elif self.optimizer_config.final_lr:

assert self.optimizer_config.final_lr < self.optimizer_config.lr
assert self.n_epoch is not None

# linearly decrease learning rate to final_lr
# # linear decay
# def learning_rate_function(epoch):
# slope = (start / final - 1) / (n_epoch)
# return 1 / (1 + slope * epoch)

# exponential decay
def learning_rate_function(epoch):
return 1 - (epoch / self.optimizer_config.final_lr)

final_decay = self.optimizer_config.final_lr / self.optimizer_config.lr
final_decay = final_decay ** (1 / (self.n_epoch - 1))
return final_decay**epoch

self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer, lr_lambda=learning_rate_function, last_epoch=last_epoch
)

elif self.optimizer_config.exp_decay:

def learning_rate_function(epoch):
return self.optimizer_config.exp_decay**epoch

self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer, lr_lambda=learning_rate_function, last_epoch=last_epoch
)
Expand Down
1 change: 1 addition & 0 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def train_learned(config):
pre_proc_aux=config.pre_proc_aux,
extra_eval_sets=extra_eval_sets if config.files.extra_eval is not None else None,
use_wandb=True if config.wandb_project is not None else False,
n_epoch=config.training.epoch,
)

trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx)
Expand Down

0 comments on commit 7cace0e

Please sign in to comment.