From 18ff256be55df0e2739eefa77def03483ac97cce Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 12:47:40 +0900 Subject: [PATCH 1/6] Implement Text Encoder Training --- trainer/diffusers_trainer.py | 71 +++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index b13056178..452f4620e 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -92,6 +92,8 @@ parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with _cropped.') parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.') +parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + args = parser.parse_args() @@ -731,11 +733,15 @@ def main(): # Freeze vae and text_encoder vae.requires_grad_(False) - text_encoder.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - + + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + if args.use_xformers: unet.set_use_memory_efficient_attention_xformers(True) @@ -753,7 +759,15 @@ def main(): output_device=rank, gradient_as_bucket_view=True ) - + + if args.train_text_encoder: + text_encoder = torch.nn.parallel.DistributedDataParallel( + text_encoder, + device_ids=[rank], + output_device=rank, + gradient_as_bucket_view=True + ) + if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. try: import bitsandbytes as bnb @@ -774,10 +788,12 @@ def main(): ) """ + optimizer_parameters = unet.parameters() if not args.train_text_encoder else itertools.chain(unet.parameters(), text_encoder.parameters()) + # Create distributed optimizer from torch.distributed.optim import ZeroRedundancyOptimizer optimizer = ZeroRedundancyOptimizer( - unet.parameters(), + optimizer_parameters, optimizer_class=optimizer_cls, parameters_as_bucket_view=True, lr=args.lr, @@ -866,6 +882,8 @@ def save_checkpoint(global_step): loss = torch.tensor(0.0, device=device, dtype=weight_dtype) for epoch in range(args.epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() for _, batch in enumerate(train_dataloader): if args.resume and global_step < target_global_step: if rank == 0: @@ -898,20 +916,37 @@ def save_checkpoint(global_step): else: raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}") - with unet.join(): - # Predict the noise residual and compute loss - with torch.autocast('cuda', enabled=args.fp16): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - - # backprop and update - scaler.scale(loss).backward() - torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) - scaler.step(optimizer) - scaler.update() - lr_scheduler.step() - optimizer.zero_grad() + if not args.train_text_encoder: + with unet.join(): + # Predict the noise residual and compute loss + with torch.autocast('cuda', enabled=args.fp16): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + # backprop and update + scaler.scale(loss).backward() + torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) + scaler.step(optimizer) + scaler.update() + lr_scheduler.step() + optimizer.zero_grad() + else: + with unet.join(), text_encoder.join(): + # Predict the noise residual and compute loss + with torch.autocast('cuda', enabled=args.fp16): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + # backprop and update + scaler.scale(loss).backward() + torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) + torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), 1.0) + scaler.step(optimizer) + scaler.update() + lr_scheduler.step() + optimizer.zero_grad() # Update EMA if args.use_ema: From 3cefb57fc68f12459cfa95d36f5ccde7525e3f55 Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 19:42:50 +0900 Subject: [PATCH 2/6] fp32 Update --- trainer/diffusers_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 452f4620e..21f28e507 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -751,7 +751,7 @@ def main(): # move models to device vae = vae.to(device, dtype=weight_dtype) unet = unet.to(device, dtype=torch.float32) - text_encoder = text_encoder.to(device, dtype=weight_dtype) + text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32) unet = torch.nn.parallel.DistributedDataParallel( unet, From 34715bcc9726f89f81858273800431b9e6522cab Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 19:57:00 +0900 Subject: [PATCH 3/6] Access Underlying Model --- trainer/diffusers_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 21f28e507..64e4021ee 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -500,6 +500,9 @@ def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, text_encoder: CL self.device = device self.ucg = ucg + if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel: + self.text_encoder = self.text_encoder.module + self.transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.ToTensor(), From bf264d0ff02230e202e6075fe7b1659338147364 Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 20:06:46 +0900 Subject: [PATCH 4/6] Update Save Checkpoint --- trainer/diffusers_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 64e4021ee..732a6b592 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -866,7 +866,7 @@ def save_checkpoint(global_step): ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, + text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module, vae=vae, unet=unet.module, tokenizer=tokenizer, From 31dd4f643379a81536b508dfa89434702aaa57fd Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 20:49:53 +0900 Subject: [PATCH 5/6] Update Samples --- trainer/diffusers_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 732a6b592..6c62ba255 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -998,7 +998,7 @@ def save_checkpoint(global_step): scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token) pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, + text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module, vae=vae, unet=unet.module, tokenizer=tokenizer, From 29e7df519b4abe1457ffd90da63e8160af927ee2 Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 21:24:38 +0900 Subject: [PATCH 6/6] Fix Gradient Checkpointing --- trainer/diffusers_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 6c62ba255..0e08746f7 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -741,9 +741,8 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() if args.use_xformers: unet.set_use_memory_efficient_attention_xformers(True)