Skip to content

Commit

Permalink
allow for straight-through gumbel softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 17, 2021
1 parent 49fc35a commit 0c3c84c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ from dalle_pytorch import DiscreteVAE

vae = DiscreteVAE(
image_size = 256,
num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
num_tokens = 8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
codebook_dim = 512, # codebook dimension
hidden_dim = 64, # hidden dimension
num_resnet_blocks = 1, # number of resnet blocks
temperature = 0.9 # gumbel softmax temperature, the lower this is, the more hard the discretization
num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
num_tokens = 8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
codebook_dim = 512, # codebook dimension
hidden_dim = 64, # hidden dimension
num_resnet_blocks = 1, # number of resnet blocks
temperature = 0.9, # gumbel softmax temperature, the lower this is, the more hard the discretization
straight_through = False # straight-through for gumbel softmax. unclear if it is better one way or the other
)

images = torch.randn(4, 3, 256, 256)
Expand Down
8 changes: 5 additions & 3 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def __init__(
num_resnet_blocks = 0,
hidden_dim = 64,
channels = 3,
temperature = 0.9
temperature = 0.9,
straight_through = False
):
super().__init__()
assert log2(image_size).is_integer(), 'image size must be a power of 2'
Expand All @@ -83,8 +84,9 @@ def __init__(
self.num_tokens = num_tokens
self.num_layers = num_layers
self.temperature = temperature
self.straight_through = straight_through
self.codebook = nn.Embedding(num_tokens, codebook_dim)

hdim = hidden_dim

enc_chans = [hidden_dim] * num_layers
Expand Down Expand Up @@ -146,7 +148,7 @@ def forward(
if return_logits:
return logits # return logits for getting hard image indices for DALL-E training

soft_one_hot = F.gumbel_softmax(logits, tau = self.temperature, dim = 1)
soft_one_hot = F.gumbel_softmax(logits, tau = self.temperature, dim = 1, hard = self.straight_through)
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
out = self.decoder(sampled)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.39',
version = '0.0.40',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0c3c84c

Please sign in to comment.