Skip to content

Commit

Permalink
the magnitude preserving unet works best with inverse square root dec…
Browse files Browse the repository at this point in the history
…ay learning schedule
  • Loading branch information
lucidrains committed Feb 7, 2024
1 parent fd5abb9 commit 9295685
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D

from denoising_diffusion_pytorch.karras_unet import KarrasUnet
from denoising_diffusion_pytorch.karras_unet import KarrasUnet, InvSqrtDecayLRSched
16 changes: 16 additions & 0 deletions denoising_diffusion_pytorch/karras_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack
Expand Down Expand Up @@ -680,6 +681,21 @@ def forward(self, x):

return x

# works best with inverse square root decay schedule

def InvSqrtDecayLRSched(
optimizer,
t_ref = 70000,
sigma_ref = 0.01
):
"""
refer to equation 67 and Table1
"""
def inv_sqrt_decay_fn(t: int):
return sigma_ref / sqrt(max(t / t_ref, 1.))

return LambdaLR(optimizer, lr_lambda = inv_sqrt_decay_fn)

# example

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.5'
__version__ = '1.10.7'

0 comments on commit 9295685

Please sign in to comment.