From e789cab9c112b0e7134a4da48a7d96dcde5b6220 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 15 Jan 2021 12:52:00 -0800 Subject: [PATCH] add some random blocks to sparse attention, default to a quarter of the blocks --- dalle_pytorch/transformer.py | 15 ++++++++++++--- setup.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index 315aab1f..e75b1dfb 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -94,17 +94,26 @@ def forward(self, x, mask = None): return out class SparseAttention(Attention): - def __init__(self, *args, sparse_attn_global_indices = [], block_size = 16, **kwargs): + def __init__( + self, + *args, + block_size = 16, + num_random_blocks = None, + sparse_attn_global_indices = [], + **kwargs + ): super().__init__(*args, **kwargs) from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig - self.block_size = block_size - global_blocks = uniq(map(lambda t: t // self.block_size, sparse_attn_global_indices)) + + num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4) + global_blocks = uniq(map(lambda t: t // block_size, sparse_attn_global_indices)) self.attn_fn = SparseSelfAttention( sparsity_config = VariableSparsityConfig( num_heads = self.heads, block = self.block_size, + num_random_blocks = num_random_blocks, global_block_indices = global_blocks, attention = 'unidirectional' if self.causal else 'bidirectional' ), diff --git a/setup.py b/setup.py index 1a725420..a221bb84 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.37', + version = '0.0.39', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',