Skip to content

Commit

Permalink
ControlNet+Adapter pipeline, and ControlNet+Adapter+Inpaint pipeline (h…
Browse files Browse the repository at this point in the history
…uggingface#5869)

* ControlNet+Adapter pipeline, and +Inpaint pipeline


---------

Co-authored-by: andres <[email protected]>
  • Loading branch information
affromero and andres authored Nov 21, 2023
1 parent 13d73d9 commit 93f1a14
Show file tree
Hide file tree
Showing 3 changed files with 3,497 additions and 0 deletions.
138 changes: 138 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2343,3 +2343,141 @@ images = pipe(

assert len(images) == (len(prompts) - 1) * num_interpolation_steps
```

### ControlNet + T2I Adapter Pipeline
This pipelines combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once.
It receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale = 0` or `controlnet_conditioning_scale = 0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively.

```py
import cv2
import numpy as np
import torch
from controlnet_aux.midas import MidasDetector
from PIL import Image

from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.utils import load_image
from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import (
StableDiffusionXLControlNetAdapterPipeline,
)

controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
adapter_depth = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)

pipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet_depth,
adapter=adapter_depth,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to("cuda")

prompt = "a tiger sitting on a park bench"
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"

image = load_image(img_url).resize((1024, 1024))

depth_image = midas_depth(
image, detect_resolution=512, image_resolution=1024
)

strength = 0.5

images = pipe(
prompt,
control_image=depth_image,
adapter_image=depth_image,
num_inference_steps=30,
controlnet_conditioning_scale=strength,
adapter_conditioning_scale=strength,
).images
images[0].save("controlnet_and_adapter.png")

```

### ControlNet + T2I Adapter + Inpainting Pipeline
```py
import cv2
import numpy as np
import torch
from controlnet_aux.midas import MidasDetector
from PIL import Image

from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.utils import load_image
from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import (
StableDiffusionXLControlNetAdapterInpaintPipeline,
)

controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
adapter_depth = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)

pipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
controlnet=controlnet_depth,
adapter=adapter_depth,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to("cuda")

prompt = "a tiger sitting on a park bench"
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

image = load_image(img_url).resize((1024, 1024))
mask_image = load_image(mask_url).resize((1024, 1024))

depth_image = midas_depth(
image, detect_resolution=512, image_resolution=1024
)

strength = 0.4

images = pipe(
prompt,
image=image,
mask_image=mask_image,
control_image=depth_image,
adapter_image=depth_image,
num_inference_steps=30,
controlnet_conditioning_scale=strength,
adapter_conditioning_scale=strength,
strength=0.7,
).images
images[0].save("controlnet_and_adapter_inpaint.png")

```
Loading

0 comments on commit 93f1a14

Please sign in to comment.