Skip to content

Commit

Permalink
fix decoder resnet blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 11, 2021
1 parent b96aff5 commit d591971
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
12 changes: 9 additions & 3 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
super().__init__()
assert log2(image_size).is_integer(), 'image size must be a power of 2'
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
has_resblocks = num_resnet_blocks > 0

self.image_size = image_size
self.num_tokens = num_tokens
Expand All @@ -87,10 +88,12 @@ def __init__(
hdim = hidden_dim

enc_chans = [hidden_dim] * num_layers
dec_chans = reversed(enc_chans)
dec_chans = list(reversed(enc_chans))

enc_chans = [channels, *enc_chans]
dec_chans = [codebook_dim, *dec_chans]

dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]

enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))

Expand All @@ -102,8 +105,11 @@ def __init__(
dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))

for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(dec_chans[1]))
enc_layers.append(ResBlock(enc_chans[-1]))
dec_layers.append(ResBlock(dec_chans[-1]))

if num_resnet_blocks > 0:
dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))

enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.34',
version = '0.0.35',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d591971

Please sign in to comment.