From 35b81fffaea20cca3e870a834cecef7e52a7d1d9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Dec 2023 11:40:04 +0100 Subject: [PATCH] [Wuerstchen] fix fp16 training and correct lora args (#6245) fix fp16 training Co-authored-by: Sayak Paul --- .../text_to_image/train_text_to_image_lora_prior.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index 1e67f05abe7a..f1f6b3215201 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -527,9 +527,17 @@ def deepspeed_zero_init_disabled_context_manager(): # lora attn processor prior_lora_config = LoraConfig( - r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"] + r=args.rank, + lora_alpha=args.rank, + target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], ) + # Add adapter and make sure the trainable params are in float32. prior.add_adapter(prior_lora_config) + if args.mixed_precision == "fp16": + for param in prior.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir):