From c977966502b70f4758c83ee5a855b48398042b03 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:59:58 +0300 Subject: [PATCH] [Dreambooth flux] bug fix for dreambooth script (align with dreambooth lora) (#9257) * fix shape * fix prompt encoding * style * fix device * add comment --- examples/dreambooth/train_dreambooth_flux.py | 130 ++++++++++--------- 1 file changed, 72 insertions(+), 58 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index ece12e289e0c..da571cc46c57 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -842,7 +842,7 @@ def __getitem__(self, index): return example -def tokenize_prompt(tokenizer, prompt, max_sequence_length=512): +def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", @@ -863,20 +863,26 @@ def _encode_prompt_with_t5( prompt=None, num_images_per_prompt=1, device=None, + text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype @@ -896,22 +902,28 @@ def _encode_prompt_with_clip( tokenizer, prompt: str, device=None, + text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=77, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel @@ -932,17 +944,19 @@ def encode_prompt( max_sequence_length, device=None, num_images_per_prompt: int = 1, + text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) dtype = text_encoders[0].dtype - + device = device if device is not None else text_encoders[1].device pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, - device=device if device is not None else text_encoders[0].device, + device=device, num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( @@ -951,7 +965,8 @@ def encode_prompt( max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - device=device if device is not None else text_encoders[1].device, + device=device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1499,7 +1514,25 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) else: tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) - tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512) + tokens_two = tokenize_prompt( + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + ) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + prompt=prompts, + ) + else: + if args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + prompt=args.instance_prompt, + ) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() @@ -1553,41 +1586,22 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): guidance = None # Predict the noise residual - if not args.train_text_encoder: - model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - else: - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=None, - prompt=None, - text_input_ids_list=[tokens_one, tokens_two], - ) - model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042 model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2]), - width=int(model_input.shape[3]), + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), vae_scale_factor=vae_scale_factor, )