diff --git a/README.md b/README.md index 7dfb6329f..3f02f9a15 100644 --- a/README.md +++ b/README.md @@ -156,7 +156,7 @@ print("done: see output.png") You should get: -![dropy slime output](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_slime_9752.png) +![dropy slime output](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/expected_dropy_slime_9752.png) ### Training diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 15bcd5003..39973b3b3 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -209,6 +209,16 @@ def download_sdxl(hf_repo_id: str = "stabilityai/stable-diffusion-xl-base-1.0"): download_sd_tokenizer(hf_repo_id, "tokenizer_2") +def download_vae_fp16_fix(): + download_files( + urls=[ + "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/raw/main/config.json", + "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/diffusion_pytorch_model.safetensors", + ], + dest_folder=os.path.join(test_weights_dir, "madebyollin", "sdxl-vae-fp16-fix"), + ) + + def download_vae_ft_mse(): download_files( urls=[ @@ -433,6 +443,17 @@ def convert_vae_ft_mse(): ) +def convert_vae_fp16_fix(): + run_conversion_script( + "convert_diffusers_autoencoder_kl.py", + "tests/weights/madebyollin/sdxl-vae-fp16-fix", + "tests/weights/sdxl-lda-fp16-fix.safetensors", + additional_args=["--subfolder", "''"], + half=True, + expected_hash="98c7e998", + ) + + def convert_lora(): os.makedirs("tests/weights/loras", exist_ok=True) run_conversion_script( @@ -610,6 +631,7 @@ def download_all(): download_sd15("runwayml/stable-diffusion-inpainting") download_sdxl("stabilityai/stable-diffusion-xl-base-1.0") download_vae_ft_mse() + download_vae_fp16_fix() download_lora() download_preprocessors() download_controlnet() @@ -624,6 +646,7 @@ def convert_all(): convert_sd15() convert_sdxl() convert_vae_ft_mse() + convert_vae_fp16_fix() convert_lora() convert_preprocessors() convert_controlnet() diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 532c68ff1..716fea31e 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -89,10 +89,18 @@ def compute_self_attention_guidance( classifier_free_guidance=True, ) - negative_embedding, _ = clip_text_embedding.chunk(2) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + negative_embedding, _ = clip_text_embedding.chunk(2) self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) - degraded_noise = self.unet(degraded_latents) + if "ip_adapter" in self.unet.provider.contexts: + # this implementation is a bit hacky, it should be refactored in the future + ip_adapter_context = self.unet.use_context("ip_adapter") + image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone() + ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2) + degraded_noise = self.unet(degraded_latents) + ip_adapter_context["clip_image_embedding"] = image_embedding_copy + else: + degraded_noise = self.unet(degraded_latents) return sag.scale * (noise - degraded_noise) @@ -160,14 +168,23 @@ def compute_self_attention_guidance( step=step, classifier_free_guidance=True, ) - - negative_embedding, _ = clip_text_embedding.chunk(2) - timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) - self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) x = torch.cat( tensors=(degraded_latents, self.mask_latents, self.target_image_latents), dim=1, ) - degraded_noise = self.unet(x) + + timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + negative_embedding, _ = clip_text_embedding.chunk(2) + self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) + + if "ip_adapter" in self.unet.provider.contexts: + # this implementation is a bit hacky, it should be refactored in the future + ip_adapter_context = self.unet.use_context("ip_adapter") + image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone() + ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2) + degraded_noise = self.unet(x) + ip_adapter_context["clip_image_embedding"] = image_embedding_copy + else: + degraded_noise = self.unet(x) return sag.scale * (noise - degraded_noise) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 0cb979b94..1971b32b2 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -138,17 +138,25 @@ def compute_self_attention_guidance( classifier_free_guidance=True, ) - negative_embedding, _ = clip_text_embedding.chunk(2) + negative_text_embedding, _ = clip_text_embedding.chunk(2) negative_pooled_embedding, _ = pooled_text_embedding.chunk(2) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) time_ids, _ = time_ids.chunk(2) + self.set_unet_context( timestep=timestep, - clip_text_embedding=negative_embedding, + clip_text_embedding=negative_text_embedding, pooled_text_embedding=negative_pooled_embedding, time_ids=time_ids, - **kwargs, ) - degraded_noise = self.unet(degraded_latents) + if "ip_adapter" in self.unet.provider.contexts: + # this implementation is a bit hacky, it should be refactored in the future + ip_adapter_context = self.unet.use_context("ip_adapter") + image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone() + ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2) + degraded_noise = self.unet(degraded_latents) + ip_adapter_context["clip_image_embedding"] = image_embedding_copy + else: + degraded_noise = self.unet(degraded_latents) return sag.scale * (noise - degraded_noise) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 4ced66ceb..04bb180eb 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -242,6 +242,20 @@ def expected_freeu(ref_path: Path) -> Image.Image: return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB") +@pytest.fixture +def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]: + assets = Path(__file__).parent.parent.parent / "assets" + dropy = assets / "dropy_logo.png" + image_prompt = assets / "dragon_quest_slime.jpg" + condition_image = assets / "dropy_canny.png" + return ( + Image.open(fp=dropy).convert(mode="RGB"), + Image.open(fp=image_prompt).convert(mode="RGB"), + Image.open(fp=condition_image).convert(mode="RGB"), + Image.open(fp=ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"), + ) + + @pytest.fixture def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] # type: ignore @@ -488,6 +502,15 @@ def sdxl_lda_weights(test_weights_path: Path) -> Path: return sdxl_lda_weights +@pytest.fixture +def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path: + sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors" + if not sdxl_lda_weights.is_file(): + warn(message=f"could not find weights at {sdxl_lda_weights}, skipping") + pytest.skip(allow_module_level=True) + return sdxl_lda_weights + + @pytest.fixture def sdxl_unet_weights(test_weights_path: Path) -> Path: sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors" @@ -524,6 +547,24 @@ def sdxl_ddim( return sdxl +@pytest.fixture +def sdxl_ddim_lda_fp16_fix( + sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device +) -> StableDiffusion_XL: + if test_device.type == "cpu": + warn(message="not running on CPU, skipping") + pytest.skip() + + scheduler = DDIM(num_inference_steps=30) + sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device) + + sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) + sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) + sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) + + return sdxl + + @no_grad() def test_diffusion_std_random_init( sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device @@ -1702,3 +1743,62 @@ def test_freeu( predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_freeu) + + +@no_grad() +def test_hello_world( + sdxl_ddim_lda_fp16_fix: StableDiffusion_XL, + t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], + sdxl_ip_adapter_weights: Path, + image_encoder_weights: Path, + hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image], +) -> None: + sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16) + sdxl.dtype = torch.float16 # FIXME: should not be necessary + + name, _, _, weights_path = t2i_adapter_xl_data_canny + init_image, image_prompt, condition_image, expected_image = hello_world_assets + + if not weights_path.is_file(): + warn(f"could not find weights at {weights_path}, skipping") + pytest.skip(allow_module_level=True) + + ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights)) + ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) + ip_adapter.inject() + + image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt)) + ip_adapter.set_clip_image_embedding(image_embedding) + + # Note: default text prompts for IP-Adapter + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality" + ) + time_ids = sdxl.default_time_ids + + t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject() + + condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype) + t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) + + first_step = 1 + ip_adapter.set_scale(0.85) + t2i_adapter.set_scale(0.8) + sdxl.set_num_inference_steps(50) + sdxl.set_self_attention_guidance(enable=True, scale=0.75) + + manual_seed(9752) + x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to( + device=sdxl.device, dtype=sdxl.dtype + ) + for step in sdxl.steps[first_step:]: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + ) + predicted_image = sdxl.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_image) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 848da4a52..42e9809e9 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -47,6 +47,7 @@ Special cases: - `expected_cutecat_sdxl_ddim_random_init_sag.png` - `expected_restart.png` - `expected_freeu.png` + - `expected_dropy_slime_9752.png` (located in `/assets`) ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_dropy_slime_9752.png b/tests/e2e/test_diffusion_ref/expected_dropy_slime_9752.png new file mode 100644 index 000000000..5628cb87c Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_dropy_slime_9752.png differ