From e8c2d9948196d6e7d3911767d533dd01f11ea821 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 28 May 2021 14:56:23 -0700 Subject: [PATCH] stability measure 3 --- dalle_pytorch/attention.py | 30 ++++++++++++++++++++++-------- dalle_pytorch/transformer.py | 8 ++++---- setup.py | 2 +- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/dalle_pytorch/attention.py b/dalle_pytorch/attention.py index 943aa292..6a1180cf 100644 --- a/dalle_pytorch/attention.py +++ b/dalle_pytorch/attention.py @@ -22,16 +22,22 @@ def default(val, d): def max_neg_value(t): return -torch.finfo(t.dtype).max +def stable_softmax(t, dim = -1, alpha = 32 ** 2): + t = t / alpha + t = t - torch.amax(t, dim = dim, keepdim = True) + return (t * alpha).softmax(dim = dim) + # classes class Attention(nn.Module): - def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.): + def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): super().__init__() inner_dim = dim_head * heads self.heads = heads self.seq_len = seq_len self.scale = dim_head ** -0.5 + self.stable = stable self.causal = causal self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) @@ -42,6 +48,8 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou def forward(self, x, mask = None): b, n, _, h, device = *x.shape, self.heads, x.device + softmax = torch.softmax if not self.stable else stable_softmax + qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) @@ -60,7 +68,7 @@ def forward(self, x, mask = None): mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() dots.masked_fill_(mask, mask_value) - attn = dots.softmax(dim=-1) + attn = softmax(dots, dim=-1) out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') @@ -70,7 +78,7 @@ def forward(self, x, mask = None): # sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation class SparseConvCausalAttention(nn.Module): - def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., **kwargs): + def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs): super().__init__() assert kernel_size % 2 == 1, 'kernel size must be odd' @@ -82,6 +90,8 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, self.kernel_size = kernel_size self.dilation = dilation + self.stable = stable + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -91,6 +101,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, def forward(self, x, mask = None): b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device + softmax = torch.softmax if not self.stable else stable_softmax img_seq_len = img_size ** 2 text_len = seq_len + 1 - img_seq_len @@ -121,7 +132,7 @@ def forward(self, x, mask = None): text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() dots_text.masked_fill_(text_causal_mask, mask_value) - attn_text = dots_text.softmax(dim = -1) + attn_text = softmax(dots_text, dim = -1) out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) # image attention @@ -163,7 +174,7 @@ def forward(self, x, mask = None): dots = torch.cat((dots_image_to_text, dots_image), dim = -1) dots.masked_fill_(mask, mask_value) - attn = dots.softmax(dim = -1) + attn = softmax(dots, dim = -1) # aggregate @@ -185,7 +196,7 @@ def forward(self, x, mask = None): # sparse axial causal attention class SparseAxialCausalAttention(nn.Module): - def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., **kwargs): + def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs): super().__init__() assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)' self.axis = axis @@ -196,6 +207,8 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head self.scale = dim_head ** -0.5 self.image_size = image_size + self.stable = stable + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -205,6 +218,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head def forward(self, x, mask = None): b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device + softmax = torch.softmax if not self.stable else stable_softmax img_seq_len = img_size ** 2 text_len = seq_len + 1 - img_seq_len @@ -235,7 +249,7 @@ def forward(self, x, mask = None): text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() dots_text.masked_fill_(text_causal_mask, mask_value) - attn_text = dots_text.softmax(dim = -1) + attn_text = softmax(dots_text, dim = -1) out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) # image attention @@ -267,7 +281,7 @@ def forward(self, x, mask = None): # attention. - attn = dots.softmax(dim = -1) + attn = softmax(dots, dim = -1) # aggregate diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index 3cb7f758..238ff358 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -108,15 +108,15 @@ def __init__( for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer): if attn_type == 'full': - attn_class = Attention + attn_class = partial(Attention, stable = stable) elif attn_type == 'sparse': attn_class = SparseAttention elif attn_type == 'axial_row': - attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size) + attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) elif attn_type == 'axial_col': - attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size) + attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) elif attn_type == 'conv_like': - attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size) + attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable) elif attn_type == 'mlp': attn_class = partial(gMLPBlock, seq_len = seq_len) else: diff --git a/setup.py b/setup.py index b3757b0f..9f992084 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '0.12.2', + version = '0.12.4', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',