Skip to content

Commit

Permalink
Fix saving text encoder weights and kohya weights in advanced dreambo…
Browse files Browse the repository at this point in the history
…oth lora script (huggingface#8766)

* update

* update

* update
  • Loading branch information
DN6 authored Jul 5, 2024
1 parent 0bab9d6 commit 85c4a32
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,7 @@ def save_model_hook(models, weights, output_dir):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
Expand Down Expand Up @@ -1981,7 +1982,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")

save_model_card(
model_id if not args.push_to_hub else repo_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")

save_model_card(
model_id if not args.push_to_hub else repo_id,
Expand Down

0 comments on commit 85c4a32

Please sign in to comment.