Skip to content

Commit

Permalink
[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)
Browse files Browse the repository at this point in the history
fix fp16 training

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
kashif and sayakpaul authored Dec 26, 2023
1 parent e0d8c91 commit 35b81ff
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,17 @@ def deepspeed_zero_init_disabled_context_manager():

# lora attn processor
prior_lora_config = LoraConfig(
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
r=args.rank,
lora_alpha=args.rank,
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
# Add adapter and make sure the trainable params are in float32.
prior.add_adapter(prior_lora_config)
if args.mixed_precision == "fp16":
for param in prior.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
Expand Down

0 comments on commit 35b81ff

Please sign in to comment.