-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training Issue (Implemented with SDXL Inpainting as base model) - Unable to obtain good outputs even when loss is converging #65
Comments
We also conducted experiments on SDXL, but SDXL does not have an official inpainting model. The SDXL Inpainting model of DIffusers is defective, so if Refiner is not used, the final result is slightly worse than that on SD1.5.But it does not only output noise. |
This is the inference code that I used, mostly taken from your pipeline.py. Could you please let me know if there are any obvious issues with it? with tqdm.tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Conditional logic for classifier-free guidance
non_inpainting_latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
# Prepare inputs for inpainting model
non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(
non_inpainting_latent_model_input, t
)
non_inpainting_latent_model_input = non_inpainting_latent_model_input.to(self.device)
mask_latent_concat = mask_latent_concat.to(self.device)
masked_latent_concat = masked_latent_concat.to(self.device)
inpainting_latent_model_input = torch.cat(
[non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1
)
inpainting_latent_model_input = inpainting_latent_model_input.to(self.device)
# Prepare additional conditions for UNet
unet_added_conditions = {"time_ids": add_time_ids}
# add_time_ids : [1024,1024,0,0,1024,1024]
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
# pooled_prompt_embeds : torch.zeros(1,77,1280)
# Predict noise residual with UNet
noise_pred = self.unet(
inpainting_latent_model_input,
t.to(self.device),
encoder_hidden_states=None, # FIXME
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]
# Guidance for classifier-free guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text
- noise_pred_uncond
)
# Compute previous noisy sample
latents = self.noise_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# Update progress bar with conditional logic
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.noise_scheduler.order == 0
):
progress_bar.update()
# Decode final latents
latents_temp = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
latents_temp = (1 / self.vae.config.scaling_factor) * latents_temp
image = self.vae.decode(latents_temp.to(self.device, dtype=self.weight_dtype)).sample
# Final decoded images
decoded_images = (image / 2 + 0.5).clamp(0, 1) A few things I would like to get verified
|
Have you try your training code in SD1.5 ? |
I implemented the CatVTON approach with SDXL Inpainting as the base model including DREAM. And the loss curve looks good & drops to ~0.001 after several epochs. However, the resulting images are just noise in the shape of the person. I also tried applying noise to "Unmasked person + garment condition" instead of Masked and the results were a little bit better, but still just noise. Apart from this, I also trained (i) Entire UNet and (ii) Only Attention parameters and the results are provided below.
Training curve
Approach based on CatVTON Paper (DREAM training + VITON HD data + SDXL Inpainting )
As you can see, the outputs are filled with noise even when the loss is converging. Can you please provide some insights into why this is occurring? Since the UNet of SDXL isn't drastically different from that of SD 1.5, I don't understand what's causing these issues. @Zheng-Chong Any feedback or suggestion is appreciated!
The text was updated successfully, but these errors were encountered: