-
Notifications
You must be signed in to change notification settings - Fork 642
Image Mask e.g. "the exact same cat on the top as sketch on the bottom"
afiaka87 edited this page Apr 18, 2021
·
1 revision
from torchvision import transforms
mport argparse
from pathlib import Path
from tqdm import tqdm
# torch
import torch
from einops import repeat
from PIL import Image
from torchvision.utils import make_grid, save_image
from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer
parser = argparse.ArgumentParser()
parser.add_argument('--dalle_path', type = str, required = True,
help='path to your trained DALL-E')
parser.add_argument('--text', type = str, required = True,
help='your text prompt')
parser.add_argument('--num_images', type = int, default = 128, required = False,
help='number of images')
parser.add_argument('--batch_size', type = int, default = 4, required = False,
help='batch size')
parser.add_argument('--top_k', type = float, default = 0.9, required = False,
help='top k filter threshold')
parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False)
parser.add_argument('--bpe_path', type = str,
help='path to your huggingface BPE json file')
parser.add_argument('--chinese', dest='chinese', action = 'store_true')
parser.add_argument('--taming', dest='taming', action='store_true')
args = parser.parse_args()
def exists(val):
return val is not None
if exists(args.bpe_path):
tokenizer = HugTokenizer(args.bpe_path)
elif args.chinese:
tokenizer = ChineseTokenizer()
# load DALL-E
dalle_path = Path(args.dalle_path)
assert dalle_path.exists(), 'trained DALL-E must exist'
load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')
dalle_params.pop('vae', None) # cleanup later
if vae_params is not None:
vae = DiscreteVAE(**vae_params)
elif not args.taming:
vae = OpenAIDiscreteVAE()
else:
vae = VQGanVAE1024()
txt = "this bird has wings that are brown with a white belly"
img_path = "the_dog_picture.jpg"
img = Image.open(img_path)
tf = transforms.Compose([
transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
transforms.RandomResizedCrop(256, scale=(0.95, 1.0), ratio=(1.0, 1.0)),
transforms.ToTensor(),
])
img = tf(img).cuda()
imgs = img.repeat(args.batch_size,1,1,1)
dalle = DALLE(vae = vae, **dalle_params).cuda()
tf = transforms.Compose([
transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
transforms.RandomResizedCrop(256, scale=(0.95, 1.0), ratio=(1.0, 1.0)),
transforms.ToTensor(),
])
img = tf(img).cuda()
imgs = img.repeat(args.batch_size,1,1,1)
dalle = DALLE(vae = vae, **dalle_params).cuda()
dalle.load_state_dict(weights)
# generate images
image_size = vae.image_size
texts = args.text.split('|')
for text in tqdm(texts):
text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda()
text = repeat(text, '() n -> b n', b = args.num_images)
outputs = []
for text_chunk in tqdm(text.split(args.batch_size), desc = f'generating images for - {text}'):
output = dalle.generate_images(text_chunk, filter_thres = args.top_k, img=imgs)
outputs.append(output)
outputs = torch.cat(outputs)
# save all images
outputs_dir = Path(args.outputs_dir) / args.text.replace(' ', '_')
outputs_dir.mkdir(parents = True, exist_ok = True)
for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
print(f'created {args.num_images} images at "{str(outputs_dir)}"')