From 2969103d83d9cfd578eef79cf3bc1f8f7554a615 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 23 Mar 2022 22:36:23 -0700 Subject: [PATCH] cleanup conv-like attention --- dalle_pytorch/attention.py | 14 +++++++------- setup.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dalle_pytorch/attention.py b/dalle_pytorch/attention.py index c5977774..188cb1c4 100644 --- a/dalle_pytorch/attention.py +++ b/dalle_pytorch/attention.py @@ -177,18 +177,18 @@ def forward(self, x, mask = None, rotary_pos_emb = None): dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img) dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text) - # calculate causal attention for local convolution + # use padding of 0 on tensor of 1s and unfold for padding mask i, j = dots_image.shape[-2:] - img_seq = torch.arange(img_seq_len, device = device) - k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size) - k_img_indices = F.pad(k_img_indices, causal_padding, value = img_seq_len) # padding set to be max, so it is never attended to - k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation) - k_img_indices = rearrange(k_img_indices, 'b j i -> b i j') + ones = torch.ones((img_seq_len,), device = device) + ones = rearrange(ones, '(h w) -> () () h w', h = img_size) + ones = F.pad(ones, causal_padding, value = 0.) + ones = F.unfold(ones, kernel_size, dilation = dilation) + ones = rearrange(ones, 'b j i -> b i j') # mask image attention - padding_mask = k_img_indices == img_seq_len + padding_mask = ones == 0. # concat text mask with image causal mask diff --git a/setup.py b/setup.py index bdd2b569..7fa756db 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '1.5.0', + version = '1.5.1', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',