Skip to content

Commit

Permalink
Merge pull request #16 from lucidrains/pw/sparse-attention
Browse files Browse the repository at this point in the history
add sparse attention to DALL-E
  • Loading branch information
lucidrains authored Jan 12, 2021
2 parents 66a3505 + c5555da commit 795dfb3
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 10 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,36 @@ dalle = DALLE(
)
```

## Sparse Attention

You can also train with Microsoft Deepspeed's Sparse Attention, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process.

First, you need to install Deepspeed with Sparse Attention

```bash
$ sh install_deepspeed.sh
```

Next, you need to install the pip package `triton`

```bash
$ pip install triton
```

If both of the above succeeded, now you can train with Sparse Attention!

```python
dalle = DALLE(
dim = 512,
vae = vae,
num_text_tokens = 10000,
text_seq_len = 256,
depth = 64,
heads = 8,
sparse_attn = (True, False) * 32 # interleave sparse and dense attention for 64 layers
)
```

## Citations

```bibtex
Expand Down
11 changes: 7 additions & 4 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.text_transformer = Transformer(causal = False, dim = dim_text, depth = text_enc_depth, heads = text_heads)
self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)

assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
Expand All @@ -189,7 +189,7 @@ def __init__(
self.visual_patch_size = visual_patch_size
self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
self.visual_transformer = Transformer(causal = False, dim = dim_image, depth = visual_enc_depth, heads = visual_heads)
self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads)
self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)

self.temperature = nn.Parameter(torch.tensor(1.))
Expand Down Expand Up @@ -251,7 +251,8 @@ def __init__(
dim_head = 64,
reversible = False,
attn_dropout = 0.,
ff_dropout = 0
ff_dropout = 0,
sparse_attn = False
):
super().__init__()
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
Expand Down Expand Up @@ -284,12 +285,14 @@ def __init__(
self.transformer = Transformer(
dim = dim,
causal = True,
seq_len = seq_len,
depth = depth,
heads = heads,
dim_head = dim_head,
reversible = reversible,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
ff_dropout = ff_dropout,
sparse_attn = sparse_attn
)

self.to_logits = nn.Sequential(
Expand Down
71 changes: 66 additions & 5 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from inspect import isfunction
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, repeat

from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence

Expand All @@ -10,6 +11,14 @@
def exists(val):
return val is not None

def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d

def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth

# classes

class PreNorm(nn.Module):
Expand Down Expand Up @@ -40,10 +49,11 @@ def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, causal = True, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim ** -0.5
self.causal = causal

Expand Down Expand Up @@ -78,26 +88,77 @@ def forward(self, x, mask = None):
out = self.to_out(out)
return out

class SparseAttention(Attention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
self.block_size = 16

self.attn_fn = SparseSelfAttention(
sparsity_config = VariableSparsityConfig(
num_heads = self.heads,
block = self.block_size,
attention = 'unidirectional' if self.causal else 'bidirectional'
),
max_seq_length = self.seq_len,
attn_mask_mode = 'add'
)

def forward(self, x, mask = None):
b, n, _, h, device = *x.shape, self.heads, x.device
remainder = n % self.block_size
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())

if remainder > 0:
padding = self.block_size - remainder
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = F.pad(mask, (0, padding), value = False)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

key_pad_mask = None
if exists(mask):
key_pad_mask = ~mask

attn_mask = None
if self.causal:
i, j = q.shape[-2], k.shape[-2]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
attn_mask = torch.zeros(i, j, device = device).to(q)
mask_value = -(torch.finfo(q.dtype).max / 2)
attn_mask.masked_fill_(mask, mask_value)

out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out[:, :n]

class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
seq_len,
reversible = False,
causal = True,
heads = 8,
dim_head = 64,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
sparse_attn = True
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)

for _, sparse_attn in zip(range(depth), sparse_layer):
attn_class = Attention if not sparse_attn else SparseAttention

for _ in range(depth):
layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, causal = causal, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
]))

Expand Down
3 changes: 3 additions & 0 deletions install_deepspeed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sudo apt-get -y install llvm-9-dev cmake
git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed
cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -s
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.35',
version = '0.0.36',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 795dfb3

Please sign in to comment.