Skip to content

Commit

Permalink
add a new image_size parameter in train_dalle and generate
Browse files Browse the repository at this point in the history
VAE models can be use with patches of any size.
For example a model trained on 16x16 patches can still be used on 32x32 patches
that increase the seq length from 256 to 1024 in dalle
  • Loading branch information
rom1504 committed Jul 2, 2021
1 parent 01e402e commit e68a30f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
8 changes: 4 additions & 4 deletions dalle_pytorch/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ def download(url, filename = None, root = CACHE_PATH):
# pretrained Discrete VAE from OpenAI

class OpenAIDiscreteVAE(nn.Module):
def __init__(self):
def __init__(self, image_size=256):
super().__init__()

self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))

self.num_layers = 3
self.image_size = 256
self.image_size = image_size
self.num_tokens = 8192

@torch.no_grad()
Expand Down Expand Up @@ -142,7 +142,7 @@ def instantiate_from_config(config):
return get_obj_from_str(config["target"])(**config.get("params", dict()))

class VQGanVAE(nn.Module):
def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
def __init__(self, image_size=256, vqgan_model_path=None, vqgan_config_path=None):
super().__init__()

if vqgan_model_path is None:
Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]
self.num_layers = int(log(f)/log(2))
self.image_size = 256
self.image_size = image_size
self.num_tokens = config.model.params.n_embed
self.is_gumbel = isinstance(self.model, GumbelVQ)

Expand Down
13 changes: 8 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
parser.add_argument('--top_k', type = float, default = 0.9, required = False,
help='top k filter threshold')

parser.add_argument('--image_size', type = int, default = 256, required = False,
help='image size')

parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
help='output directory')

Expand Down Expand Up @@ -81,12 +84,14 @@ def exists(val):

dalle_params.pop('vae', None) # cleanup later

IMAGE_SIZE = args.image_size

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
elif not args.taming:
vae = OpenAIDiscreteVAE()
vae = OpenAIDiscreteVAE(IMAGE_SIZE)
else:
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
vae = VQGanVAE(IMAGE_SIZE, args.vqgan_model_path, args.vqgan_config_path)


dalle = DALLE(vae = vae, **dalle_params).cuda()
Expand All @@ -95,8 +100,6 @@ def exists(val):

# generate images

image_size = vae.image_size

texts = args.text.split('|')

for text in tqdm(texts):
Expand Down
18 changes: 9 additions & 9 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@

model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight')

model_group.add_argument('--image_size', default = 256, type = int, help = 'Image size')

model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.')

args = parser.parse_args()
Expand Down Expand Up @@ -173,6 +175,7 @@ def cp_path_to_dir(cp_path, tag):
SAVE_EVERY_N_STEPS = args.save_every_n_steps
KEEP_N_CHECKPOINTS = args.keep_n_checkpoints

IMAGE_SIZE = args.image_size
MODEL_DIM = args.dim
TEXT_SEQ_LEN = args.text_seq_len
DEPTH = args.depth
Expand Down Expand Up @@ -242,17 +245,16 @@ def cp_path_to_dir(cp_path, tag):
scheduler_state = loaded_obj.get('scheduler_state')

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
else:
if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()
vae = OpenAIDiscreteVAE(IMAGE_SIZE)

dalle_params = dict(
**dalle_params
)
IMAGE_SIZE = vae.image_size
resume_epoch = loaded_obj.get('epoch', 0)
else:
if exists(VAE_PATH):
Expand All @@ -268,19 +270,17 @@ def cp_path_to_dir(cp_path, tag):

vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
vae.load_state_dict(weights)
else:
if distr_backend.is_root_worker():
print('using pretrained VAE for encoding images to tokens')
vae_params = None

if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()

IMAGE_SIZE = vae.image_size
vae = OpenAIDiscreteVAE(IMAGE_SIZE)

dalle_params = dict(
num_text_tokens=tokenizer.vocab_size,
Expand Down

0 comments on commit e68a30f

Please sign in to comment.