Skip to content

Commit

Permalink
Merge pull request #139 from huggingface/main
Browse files Browse the repository at this point in the history
Merge changes
  • Loading branch information
Skquark authored Jan 16, 2024
2 parents 2f06855 + 8842bca commit b489020
Show file tree
Hide file tree
Showing 53 changed files with 4,737 additions and 696 deletions.
14 changes: 5 additions & 9 deletions docs/source/en/using-diffusers/callback.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,15 @@ With callbacks, you can implement features such as dynamic CFG without having to

</Tip>

## Interrupt the diffusion process

## Using Callbacks to interrupt the Diffusion Process
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.

The following Pipelines support interrupting the diffusion process via callback
<Tip>

- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).

Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
</Tip>

This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/using-diffusers/controlnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ image = pipe(
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
```

### MultiControlNet
## MultiControlNet

<Tip>

Expand Down
73 changes: 68 additions & 5 deletions docs/source/en/using-diffusers/inpaint.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,42 @@ Throughout this guide, the mask image is provided in all of the code examples fo
Upload a base image to inpaint on and use the sketch tool to draw a mask. Once you're done, click **Run** to generate and download the mask image.

<iframe
src="https://stevhliu-inpaint-mask-maker.hf.space"
frameborder="0"
width="850"
height="450"
src="https://stevhliu-inpaint-mask-maker.hf.space"
frameborder="0"
width="850"
height="450"
></iframe>
### Mask blur

The [`~VaeImageProcessor.blur`] method provides an option for how to blend the original image and inpaint area. The amount of blur is determined by the `blur_factor` parameter. Increasing the `blur_factor` increases the amount of blur applied to the mask edges, softening the transition between the original image and inpaint area. A low or zero `blur_factor` preserves the sharper edges of the mask.

To use this, create a blurred mask with the image processor.

```py
import torch
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
from PIL import Image

pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to('cuda')

mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png")
blurred_mask = pipeline.mask_processor.blur(mask, blur_factor=33)
blurred_mask
```

<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">mask with no blur</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/mask_blurred.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">mask with blur applied</figcaption>
</div>
</div>

## Popular models

[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2 Inpainting](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
Expand Down Expand Up @@ -318,7 +348,7 @@ make_image_grid([init_image, image], rows=1, cols=2)

The trade-off of using a non-inpaint specific checkpoint is the overall image quality may be lower, but it generally tends to preserve the mask area (that is why you can see the mask outline). The inpaint specific checkpoints are intentionally trained to generate higher quality inpainted images, and that includes creating a more natural transition between the masked and unmasked areas. As a result, these checkpoints are more likely to change your unmasked area.

If preserving the unmasked area is important for your task, you can use the `apply_overlay` method of [`VaeImageProcessor`] to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.
If preserving the unmasked area is important for your task, you can use the [`VaeImageProcessor.apply_overlay`] method to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.

```py
import PIL
Expand Down Expand Up @@ -475,6 +505,39 @@ make_image_grid([init_image, mask_image, image], rows=1, cols=3)
</figure>
</div>

### Padding mask crop

A method for increasing the inpainting image quality is to use the [`padding_mask_crop`](https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline.__call__.padding_mask_crop) parameter. When enabled, this option crops the masked area with some user-specified padding and it'll also crop the same area from the original image. Both the image and mask are upscaled to a higher resolution for inpainting, and then overlaid on the original image. This is a quick and easy way to improve image quality without using a separate pipeline like [`StableDiffusionUpscalePipeline`].

Add the `padding_mask_crop` parameter to the pipeline call and set it to the desired padding value.

```py
import torch
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
from PIL import Image

generator = torch.Generator(device='cuda').manual_seed(0)
pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to('cuda')

base = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png")
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png")

image = pipeline("boat", image=base, mask_image=mask, strength=0.75, generator=generator, padding_mask_crop=32).images[0]
image
```

<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/baseline_inpaint.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">default inpaint image</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/padding_mask_crop_inpaint.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">inpaint image with `padding_mask_crop` enabled</figcaption>
</div>
</div>

## Chained inpainting pipelines

[`AutoPipelineForInpainting`] can be chained with other 🤗 Diffusers pipelines to edit their outputs. This is often useful for improving the output quality from your other diffusion pipelines, and if you're using multiple pipelines, it can be more memory-efficient to chain them together to keep the outputs in latent space and reuse the same pipeline components.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
Expand All @@ -58,12 +58,13 @@
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import (
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -1292,17 +1293,6 @@ def main(args):
else:
param.requires_grad = False

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
Expand Down Expand Up @@ -1358,17 +1348,34 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
if args.train_text_encoder:
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)

_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
)

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
cast_training_params(models)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand All @@ -1383,6 +1390,13 @@ def load_model_hook(models, input_dir):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
cast_training_params(models, dtype=torch.float32)

unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))

if args.train_text_encoder:
Expand Down
Loading

0 comments on commit b489020

Please sign in to comment.