Skip to content

Commit

Permalink
fix broken self-attention guidance with ip-adapter
Browse files Browse the repository at this point in the history
The #168 and #177 refactorings caused this regression. A new end-to-end
test has been added for proper coverage.

(This fix will be revisited at some point)
  • Loading branch information
limiteinductive authored and deltheil committed Jan 16, 2024
1 parent d9ae7ca commit f8c09a7
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 11 deletions.
23 changes: 23 additions & 0 deletions scripts/prepare_test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ def download_sdxl(hf_repo_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
download_sd_tokenizer(hf_repo_id, "tokenizer_2")


def download_vae_fp16_fix():
download_files(
urls=[
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/raw/main/config.json",
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/diffusion_pytorch_model.safetensors",
],
dest_folder=os.path.join(test_weights_dir, "madebyollin", "sdxl-vae-fp16-fix"),
)


def download_vae_ft_mse():
download_files(
urls=[
Expand Down Expand Up @@ -433,6 +443,17 @@ def convert_vae_ft_mse():
)


def convert_vae_fp16_fix():
run_conversion_script(
"convert_diffusers_autoencoder_kl.py",
"tests/weights/madebyollin/sdxl-vae-fp16-fix",
"tests/weights/sdxl-lda-fp16-fix.safetensors",
additional_args=["--subfolder", "''"],
half=True,
expected_hash="98c7e998",
)


def convert_lora():
os.makedirs("tests/weights/loras", exist_ok=True)
run_conversion_script(
Expand Down Expand Up @@ -610,6 +631,7 @@ def download_all():
download_sd15("runwayml/stable-diffusion-inpainting")
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
download_vae_ft_mse()
download_vae_fp16_fix()
download_lora()
download_preprocessors()
download_controlnet()
Expand All @@ -624,6 +646,7 @@ def convert_all():
convert_sd15()
convert_sdxl()
convert_vae_ft_mse()
convert_vae_fp16_fix()
convert_lora()
convert_preprocessors()
convert_controlnet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,18 @@ def compute_self_attention_guidance(
classifier_free_guidance=True,
)

negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
degraded_noise = self.unet(degraded_latents)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents)

return sag.scale * (noise - degraded_noise)

Expand Down Expand Up @@ -160,14 +168,23 @@ def compute_self_attention_guidance(
step=step,
classifier_free_guidance=True,
)

negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)

timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)

if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x)

return sag.scale * (noise - degraded_noise)
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,25 @@ def compute_self_attention_guidance(
classifier_free_guidance=True,
)

negative_embedding, _ = clip_text_embedding.chunk(2)
negative_text_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)

self.set_unet_context(
timestep=timestep,
clip_text_embedding=negative_embedding,
clip_text_embedding=negative_text_embedding,
pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids,
**kwargs,
)
degraded_noise = self.unet(degraded_latents)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents)

return sag.scale * (noise - degraded_noise)
100 changes: 100 additions & 0 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,20 @@ def expected_freeu(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")


@pytest.fixture
def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]:
assets = Path(__file__).parent.parent.parent / "assets"
dropy = assets / "dropy_logo.png"
image_prompt = assets / "dragon_quest_slime.jpg"
condition_image = assets / "dropy_canny.png"
return (
Image.open(fp=dropy).convert(mode="RGB"),
Image.open(fp=image_prompt).convert(mode="RGB"),
Image.open(fp=condition_image).convert(mode="RGB"),
Image.open(fp=ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
)


@pytest.fixture
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
Expand Down Expand Up @@ -488,6 +502,15 @@ def sdxl_lda_weights(test_weights_path: Path) -> Path:
return sdxl_lda_weights


@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights


@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
Expand Down Expand Up @@ -524,6 +547,24 @@ def sdxl_ddim(
return sdxl


@pytest.fixture
def sdxl_ddim_lda_fp16_fix(
sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()

scheduler = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)

sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)

return sdxl


@no_grad()
def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
Expand Down Expand Up @@ -1702,3 +1743,62 @@ def test_freeu(
predicted_image = sd15.lda.decode_latents(x)

ensure_similar_images(predicted_image, expected_freeu)


@no_grad()
def test_hello_world(
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
sdxl_ip_adapter_weights: Path,
image_encoder_weights: Path,
hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image],
) -> None:
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
sdxl.dtype = torch.float16 # FIXME: should not be necessary

name, _, _, weights_path = t2i_adapter_xl_data_canny
init_image, image_prompt, condition_image, expected_image = hello_world_assets

if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)

ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()

image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
ip_adapter.set_clip_image_embedding(image_embedding)

# Note: default text prompts for IP-Adapter
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
)
time_ids = sdxl.default_time_ids

t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()

condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))

first_step = 1
ip_adapter.set_scale(0.85)
t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)

manual_seed(9752)
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to(
device=sdxl.device, dtype=sdxl.dtype
)
for step in sdxl.steps[first_step:]:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)

ensure_similar_images(predicted_image, expected_image)
1 change: 1 addition & 0 deletions tests/e2e/test_diffusion_ref/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Special cases:
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
- `expected_restart.png`
- `expected_freeu.png`
- `expected_dropy_slime_9752.png`

## Other images

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit f8c09a7

Please sign in to comment.