Skip to content

Commit

Permalink
Merge pull request #293 from rom1504/vqgan_custom
Browse files Browse the repository at this point in the history
add vqgan_model_path and vqgan_config_path parameters for custom vqgan support
  • Loading branch information
lucidrains authored Jun 15, 2021
2 parents cec0797 + 907737f commit 50fb971
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 26 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,15 @@ In contrast to OpenAI's VAE, it also has an extra layer of downsampling, so the
Update - <a href="https://github.com/lucidrains/DALLE-pytorch/discussions/131">it works!</a>

```python
from dalle_pytorch import VQGanVAE1024
from dalle_pytorch import VQGanVAE

vae = VQGanVAE1024()
vae = VQGanVAE()

# the rest is the same as the above example
```

The default VQGan is the codebook size 1024 one trained on imagenet. If you wish to use a different one, you can use the `vqgan_model_path` and `vqgan_config_path` to pass the .ckpt file and the .yaml file. These options can be used both in train-dalle script or as argument of VQGanVAE class. Other pretrained VQGAN can be found in [taming transformers readme](https://github.com/CompVis/taming-transformers#overview-of-pretrained-models). If you want to train a custom one you can [follow this guide](https://github.com/CompVis/taming-transformers/pull/54)

## Ranking the generations

Train CLIP
Expand Down
2 changes: 1 addition & 1 deletion dalle_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE1024
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
5 changes: 2 additions & 3 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from einops import rearrange

from dalle_pytorch import distributed_utils
from dalle_pytorch.vae import OpenAIDiscreteVAE
from dalle_pytorch.vae import VQGanVAE1024
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
from dalle_pytorch.transformer import Transformer, DivideMax

# helpers
Expand Down Expand Up @@ -325,7 +324,7 @@ def __init__(
stable = False
):
super().__init__()
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024)), 'vae must be an instance of DiscreteVAE'
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'

image_size = vae.image_size
num_image_tokens = vae.num_tokens
Expand Down
33 changes: 20 additions & 13 deletions dalle_pytorch/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import yaml
from pathlib import Path
from tqdm import tqdm
from math import sqrt
from math import sqrt, log
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel

Expand Down Expand Up @@ -129,27 +129,34 @@ def forward(self, img):
# VQGAN from Taming Transformers paper
# https://arxiv.org/abs/2012.09841

class VQGanVAE1024(nn.Module):
def __init__(self):
class VQGanVAE(nn.Module):
def __init__(self, vqgan_model_path, vqgan_config_path):
super().__init__()

model_filename = 'vqgan.1024.model.ckpt'
config_filename = 'vqgan.1024.config.yml'

download(VQGAN_VAE_CONFIG_PATH, config_filename)
download(VQGAN_VAE_PATH, model_filename)

config = OmegaConf.load(str(Path(CACHE_PATH) / config_filename))
if vqgan_model_path is None:
model_filename = 'vqgan.1024.model.ckpt'
config_filename = 'vqgan.1024.config.yml'
download(VQGAN_VAE_CONFIG_PATH, config_filename)
download(VQGAN_VAE_PATH, model_filename)
config_path = str(Path(CACHE_PATH) / config_filename)
model_path = str(Path(CACHE_PATH) / model_filename)
else:
model_path = vqgan_model_path
config_path = vqgan_config_path

config = OmegaConf.load(config_path)
model = VQModel(**config.model.params)

state = torch.load(str(Path(CACHE_PATH) / model_filename), map_location = 'cpu')['state_dict']
state = torch.load(model_path, map_location = 'cpu')['state_dict']
model.load_state_dict(state, strict = False)

print(f"Loaded VQGAN from {model_path} and {config_path}")

self.model = model

self.num_layers = 4
self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0])/log(2))
self.image_size = 256
self.num_tokens = 1024
self.num_tokens = config.model.params.n_embed

self._register_external_parameters()

Expand Down
10 changes: 8 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# dalle related classes and utils

from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024, DALLE
from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer

# argument parsing
Expand All @@ -25,6 +25,12 @@
parser.add_argument('--dalle_path', type = str, required = True,
help='path to your trained DALL-E')

parser.add_argument('--vqgan_model_path', type=str, default = None,
help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

parser.add_argument('--vqgan_config_path', type=str, default = None,
help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')

parser.add_argument('--text', type = str, required = True,
help='your text prompt')

Expand Down Expand Up @@ -80,7 +86,7 @@ def exists(val):
elif not args.taming:
vae = OpenAIDiscreteVAE()
else:
vae = VQGanVAE1024()
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)


dalle = DALLE(vae = vae, **dalle_params).cuda()
Expand Down
22 changes: 17 additions & 5 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE1024, DiscreteVAE, DALLE
from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE
from dalle_pytorch import distributed_utils
from dalle_pytorch.loader import TextImageDataset
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer
Expand All @@ -29,6 +29,12 @@
group.add_argument('--dalle_path', type=str,
help='path to your partially trained DALL-E')

parser.add_argument('--vqgan_model_path', type=str, default = None,
help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

parser.add_argument('--vqgan_config_path', type=str, default = None,
help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')

parser.add_argument('--image_text_folder', type=str, required=True,
help='path to your folder of images and text for learning the DALL-E')

Expand Down Expand Up @@ -135,6 +141,8 @@ def cp_path_to_dir(cp_path, tag):
DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name + ".pt"

VAE_PATH = args.vae_path
VQGAN_MODEL_PATH = args.vqgan_model_path
VQGAN_CONFIG_PATH = args.vqgan_config_path
DALLE_PATH = args.dalle_path
RESUME = exists(DALLE_PATH)

Expand Down Expand Up @@ -195,8 +203,10 @@ def cp_path_to_dir(cp_path, tag):
if vae_params is not None:
vae = DiscreteVAE(**vae_params)
else:
vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
vae = vae_klass()
if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()

dalle_params = dict(
**dalle_params
Expand All @@ -223,8 +233,10 @@ def cp_path_to_dir(cp_path, tag):
print('using pretrained VAE for encoding images to tokens')
vae_params = None

vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
vae = vae_klass()
if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()

IMAGE_SIZE = vae.image_size

Expand Down

0 comments on commit 50fb971

Please sign in to comment.