Skip to content

Commit

Permalink
Merge pull request #58 from harubaru/text-encoder-updates
Browse files Browse the repository at this point in the history
FEAT: Text Encoder (CLIP) Training
  • Loading branch information
harubaru authored Dec 3, 2022
2 parents c709257 + 29e7df5 commit 27d301c
Showing 1 changed file with 58 additions and 21 deletions.
79 changes: 58 additions & 21 deletions trainer/diffusers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")


args = parser.parse_args()

Expand Down Expand Up @@ -498,6 +500,9 @@ def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, text_encoder: CL
self.device = device
self.ucg = ucg

if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel:
self.text_encoder = self.text_encoder.module

self.transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
Expand Down Expand Up @@ -731,11 +736,14 @@ def main():

# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()

if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()

if args.use_xformers:
unet.set_use_memory_efficient_attention_xformers(True)

Expand All @@ -745,15 +753,23 @@ def main():
# move models to device
vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype)
text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32)

unet = torch.nn.parallel.DistributedDataParallel(
unet,
device_ids=[rank],
output_device=rank,
gradient_as_bucket_view=True
)


if args.train_text_encoder:
text_encoder = torch.nn.parallel.DistributedDataParallel(
text_encoder,
device_ids=[rank],
output_device=rank,
gradient_as_bucket_view=True
)

if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try:
import bitsandbytes as bnb
Expand All @@ -774,10 +790,12 @@ def main():
)
"""

optimizer_parameters = unet.parameters() if not args.train_text_encoder else itertools.chain(unet.parameters(), text_encoder.parameters())

# Create distributed optimizer
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(
unet.parameters(),
optimizer_parameters,
optimizer_class=optimizer_cls,
parameters_as_bucket_view=True,
lr=args.lr,
Expand Down Expand Up @@ -847,7 +865,7 @@ def save_checkpoint(global_step):
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module,
vae=vae,
unet=unet.module,
tokenizer=tokenizer,
Expand All @@ -866,6 +884,8 @@ def save_checkpoint(global_step):
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
for epoch in range(args.epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for _, batch in enumerate(train_dataloader):
if args.resume and global_step < target_global_step:
if rank == 0:
Expand Down Expand Up @@ -898,20 +918,37 @@ def save_checkpoint(global_step):
else:
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")

with unet.join():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")

# backprop and update
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
if not args.train_text_encoder:
with unet.join():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")

# backprop and update
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
else:
with unet.join(), text_encoder.join():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")

# backprop and update
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()

# Update EMA
if args.use_ema:
Expand Down Expand Up @@ -960,7 +997,7 @@ def save_checkpoint(global_step):
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)

pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module,
vae=vae,
unet=unet.module,
tokenizer=tokenizer,
Expand Down

0 comments on commit 27d301c

Please sign in to comment.