From 907737fdeba55349393d9312cff14256289de4f9 Mon Sep 17 00:00:00 2001 From: "r.beaumont" Date: Fri, 11 Jun 2021 09:48:44 +0000 Subject: [PATCH] add vqgan_path and vqgan_config_path parameters for custom vqgan support load the vqgan model from the provided path and config when not None --- README.md | 6 ++++-- dalle_pytorch/__init__.py | 2 +- dalle_pytorch/dalle_pytorch.py | 5 ++--- dalle_pytorch/vae.py | 33 ++++++++++++++++++++------------- generate.py | 10 ++++++++-- train_dalle.py | 22 +++++++++++++++++----- 6 files changed, 52 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 4372c2d2..aeb2db9f 100644 --- a/README.md +++ b/README.md @@ -162,13 +162,15 @@ In contrast to OpenAI's VAE, it also has an extra layer of downsampling, so the Update - it works! ```python -from dalle_pytorch import VQGanVAE1024 +from dalle_pytorch import VQGanVAE -vae = VQGanVAE1024() +vae = VQGanVAE() # the rest is the same as the above example ``` +The default VQGan is the codebook size 1024 one trained on imagenet. If you wish to use a different one, you can use the `vqgan_model_path` and `vqgan_config_path` to pass the .ckpt file and the .yaml file. These options can be used both in train-dalle script or as argument of VQGanVAE class. Other pretrained VQGAN can be found in [taming transformers readme](https://github.com/CompVis/taming-transformers#overview-of-pretrained-models). If you want to train a custom one you can [follow this guide](https://github.com/CompVis/taming-transformers/pull/54) + ## Ranking the generations Train CLIP diff --git a/dalle_pytorch/__init__.py b/dalle_pytorch/__init__.py index 88d77499..c242583d 100644 --- a/dalle_pytorch/__init__.py +++ b/dalle_pytorch/__init__.py @@ -1,2 +1,2 @@ from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE -from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE1024 +from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index ab257871..135c2a01 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -7,8 +7,7 @@ from einops import rearrange from dalle_pytorch import distributed_utils -from dalle_pytorch.vae import OpenAIDiscreteVAE -from dalle_pytorch.vae import VQGanVAE1024 +from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE from dalle_pytorch.transformer import Transformer, DivideMax # helpers @@ -325,7 +324,7 @@ def __init__( stable = False ): super().__init__() - assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024)), 'vae must be an instance of DiscreteVAE' + assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE' image_size = vae.image_size num_image_tokens = vae.num_tokens diff --git a/dalle_pytorch/vae.py b/dalle_pytorch/vae.py index fa58a348..09b88d23 100644 --- a/dalle_pytorch/vae.py +++ b/dalle_pytorch/vae.py @@ -10,7 +10,7 @@ import yaml from pathlib import Path from tqdm import tqdm -from math import sqrt +from math import sqrt, log from omegaconf import OmegaConf from taming.models.vqgan import VQModel @@ -129,27 +129,34 @@ def forward(self, img): # VQGAN from Taming Transformers paper # https://arxiv.org/abs/2012.09841 -class VQGanVAE1024(nn.Module): - def __init__(self): +class VQGanVAE(nn.Module): + def __init__(self, vqgan_model_path, vqgan_config_path): super().__init__() - model_filename = 'vqgan.1024.model.ckpt' - config_filename = 'vqgan.1024.config.yml' - - download(VQGAN_VAE_CONFIG_PATH, config_filename) - download(VQGAN_VAE_PATH, model_filename) - - config = OmegaConf.load(str(Path(CACHE_PATH) / config_filename)) + if vqgan_model_path is None: + model_filename = 'vqgan.1024.model.ckpt' + config_filename = 'vqgan.1024.config.yml' + download(VQGAN_VAE_CONFIG_PATH, config_filename) + download(VQGAN_VAE_PATH, model_filename) + config_path = str(Path(CACHE_PATH) / config_filename) + model_path = str(Path(CACHE_PATH) / model_filename) + else: + model_path = vqgan_model_path + config_path = vqgan_config_path + + config = OmegaConf.load(config_path) model = VQModel(**config.model.params) - state = torch.load(str(Path(CACHE_PATH) / model_filename), map_location = 'cpu')['state_dict'] + state = torch.load(model_path, map_location = 'cpu')['state_dict'] model.load_state_dict(state, strict = False) + print(f"Loaded VQGAN from {model_path} and {config_path}") + self.model = model - self.num_layers = 4 + self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0])/log(2)) self.image_size = 256 - self.num_tokens = 1024 + self.num_tokens = config.model.params.n_embed self._register_external_parameters() diff --git a/generate.py b/generate.py index 3d5c9ac0..1e606085 100644 --- a/generate.py +++ b/generate.py @@ -15,7 +15,7 @@ # dalle related classes and utils -from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024, DALLE +from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer # argument parsing @@ -25,6 +25,12 @@ parser.add_argument('--dalle_path', type = str, required = True, help='path to your trained DALL-E') +parser.add_argument('--vqgan_model_path', type=str, default = None, + help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)') + +parser.add_argument('--vqgan_config_path', type=str, default = None, + help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)') + parser.add_argument('--text', type = str, required = True, help='your text prompt') @@ -80,7 +86,7 @@ def exists(val): elif not args.taming: vae = OpenAIDiscreteVAE() else: - vae = VQGanVAE1024() + vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path) dalle = DALLE(vae = vae, **dalle_params).cuda() diff --git a/train_dalle.py b/train_dalle.py index 0fd28d2a..41f41bd1 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader -from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE1024, DiscreteVAE, DALLE +from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE from dalle_pytorch import distributed_utils from dalle_pytorch.loader import TextImageDataset from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer @@ -29,6 +29,12 @@ group.add_argument('--dalle_path', type=str, help='path to your partially trained DALL-E') +parser.add_argument('--vqgan_model_path', type=str, default = None, + help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)') + +parser.add_argument('--vqgan_config_path', type=str, default = None, + help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)') + parser.add_argument('--image_text_folder', type=str, required=True, help='path to your folder of images and text for learning the DALL-E') @@ -132,6 +138,8 @@ def cp_path_to_dir(cp_path, tag): DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name + ".pt" VAE_PATH = args.vae_path +VQGAN_MODEL_PATH = args.vqgan_model_path +VQGAN_CONFIG_PATH = args.vqgan_config_path DALLE_PATH = args.dalle_path RESUME = exists(DALLE_PATH) @@ -190,8 +198,10 @@ def cp_path_to_dir(cp_path, tag): if vae_params is not None: vae = DiscreteVAE(**vae_params) else: - vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 - vae = vae_klass() + if args.taming: + vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH) + else: + vae = OpenAIDiscreteVAE() dalle_params = dict( **dalle_params @@ -218,8 +228,10 @@ def cp_path_to_dir(cp_path, tag): print('using pretrained VAE for encoding images to tokens') vae_params = None - vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 - vae = vae_klass() + if args.taming: + vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH) + else: + vae = OpenAIDiscreteVAE() IMAGE_SIZE = vae.image_size