Skip to content

Commit

Permalink
use logits masking to ensure image is produced after text
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 7, 2021
1 parent 3302a76 commit 8995720
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
37 changes: 29 additions & 8 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def masked_mean(t, mask, dim = 1):

# sampling helpers

def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
Expand All @@ -34,16 +35,16 @@ def generate_images(
text,
clipper = None,
mask = None,
filter_thres = 0.9,
filter_thres = 0.5,
temperature = 1.
):
x = text
out = x

text_seq_len = model.text_seq_len
image_seq_len = model.image_seq_len
total_len = text_seq_len + model.image_seq_len - text.shape[1]

out = x
for _ in range(total_len):
text, image = x[:, :text_seq_len], x[:, text_seq_len:]
logits = model(text, image, mask = mask)[:, -1, :]
Expand All @@ -57,9 +58,9 @@ def generate_images(
mask = F.pad(mask, (0, 1), value=True)

text_seq = torch.cat((x[:, :1], out[:, :(text_seq_len - 1)]), dim = 1)
img_seq = out[:, -(image_seq_len + 1):-1]

img_seq = out[:, -image_seq_len:]
img_seq -= model.num_text_tokens
img_seq.clamp_(min = 0, max = (model.num_image_tokens - 1)) # extra insurance - todo: get rid of this at a future date and rely only on masking of logits

images = vae.decode(img_seq)

Expand All @@ -74,7 +75,7 @@ def generate_images(
class DiscreteVAE(nn.Module):
def __init__(
self,
num_tokens,
num_tokens = 512,
dim = 512,
hidden_dim = 64,
num_layers = 3
Expand Down Expand Up @@ -253,7 +254,9 @@ def __init__(
self.num_image_tokens = num_image_tokens
self.text_seq_len = text_seq_len
self.image_seq_len = image_seq_len
self.total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS

total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS
self.total_tokens = total_tokens

self.vae = vae
if exists(self.vae):
Expand All @@ -267,6 +270,20 @@ def __init__(
nn.Linear(dim, self.total_tokens),
)

seq_range = torch.arange(text_seq_len + image_seq_len)
logits_range = torch.arange(total_tokens)

seq_range = rearrange(seq_range, 'n -> () n ()')
logits_range = rearrange(logits_range, 'd -> () () d')

logits_mask = (
((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
(logits_range >= (total_tokens - 1))
)

self.register_buffer('logits_mask', logits_mask)

def forward(
self,
text,
Expand Down Expand Up @@ -299,6 +316,10 @@ def forward(
logits = self.to_logits(out)

if not return_loss:
seq_len = tokens.shape[1]
mask = self.logits_mask[:, :seq_len]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(mask, max_neg_value)
return logits

assert exists(image), 'when training, image must be supplied'
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.11',
version = '0.0.12',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 8995720

Please sign in to comment.