Skip to content

Commit

Permalink
take a gamble on cosine sim attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 18, 2022
1 parent 2b742dd commit 555566c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
13 changes: 8 additions & 5 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def convert_image_to(img_type, image):
return image.convert(img_type)
return image

def l2norm(t):
return F.normalize(t, dim = -1)

# normalization functions

def normalize_to_neg_one_to_one(img):
Expand Down Expand Up @@ -215,9 +218,9 @@ def forward(self, x):
return self.to_out(out)

class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
def __init__(self, dim, heads = 4, dim_head = 32, scale = 16):
super().__init__()
self.scale = dim_head ** -0.5
self.scale = scale
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
Expand All @@ -227,10 +230,10 @@ def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q * self.scale

sim = einsum('b h d i, b h d j -> b h i j', q, k)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
q, k = map(l2norm, (q, k))

sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)

out = einsum('b h i j, b h d j -> b h i d', attn, v)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.25.3',
version = '0.26.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 555566c

Please sign in to comment.