Skip to content

Commit

Permalink
redo resnet block only to happen at the lowest resolution feature map
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 11, 2021
1 parent 55fe592 commit b96aff5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 30 deletions.
45 changes: 16 additions & 29 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,19 @@ def top_k(logits, thres = 0.5):

# discrete vae class

def ConvBlock(chan_in, chan_out):
return nn.Sequential(
nn.Conv2d(chan_in, chan_out, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan_out, chan_out, 3, padding = 1),
nn.ReLU()
)

class ResBlock(nn.Module):
def __init__(
self,
chan_in,
chan_out,
num_blocks = 1,
upsample = False
):
def __init__(self, chan):
super().__init__()
self.upsample = upsample
conv_kls = nn.ConvTranspose2d if upsample else nn.Conv2d
self.res = conv_kls(chan_in, chan_out, 1, stride = 2) if (num_blocks > 0 and chan_in != chan_out) else always(0)

self.net = nn.Sequential(*[
nn.Sequential(conv_kls(chan_in, chan_out, 4, stride = 2, padding = 1), nn.ReLU()),
*[ConvBlock(chan_out, chan_out) for _ in range(num_blocks)]
])
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan, chan, 1)
)

def forward(self, x):
out = self.net(x)
res_kwargs = {'output_size': out.shape[2:]} if self.upsample else {}
return out + self.res(x, **res_kwargs)
return self.net(x) + x

class DiscreteVAE(nn.Module):
def __init__(
Expand Down Expand Up @@ -115,9 +98,13 @@ def __init__(
dec_layers = []

for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(ResBlock(enc_in, enc_out, num_blocks = num_resnet_blocks))
dec_layers.append(ResBlock(dec_in, dec_out, num_blocks = num_resnet_blocks, upsample = True))

enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))

for _ in range(num_resnet_blocks):
enc_layers.append(ResBlock(enc_chans[-1]))
dec_layers.append(ResBlock(dec_chans[-1]))

enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.33',
version = '0.0.34',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b96aff5

Please sign in to comment.