From 7cace0e7336fd4705a75ad2d8617f834e74bd6cd Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 9 Jul 2024 17:36:41 +0200 Subject: [PATCH] Add exp decay of learning rate. --- configs/train_unrolledADMM.yaml | 1 + lensless/recon/utils.py | 35 ++++++++++++++++++++++----- scripts/recon/train_learning_based.py | 1 + 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index ec4bf28e..52481bd7 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -175,6 +175,7 @@ optimizer: type: Adam # Adam, SGD... (Pytorch class) lr: 1e-4 final_lr: False # if set, linearly decay to this value + exp_decay: False # if set, exponentially decay with this value slow_start: False #float how much to reduce lr for first epoch # Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html step: False # int, period of learning rate decay. False to not apply diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index fbf19c07..f7df5e1c 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -491,6 +491,7 @@ def __init__( post_process_delay=None, post_process_freeze=None, post_process_unfreeze=None, + n_epoch=None, ): """ Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. @@ -687,6 +688,7 @@ def __init__( # optimizer self.clip_grad_norm = clip_grad self.optimizer_config = optimizer + self.n_epoch = n_epoch self.set_optimizer() # metrics @@ -750,10 +752,14 @@ def set_optimizer(self, last_epoch=-1): print("USING ADAMW") self.optimizer = torch.optim.AdamW( [ - {'params': [p for p in self.recon.parameters() if p.dim() > 1]}, - {'params': [p for p in self.recon.parameters() if p.dim() <= 1], 'weight_decay': 0} # no weight decay on bias terms + {"params": [p for p in self.recon.parameters() if p.dim() > 1]}, + { + "params": [p for p in self.recon.parameters() if p.dim() <= 1], + "weight_decay": 0, + }, # no weight decay on bias terms ], - lr=self.optimizer_config.lr, weight_decay=0.01 + lr=self.optimizer_config.lr, + weight_decay=0.01, ) else: print(f"USING {self.optimizer_config.type}") @@ -780,11 +786,28 @@ def learning_rate_function(epoch): elif self.optimizer_config.final_lr: assert self.optimizer_config.final_lr < self.optimizer_config.lr + assert self.n_epoch is not None - # linearly decrease learning rate to final_lr + # # linear decay + # def learning_rate_function(epoch): + # slope = (start / final - 1) / (n_epoch) + # return 1 / (1 + slope * epoch) + + # exponential decay def learning_rate_function(epoch): - return 1 - (epoch / self.optimizer_config.final_lr) - + final_decay = self.optimizer_config.final_lr / self.optimizer_config.lr + final_decay = final_decay ** (1 / (self.n_epoch - 1)) + return final_decay**epoch + + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=learning_rate_function, last_epoch=last_epoch + ) + + elif self.optimizer_config.exp_decay: + + def learning_rate_function(epoch): + return self.optimizer_config.exp_decay**epoch + self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=learning_rate_function, last_epoch=last_epoch ) diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 21b1c004..3d5688b8 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -569,6 +569,7 @@ def train_learned(config): pre_proc_aux=config.pre_proc_aux, extra_eval_sets=extra_eval_sets if config.files.extra_eval is not None else None, use_wandb=True if config.wandb_project is not None else False, + n_epoch=config.training.epoch, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx)