From 40f41199b3f4a355108c64db3ef018d9271bb131 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 5 Feb 2021 14:22:41 -0800 Subject: [PATCH] fix causal masking in sparse conv attention --- dalle_pytorch/attention.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dalle_pytorch/attention.py b/dalle_pytorch/attention.py index 23aec7cf..98b3c103 100644 --- a/dalle_pytorch/attention.py +++ b/dalle_pytorch/attention.py @@ -143,7 +143,8 @@ def forward(self, x, mask = None): # mask image attention - mask = rearrange(img_seq, 'i -> () i ()') <= k_img_indices + q_img_indices = rearrange(img_seq, 'i -> () i ()') + mask = q_img_indices >= k_img_indices # image can attend to all of text diff --git a/setup.py b/setup.py index 424f9ca0..ddaaee39 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.52', + version = '0.0.53', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',