Skip to content

Commit

Permalink
stability measure 3
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 28, 2021
1 parent f794ba6 commit e8c2d99
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
30 changes: 22 additions & 8 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@ def default(val, d):
def max_neg_value(t):
return -torch.finfo(t.dtype).max

def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True)
return (t * alpha).softmax(dim = dim)

# classes

class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head ** -0.5

self.stable = stable
self.causal = causal

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
Expand All @@ -42,6 +48,8 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou

def forward(self, x, mask = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

Expand All @@ -60,7 +68,7 @@ def forward(self, x, mask = None):
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)

attn = dots.softmax(dim=-1)
attn = softmax(dots, dim=-1)

out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand All @@ -70,7 +78,7 @@ def forward(self, x, mask = None):
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation

class SparseConvCausalAttention(nn.Module):
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., **kwargs):
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
super().__init__()
assert kernel_size % 2 == 1, 'kernel size must be odd'

Expand All @@ -82,6 +90,8 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
self.kernel_size = kernel_size
self.dilation = dilation

self.stable = stable

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -91,6 +101,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,

def forward(self, x, mask = None):
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax

img_seq_len = img_size ** 2
text_len = seq_len + 1 - img_seq_len
Expand Down Expand Up @@ -121,7 +132,7 @@ def forward(self, x, mask = None):
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)

attn_text = dots_text.softmax(dim = -1)
attn_text = softmax(dots_text, dim = -1)
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

# image attention
Expand Down Expand Up @@ -163,7 +174,7 @@ def forward(self, x, mask = None):
dots = torch.cat((dots_image_to_text, dots_image), dim = -1)
dots.masked_fill_(mask, mask_value)

attn = dots.softmax(dim = -1)
attn = softmax(dots, dim = -1)

# aggregate

Expand All @@ -185,7 +196,7 @@ def forward(self, x, mask = None):
# sparse axial causal attention

class SparseAxialCausalAttention(nn.Module):
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., **kwargs):
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
super().__init__()
assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
self.axis = axis
Expand All @@ -196,6 +207,8 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
self.scale = dim_head ** -0.5
self.image_size = image_size

self.stable = stable

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -205,6 +218,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head

def forward(self, x, mask = None):
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax

img_seq_len = img_size ** 2
text_len = seq_len + 1 - img_seq_len
Expand Down Expand Up @@ -235,7 +249,7 @@ def forward(self, x, mask = None):
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)

attn_text = dots_text.softmax(dim = -1)
attn_text = softmax(dots_text, dim = -1)
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

# image attention
Expand Down Expand Up @@ -267,7 +281,7 @@ def forward(self, x, mask = None):

# attention.

attn = dots.softmax(dim = -1)
attn = softmax(dots, dim = -1)

# aggregate

Expand Down
8 changes: 4 additions & 4 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ def __init__(

for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
if attn_type == 'full':
attn_class = Attention
attn_class = partial(Attention, stable = stable)
elif attn_type == 'sparse':
attn_class = SparseAttention
elif attn_type == 'axial_row':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size)
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
elif attn_type == 'axial_col':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size)
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
elif attn_type == 'conv_like':
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size)
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
elif attn_type == 'mlp':
attn_class = partial(gMLPBlock, seq_len = seq_len)
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.12.2',
version = '0.12.4',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e8c2d99

Please sign in to comment.