Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Issue (Implemented with SDXL Inpainting as base model) - Unable to obtain good outputs even when loss is converging #65

Open
badhri-suresh opened this issue Oct 12, 2024 · 3 comments

Comments

@badhri-suresh
Copy link

I implemented the CatVTON approach with SDXL Inpainting as the base model including DREAM. And the loss curve looks good & drops to ~0.001 after several epochs. However, the resulting images are just noise in the shape of the person. I also tried applying noise to "Unmasked person + garment condition" instead of Masked and the results were a little bit better, but still just noise. Apart from this, I also trained (i) Entire UNet and (ii) Only Attention parameters and the results are provided below.

Training curve

image

Approach based on CatVTON Paper (DREAM training + VITON HD data + SDXL Inpainting )

  1. Entire UNet parameters trained

image

  1. Attention parameters only trained

image

As you can see, the outputs are filled with noise even when the loss is converging. Can you please provide some insights into why this is occurring? Since the UNet of SDXL isn't drastically different from that of SD 1.5, I don't understand what's causing these issues. @Zheng-Chong Any feedback or suggestion is appreciated!

@Zheng-Chong
Copy link
Owner

We also conducted experiments on SDXL, but SDXL does not have an official inpainting model. The SDXL Inpainting model of DIffusers is defective, so if Refiner is not used, the final result is slightly worse than that on SD1.5.But it does not only output noise.
If your training is correct, maybe there is a problem with your inference code, which makes the inference result completely unable to be presented normally.

@badhri-suresh
Copy link
Author

badhri-suresh commented Oct 14, 2024

This is the inference code that I used, mostly taken from your pipeline.py. Could you please let me know if there are any obvious issues with it?

with tqdm.tqdm(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # Conditional logic for classifier-free guidance
        non_inpainting_latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents   

        )

        # Prepare inputs for inpainting model
        non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(
            non_inpainting_latent_model_input, t
        )
        non_inpainting_latent_model_input = non_inpainting_latent_model_input.to(self.device)
        mask_latent_concat = mask_latent_concat.to(self.device)
        masked_latent_concat = masked_latent_concat.to(self.device)
        inpainting_latent_model_input = torch.cat(
            [non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1
        )
        inpainting_latent_model_input = inpainting_latent_model_input.to(self.device)

        # Prepare additional conditions for UNet
        unet_added_conditions = {"time_ids": add_time_ids}
        # add_time_ids : [1024,1024,0,0,1024,1024]
        unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
        # pooled_prompt_embeds : torch.zeros(1,77,1280)

        # Predict noise residual with UNet
        noise_pred = self.unet(
            inpainting_latent_model_input,
            t.to(self.device),
            encoder_hidden_states=None,  # FIXME
            added_cond_kwargs=unet_added_conditions,
            return_dict=False,
        )[0]

        # Guidance for classifier-free guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text   
 - noise_pred_uncond
            )

        # Compute  previous noisy sample
        latents = self.noise_scheduler.step(
            noise_pred, t, latents, **extra_step_kwargs
        ).prev_sample

        # Update progress bar with conditional logic
        if i == len(timesteps) - 1 or (
            (i + 1) > num_warmup_steps and (i + 1) % self.noise_scheduler.order == 0
        ):
            progress_bar.update()

    # Decode final latents
    latents_temp = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
    latents_temp = (1 / self.vae.config.scaling_factor) * latents_temp
    image = self.vae.decode(latents_temp.to(self.device, dtype=self.weight_dtype)).sample

    # Final decoded images
    decoded_images = (image / 2 + 0.5).clamp(0, 1)

A few things I would like to get verified

  1. I use init_adapter & get_trainable_modules to obtain the attn parameters to train. I then by-pass the cross attention summation inside BasicTransformer Block. And finally, I pass in torch.zeroes() in place of pooled_prompt_embeds. I am assuming this will reproduce what you proposed in the paper. Could you please confirm?

  2. Second, I'm not doing any augmentation as of now and therefore, I set the time_ids as [1024,1024,0,0,1024,1024] for each sample and I assume this shouldn't be an issue. But would appreciate if you can confirm.

  3. And finally, I only trained it on 128 samples for a few epochs, just to try and overfit it on fewer number of samples to ensure proper learning. But since the loss converges and images are just noise, I assumed scaling up the data & number of epochs wouldn't help. Could you please share your thoughts on this? Thanks

@ApolloRay
Copy link

I implemented the CatVTON approach with SDXL Inpainting as the base model including DREAM. And the loss curve looks good & drops to ~0.001 after several epochs. However, the resulting images are just noise in the shape of the person. I also tried applying noise to "Unmasked person + garment condition" instead of Masked and the results were a little bit better, but still just noise. Apart from this, I also trained (i) Entire UNet and (ii) Only Attention parameters and the results are provided below.

Training curve

image

Approach based on CatVTON Paper (DREAM training + VITON HD data + SDXL Inpainting )

  1. Entire UNet parameters trained

image

  1. Attention parameters only trained

image

As you can see, the outputs are filled with noise even when the loss is converging. Can you please provide some insights into why this is occurring? Since the UNet of SDXL isn't drastically different from that of SD 1.5, I don't understand what's causing these issues. @Zheng-Chong Any feedback or suggestion is appreciated!

Have you try your training code in SD1.5 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants