Skip to content

Commit

Permalink
address lucidrains#293
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 19, 2024
1 parent 7a77b45 commit c59ebf4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
17 changes: 11 additions & 6 deletions denoising_diffusion_pytorch/karras_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def forward(self, x):
# forced weight normed conv2d and linear
# algorithm 1 in paper

def normalize_weight(weight, eps = 1e-4):
weight, ps = pack_one(weight, 'o *')
normed_weight = l2norm(weight, eps = eps)
normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
return unpack_one(normed_weight, ps, 'o *')

class Conv2d(Module):
def __init__(
self,
Expand All @@ -142,14 +148,13 @@ def __init__(
self.concat_ones_to_input = concat_ones_to_input

def forward(self, x):

if self.training:
with torch.no_grad():
weight, ps = pack_one(self.weight, 'o *')
normed_weight = l2norm(weight, eps = self.eps)
normed_weight = unpack_one(normed_weight, ps, 'o *')
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)

weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)

if self.concat_ones_to_input:
x = F.pad(x, (0, 0, 0, 0, 1, 0), value = 1.)
Expand All @@ -167,10 +172,10 @@ def __init__(self, dim_in, dim_out, eps = 1e-4):
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = l2norm(self.weight, eps = self.eps)
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)

weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
return F.linear(x, weight)

# mp fourier embeds
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.10'
__version__ = '1.10.12'

0 comments on commit c59ebf4

Please sign in to comment.