Skip to content

Commit

Permalink
add a new image_size parameter in train_dalle and generate
Browse files Browse the repository at this point in the history
VAE models can be use with patches of any size.
For example a model trained on 16x16 patches can still be used on 32x32 patches
that increase the seq length from 256 to 1024 in dalle
  • Loading branch information
rom1504 committed Jun 16, 2021
1 parent d6107cc commit 19f0aa5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
8 changes: 4 additions & 4 deletions dalle_pytorch/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ def download(url, filename = None, root = CACHE_PATH):
# pretrained Discrete VAE from OpenAI

class OpenAIDiscreteVAE(nn.Module):
def __init__(self):
def __init__(self, image_size=256):
super().__init__()

self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))

self.num_layers = 3
self.image_size = 256
self.image_size = image_size
self.num_tokens = 8192

@torch.no_grad()
Expand All @@ -130,7 +130,7 @@ def forward(self, img):
# https://arxiv.org/abs/2012.09841

class VQGanVAE(nn.Module):
def __init__(self, vqgan_model_path, vqgan_config_path):
def __init__(self, image_size=256, vqgan_model_path=None, vqgan_config_path=None):
super().__init__()

if vqgan_model_path is None:
Expand All @@ -155,7 +155,7 @@ def __init__(self, vqgan_model_path, vqgan_config_path):
self.model = model

self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0])/log(2))
self.image_size = 256
self.image_size = image_size
self.num_tokens = config.model.params.n_embed

self._register_external_parameters()
Expand Down
13 changes: 8 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
parser.add_argument('--top_k', type = float, default = 0.9, required = False,
help='top k filter threshold')

parser.add_argument('--image_size', type = int, default = 256, required = False,
help='image size')

parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
help='output directory')

Expand Down Expand Up @@ -81,12 +84,14 @@ def exists(val):

dalle_params.pop('vae', None) # cleanup later

IMAGE_SIZE = args.image_size

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
elif not args.taming:
vae = OpenAIDiscreteVAE()
vae = OpenAIDiscreteVAE(IMAGE_SIZE)
else:
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
vae = VQGanVAE(IMAGE_SIZE, args.vqgan_model_path, args.vqgan_config_path)


dalle = DALLE(vae = vae, **dalle_params).cuda()
Expand All @@ -95,8 +100,6 @@ def exists(val):

# generate images

image_size = vae.image_size

texts = args.text.split('|')

for text in tqdm(texts):
Expand Down
18 changes: 9 additions & 9 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@

model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight')

model_group.add_argument('--image_size', default = 256, type = int, help = 'Image size')

model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.')

args = parser.parse_args()
Expand Down Expand Up @@ -155,6 +157,7 @@ def cp_path_to_dir(cp_path, tag):
SAVE_EVERY_N_STEPS = args.save_every_n_steps
KEEP_N_CHECKPOINTS = args.keep_n_checkpoints

IMAGE_SIZE = args.image_size
MODEL_DIM = args.dim
TEXT_SEQ_LEN = args.text_seq_len
DEPTH = args.depth
Expand Down Expand Up @@ -202,17 +205,16 @@ def cp_path_to_dir(cp_path, tag):
scheduler_state = loaded_obj.get('scheduler_state')

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
else:
if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()
vae = OpenAIDiscreteVAE(IMAGE_SIZE)

dalle_params = dict(
**dalle_params
)
IMAGE_SIZE = vae.image_size
resume_epoch = loaded_obj.get('epoch', 0)
else:
if exists(VAE_PATH):
Expand All @@ -228,19 +230,17 @@ def cp_path_to_dir(cp_path, tag):

vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
vae.load_state_dict(weights)
else:
if distr_backend.is_root_worker():
print('using pretrained VAE for encoding images to tokens')
vae_params = None

if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()

IMAGE_SIZE = vae.image_size
vae = OpenAIDiscreteVAE(IMAGE_SIZE)

dalle_params = dict(
num_text_tokens=tokenizer.vocab_size,
Expand Down

0 comments on commit 19f0aa5

Please sign in to comment.