diff --git a/dalle_pytorch/vae.py b/dalle_pytorch/vae.py index 09b88d23..990a3fee 100644 --- a/dalle_pytorch/vae.py +++ b/dalle_pytorch/vae.py @@ -96,14 +96,14 @@ def download(url, filename = None, root = CACHE_PATH): # pretrained Discrete VAE from OpenAI class OpenAIDiscreteVAE(nn.Module): - def __init__(self): + def __init__(self, image_size=256): super().__init__() self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH)) self.dec = load_model(download(OPENAI_VAE_DECODER_PATH)) self.num_layers = 3 - self.image_size = 256 + self.image_size = image_size self.num_tokens = 8192 @torch.no_grad() @@ -130,7 +130,7 @@ def forward(self, img): # https://arxiv.org/abs/2012.09841 class VQGanVAE(nn.Module): - def __init__(self, vqgan_model_path, vqgan_config_path): + def __init__(self, image_size=256, vqgan_model_path=None, vqgan_config_path=None): super().__init__() if vqgan_model_path is None: @@ -155,7 +155,7 @@ def __init__(self, vqgan_model_path, vqgan_config_path): self.model = model self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0])/log(2)) - self.image_size = 256 + self.image_size = image_size self.num_tokens = config.model.params.n_embed self._register_external_parameters() diff --git a/generate.py b/generate.py index 1e606085..7d4adb89 100644 --- a/generate.py +++ b/generate.py @@ -43,6 +43,9 @@ parser.add_argument('--top_k', type = float, default = 0.9, required = False, help='top k filter threshold') +parser.add_argument('--image_size', type = int, default = 256, required = False, + help='image size') + parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False, help='output directory') @@ -81,12 +84,14 @@ def exists(val): dalle_params.pop('vae', None) # cleanup later +IMAGE_SIZE = args.image_size + if vae_params is not None: - vae = DiscreteVAE(**vae_params) + vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:]) elif not args.taming: - vae = OpenAIDiscreteVAE() + vae = OpenAIDiscreteVAE(IMAGE_SIZE) else: - vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path) + vae = VQGanVAE(IMAGE_SIZE, args.vqgan_model_path, args.vqgan_config_path) dalle = DALLE(vae = vae, **dalle_params).cuda() @@ -95,8 +100,6 @@ def exists(val): # generate images -image_size = vae.image_size - texts = args.text.split('|') for text in tqdm(texts): diff --git a/train_dalle.py b/train_dalle.py index f4b12109..143fe0f7 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -108,6 +108,8 @@ model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight') +model_group.add_argument('--image_size', default = 256, type = int, help = 'Image size') + model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.') args = parser.parse_args() @@ -155,6 +157,7 @@ def cp_path_to_dir(cp_path, tag): SAVE_EVERY_N_STEPS = args.save_every_n_steps KEEP_N_CHECKPOINTS = args.keep_n_checkpoints +IMAGE_SIZE = args.image_size MODEL_DIM = args.dim TEXT_SEQ_LEN = args.text_seq_len DEPTH = args.depth @@ -202,17 +205,16 @@ def cp_path_to_dir(cp_path, tag): scheduler_state = loaded_obj.get('scheduler_state') if vae_params is not None: - vae = DiscreteVAE(**vae_params) + vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:]) else: if args.taming: - vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH) + vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH) else: - vae = OpenAIDiscreteVAE() + vae = OpenAIDiscreteVAE(IMAGE_SIZE) dalle_params = dict( **dalle_params ) - IMAGE_SIZE = vae.image_size resume_epoch = loaded_obj.get('epoch', 0) else: if exists(VAE_PATH): @@ -228,7 +230,7 @@ def cp_path_to_dir(cp_path, tag): vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] - vae = DiscreteVAE(**vae_params) + vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:]) vae.load_state_dict(weights) else: if distr_backend.is_root_worker(): @@ -236,11 +238,9 @@ def cp_path_to_dir(cp_path, tag): vae_params = None if args.taming: - vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH) + vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH) else: - vae = OpenAIDiscreteVAE() - - IMAGE_SIZE = vae.image_size + vae = OpenAIDiscreteVAE(IMAGE_SIZE) dalle_params = dict( num_text_tokens=tokenizer.vocab_size,