Skip to content

Commit

Permalink
pad to block size of sparse attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 12, 2021
1 parent dc1c6d5 commit c5555da
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
20 changes: 18 additions & 2 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from inspect import isfunction
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
Expand All @@ -10,6 +11,11 @@
def exists(val):
return val is not None

def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d

def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth

Expand Down Expand Up @@ -86,11 +92,12 @@ class SparseAttention(Attention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
self.block_size = 16

self.attn_fn = SparseSelfAttention(
sparsity_config = VariableSparsityConfig(
num_heads = self.heads,
block = 16,
block = self.block_size,
attention = 'unidirectional' if self.causal else 'bidirectional'
),
max_seq_length = self.seq_len,
Expand All @@ -99,6 +106,14 @@ def __init__(self, *args, **kwargs):

def forward(self, x, mask = None):
b, n, _, h, device = *x.shape, self.heads, x.device
remainder = n % self.block_size
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())

if remainder > 0:
padding = self.block_size - remainder
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = F.pad(mask, (0, padding), value = False)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

Expand All @@ -116,7 +131,8 @@ def forward(self, x, mask = None):

out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
out = self.to_out(out)
return out[:, :n]

class Transformer(nn.Module):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.35',
version = '0.0.36',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c5555da

Please sign in to comment.