Skip to content

Commit

Permalink
[train_custom_diffusion.py] Fix the LR schedulers when `num_train_epo…
Browse files Browse the repository at this point in the history
…chs` is passed in a distributed training env (huggingface#9308)

* Update train_custom_diffusion.py to fix the LR schedulers for `num_train_epochs`

* Fix saving text embeddings during safe serialization

* Fixed formatting
  • Loading branch information
AnandK27 authored Aug 29, 2024
1 parent 2a3fbc2 commit 40c13fe
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,12 @@ def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_di
for x, y in zip(modifier_token_id, args.modifier_token):
learned_embeds_dict = {}
learned_embeds_dict[y] = learned_embeds[x]
filename = f"{output_dir}/{y}.bin"

if safe_serialization:
filename = f"{output_dir}/{y}.safetensors"
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
else:
filename = f"{output_dir}/{y}.bin"
torch.save(learned_embeds_dict, filename)


Expand Down Expand Up @@ -1040,17 +1041,22 @@ def main(args):
)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
)

# Prepare everything with our `accelerator`.
Expand All @@ -1065,8 +1071,14 @@ def main(args):

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

Expand Down

0 comments on commit 40c13fe

Please sign in to comment.