Skip to content

Commit

Permalink
add some random blocks to sparse attention, default to a quarter of t…
Browse files Browse the repository at this point in the history
…he blocks
  • Loading branch information
lucidrains committed Jan 15, 2021
1 parent 5c00ebc commit e789cab
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
15 changes: 12 additions & 3 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,26 @@ def forward(self, x, mask = None):
return out

class SparseAttention(Attention):
def __init__(self, *args, sparse_attn_global_indices = [], block_size = 16, **kwargs):
def __init__(
self,
*args,
block_size = 16,
num_random_blocks = None,
sparse_attn_global_indices = [],
**kwargs
):
super().__init__(*args, **kwargs)
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig

self.block_size = block_size
global_blocks = uniq(map(lambda t: t // self.block_size, sparse_attn_global_indices))

num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4)
global_blocks = uniq(map(lambda t: t // block_size, sparse_attn_global_indices))

self.attn_fn = SparseSelfAttention(
sparsity_config = VariableSparsityConfig(
num_heads = self.heads,
block = self.block_size,
num_random_blocks = num_random_blocks,
global_block_indices = global_blocks,
attention = 'unidirectional' if self.causal else 'bidirectional'
),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.37',
version = '0.0.39',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e789cab

Please sign in to comment.