From 9295685a8658cb508d12446598ecc6c34c3c542f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 7 Feb 2024 10:37:19 -0800 Subject: [PATCH] the magnitude preserving unet works best with inverse square root decay learning schedule --- denoising_diffusion_pytorch/__init__.py | 2 +- denoising_diffusion_pytorch/karras_unet.py | 16 ++++++++++++++++ denoising_diffusion_pytorch/version.py | 2 +- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/denoising_diffusion_pytorch/__init__.py b/denoising_diffusion_pytorch/__init__.py index 05ded85d2..c63fcb09c 100644 --- a/denoising_diffusion_pytorch/__init__.py +++ b/denoising_diffusion_pytorch/__init__.py @@ -8,4 +8,4 @@ from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D -from denoising_diffusion_pytorch.karras_unet import KarrasUnet +from denoising_diffusion_pytorch.karras_unet import KarrasUnet, InvSqrtDecayLRSched diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index f25cba636..7db74e4ff 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -9,6 +9,7 @@ import torch from torch import nn, einsum from torch.nn import Module, ModuleList +from torch.optim.lr_scheduler import LambdaLR import torch.nn.functional as F from einops import rearrange, repeat, pack, unpack @@ -680,6 +681,21 @@ def forward(self, x): return x +# works best with inverse square root decay schedule + +def InvSqrtDecayLRSched( + optimizer, + t_ref = 70000, + sigma_ref = 0.01 +): + """ + refer to equation 67 and Table1 + """ + def inv_sqrt_decay_fn(t: int): + return sigma_ref / sqrt(max(t / t_ref, 1.)) + + return LambdaLR(optimizer, lr_lambda = inv_sqrt_decay_fn) + # example if __name__ == '__main__': diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 420b4d0ed..8047ba3fd 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.5' +__version__ = '1.10.7'