diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index fea145d0b1e3..cf558f082018 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -1290,6 +1290,7 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) + else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again @@ -1981,7 +1982,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) - save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors") + save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors") save_model_card( model_id if not args.push_to_hub else repo_id, diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f29b2e0b5225..9d06ce6cba16 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -2425,7 +2425,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) - save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors") + save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors") save_model_card( model_id if not args.push_to_hub else repo_id,