diff --git a/.github/workflows/python_pycsou.yml b/.github/workflows/python_pycsou.yml index 61f89fa5..dfdafb9b 100644 --- a/.github/workflows/python_pycsou.yml +++ b/.github/workflows/python_pycsou.yml @@ -20,7 +20,7 @@ jobs: fail-fast: false max-parallel: 12 matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-12, windows-latest] python-version: [3.9, "3.10"] steps: - uses: actions/checkout@v3 diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 52481bd7..2c3023dd 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -177,6 +177,7 @@ optimizer: final_lr: False # if set, linearly decay to this value exp_decay: False # if set, exponentially decay with this value slow_start: False #float how much to reduce lr for first epoch + cosine_decay_warmup: False # if set, cosine decay with warmup of 5% # Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html step: False # int, period of learning rate decay. False to not apply gamma: 0.1 # float, factor for learning rate decay diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index f7df5e1c..bf8a74e8 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -812,6 +812,24 @@ def learning_rate_function(epoch): self.optimizer, lr_lambda=learning_rate_function, last_epoch=last_epoch ) + elif self.optimizer_config.cosine_decay_warmup: + + total_iterations = len(self.train_dataloader) * self.n_epoch + warmup_steps = int(0.05 * total_iterations) + + def cosine_decay_with_warmup(step, warmup_steps, total_steps): + if step < warmup_steps: + return step / warmup_steps + progress = (step - warmup_steps) / (total_steps - warmup_steps) + return 0.5 * (1 + math.cos(math.pi * progress)) + + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lr_lambda=lambda step: cosine_decay_with_warmup( + step, warmup_steps, total_iterations + ), + ) + elif self.optimizer_config.step: self.scheduler = torch.optim.lr_scheduler.StepLR(