From 54c15c65ef8e71217a3890722641bf055a6b4c7a Mon Sep 17 00:00:00 2001 From: Clay Mullis Date: Tue, 30 Nov 2021 13:58:25 -0600 Subject: [PATCH 1/4] (fp16/deepspeed w generations) Cast submodules to half except vqgan. --- train_dalle.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) mode change 100644 => 100755 train_dalle.py diff --git a/train_dalle.py b/train_dalle.py old mode 100644 new mode 100755 index 7f07079c..3d66cf31 --- a/train_dalle.py +++ b/train_dalle.py @@ -411,10 +411,12 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) -if not using_deepspeed: - if args.fp16: - dalle = dalle.half() - dalle = dalle.cuda() +if args.fp16: + dalle.vae.requires_grad_(False).float() + for layer in dalle.modules(): + if not isinstance(layer, VQGanVAE): # VQGanVAE is not FP16 compatible + layer.half() +dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) @@ -505,7 +507,6 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available # Prefer scheduler in `deepspeed_config`. if LR_DECAY and distr_scheduler is None: distr_scheduler = scheduler -avoid_model_calls = using_deepspeed and args.fp16 if RESUME and using_deepspeed: distr_dalle.load_checkpoint(str(cp_dir)) @@ -607,16 +608,11 @@ def save_model(path, epoch=0): token_list = sample_text.masked_select(sample_text != 0).tolist() decoded_text = tokenizer.decode(token_list) - if not avoid_model_calls: - # CUDA index errors when we don't guard this - image = dalle.generate_images(text[:1], filter_thres=0.9) # topk sampling at 0.9 - - + image = dalle.generate_images(text[:1], filter_thres=0.9) # topk sampling at 0.9 log = { **log, } - if not avoid_model_calls: - log['image'] = wandb.Image(image, caption=decoded_text) + log['image'] = wandb.Image(image, caption=decoded_text) if i % 10 == 9 and distr_backend.is_root_worker(): sample_per_sec = BATCH_SIZE * 10 / (time.time() - t) From 31fa1af78bd22e60491823b156099134a32934a0 Mon Sep 17 00:00:00 2001 From: Clay Mullis Date: Tue, 30 Nov 2021 13:59:36 -0600 Subject: [PATCH 2/4] Autocast VQGAN calls to fp32 in fp16 training --- dalle_pytorch/vae.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/dalle_pytorch/vae.py b/dalle_pytorch/vae.py index 6b925c92..a7ce70c7 100644 --- a/dalle_pytorch/vae.py +++ b/dalle_pytorch/vae.py @@ -1,12 +1,5 @@ -import io -import sys import os -import requests -import PIL -import warnings -import hashlib import urllib -import yaml from pathlib import Path from tqdm import tqdm from math import sqrt, log @@ -14,9 +7,11 @@ from taming.models.vqgan import VQModel, GumbelVQ import importlib + import torch from torch import nn import torch.nn.functional as F +from torch.cuda.amp.autocast_mode import autocast from einops import rearrange @@ -196,6 +191,7 @@ def _register_external_parameters(self): self, self.model.quantize.embed.weight if self.is_gumbel else self.model.quantize.embedding.weight) @torch.no_grad() + @autocast(enabled=True, dtype=torch.float32, cache_enabled=True) def get_codebook_indices(self, img): b = img.shape[0] img = (2 * img) - 1 @@ -204,6 +200,7 @@ def get_codebook_indices(self, img): return rearrange(indices, 'b h w -> b (h w)', b=b) return rearrange(indices, '(b n) -> b n', b = b) + @autocast(enabled=True, dtype=torch.float32, cache_enabled=True) def decode(self, img_seq): b, n = img_seq.shape one_hot_indices = F.one_hot(img_seq, num_classes = self.num_tokens).float() From 5d843fa5f5b394c9d4e74388b1d5c216bf115c28 Mon Sep 17 00:00:00 2001 From: Clay Mullis Date: Tue, 30 Nov 2021 15:14:29 -0600 Subject: [PATCH 3/4] enable pytorch's built-in automatic mixed precision by default --- dalle_pytorch/dalle_pytorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index fe294465..f416e23b 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -508,6 +508,8 @@ def generate_images( return images + + @torch.autocast(device_type="cuda", enabled=True) def forward( self, text, From 43ea4a3ced12e9c013b5e7cdae7ba8c33ea3e61c Mon Sep 17 00:00:00 2001 From: Clay Mullis Date: Tue, 30 Nov 2021 16:11:14 -0600 Subject: [PATCH 4/4] remove duplicate code --- train_dalle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_dalle.py b/train_dalle.py index 3d66cf31..fd9c35ca 100755 --- a/train_dalle.py +++ b/train_dalle.py @@ -412,7 +412,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available dalle = DALLE(vae=vae, **dalle_params) if args.fp16: - dalle.vae.requires_grad_(False).float() + dalle.vae.float() for layer in dalle.modules(): if not isinstance(layer, VQGanVAE): # VQGanVAE is not FP16 compatible layer.half()