From 40c13fe5b41872d99dde75bf32f5c31626cf0043 Mon Sep 17 00:00:00 2001 From: Anand Kumar <63339285+AnandK27@users.noreply.github.com> Date: Thu, 29 Aug 2024 01:53:36 -0700 Subject: [PATCH] [train_custom_diffusion.py] Fix the LR schedulers when `num_train_epochs` is passed in a distributed training env (#9308) * Update train_custom_diffusion.py to fix the LR schedulers for `num_train_epochs` * Fix saving text embeddings during safe serialization * Fixed formatting --- .../train_custom_diffusion.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 8dddcd0ca706..e498ca98b1c7 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -314,11 +314,12 @@ def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_di for x, y in zip(modifier_token_id, args.modifier_token): learned_embeds_dict = {} learned_embeds_dict[y] = learned_embeds[x] - filename = f"{output_dir}/{y}.bin" if safe_serialization: + filename = f"{output_dir}/{y}.safetensors" safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"}) else: + filename = f"{output_dir}/{y}.bin" torch.save(learned_embeds_dict, filename) @@ -1040,17 +1041,22 @@ def main(args): ) # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, ) # Prepare everything with our `accelerator`. @@ -1065,8 +1071,14 @@ def main(args): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)