From 57fde871e117090fc766cc36caddf605f7c72465 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 18 Dec 2023 15:10:01 -1000 Subject: [PATCH 01/42] offload the optional module `image_encoder` (#6151) * offload image_encoder * add test --------- Co-authored-by: yiyixuxu Co-authored-by: Sayak Paul --- .../animatediff/pipeline_animatediff.py | 2 +- .../controlnet/pipeline_controlnet.py | 2 +- .../controlnet/pipeline_controlnet_inpaint.py | 2 +- .../controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_alt_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- .../pipeline_stable_diffusion_xl.py | 2 +- .../pipeline_stable_diffusion_xl_img2img.py | 2 +- .../pipeline_stable_diffusion_xl_inpaint.py | 2 +- .../test_ip_adapter_stable_diffusion.py | 27 +++++++++++++++++++ 13 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 68b358f7645c..0dab722e51a8 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -106,7 +106,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder"] def __init__( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 3de6732be0f2..d7168bec8259 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -176,7 +176,7 @@ class StableDiffusionControlNetPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 270c232b698c..a18468f72c19 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -291,7 +291,7 @@ class StableDiffusionControlNetInpaintPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0e7920708184..02e515c0ff55 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -165,7 +165,7 @@ class StableDiffusionXLControlNetPipeline( """ # leave controlnet out on purpose because it iterates with unet - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py index 45e82a28d2e0..186efbfc160d 100644 --- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py @@ -155,7 +155,7 @@ class AltDiffusionPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py index 9838bb9e5ba6..5ba1d7afd336 100644 --- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -195,7 +195,7 @@ class AltDiffusionImg2ImgPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 2ad90f049922..b05d0b17dd5a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -151,7 +151,7 @@ class StableDiffusionPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index d7e0952b2aa4..d2538749f3be 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -191,7 +191,7 @@ class StableDiffusionImg2ImgPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a321bb41a7eb..bc6c65f4a654 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -255,7 +255,7 @@ class StableDiffusionInpaintPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a6033b698a41..569668a1686d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -198,7 +198,7 @@ class StableDiffusionXLPipeline( watermarker will be used. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 97f99386acef..4f75ce6878ad 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -219,7 +219,7 @@ class StableDiffusionXLImg2ImgPipeline( watermarker will be used. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 812f5499f8e6..751823ea4b10 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -364,7 +364,7 @@ class StableDiffusionXLInpaintPipeline( watermarker will be used. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index ff93ecaf003b..dfc39d61bb08 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -182,6 +182,33 @@ def test_inpainting(self): assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + def test_text_to_image_model_cpu_offload(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + pipeline.to(torch_device) + + inputs = self.get_dummy_inputs() + output_without_offload = pipeline(**inputs).images + + pipeline.enable_model_cpu_offload() + inputs = self.get_dummy_inputs() + output_with_offload = pipeline(**inputs).images + max_diff = np.abs(output_with_offload - output_without_offload).max() + self.assertLess(max_diff, 1e-3, "CPU offloading should not affect the inference results") + + offloaded_modules = [ + v + for k, v in pipeline.components.items() + if isinstance(v, torch.nn.Module) and k not in pipeline._exclude_from_cpu_offload + ] + ( + self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)), + f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}", + ) + def test_text_to_image_full_face(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionPipeline.from_pretrained( From 9221da4063d79e5a518ab6c047a28afced38704a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 19 Dec 2023 09:46:57 +0530 Subject: [PATCH 02/42] fix: init for vae during pixart tests (#6215) * fix: init for vae during pixart tests * print the values * add flatten * correct assertion value for test_inference * correct assertion values for test_inference_non_square_images * run styling * debug test_inference_with_multiple_images_per_prompt * fix assertion values for test_inference_with_multiple_images_per_prompt --- tests/pipelines/pixart/test_pixart.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index 361bacc298e9..3df4cad1925f 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -64,7 +64,9 @@ def get_dummy_components(self): norm_elementwise_affine=False, norm_eps=1e-6, ) + torch.manual_seed(0) vae = AutoencoderKL() + scheduler = DDIMScheduler() text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") @@ -186,7 +188,7 @@ def test_inference(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (1, 8, 8, 3)) - expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675]) + expected_slice = np.array([0.6319, 0.3526, 0.3806, 0.6327, 0.4639, 0.483, 0.2583, 0.5331, 0.4852]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) @@ -203,7 +205,7 @@ def test_inference_non_square_images(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (1, 32, 48, 3)) - expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416]) + expected_slice = np.array([0.6493, 0.537, 0.4081, 0.4762, 0.3695, 0.4711, 0.3026, 0.5218, 0.5263]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) @@ -293,7 +295,7 @@ def test_inference_with_multiple_images_per_prompt(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (2, 8, 8, 3)) - expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675]) + expected_slice = np.array([0.6319, 0.3526, 0.3806, 0.6327, 0.4639, 0.483, 0.2583, 0.5331, 0.4852]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) From 288ceebea51abb439ca81684d642852c7fcc044a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 19 Dec 2023 09:54:17 +0530 Subject: [PATCH 03/42] [T2I LoRA training] fix: unscale fp16 gradient problem (#6119) * fix: unscale fp16 gradient problem * fix for dreambooth lora sdxl * make the type-casting conditional. * Apply suggestions from code review Co-authored-by: Patrick von Platen --------- Co-authored-by: Patrick von Platen --- .../dreambooth/train_dreambooth_lora_sdxl.py | 11 +++ .../text_to_image/train_text_to_image_lora.py | 69 +++++++++++-------- 2 files changed, 50 insertions(+), 30 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c3a78eae34d7..9992292e30aa 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -991,6 +991,17 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ca699c863eb6..c8efbddd0b44 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -460,7 +460,13 @@ def main(): vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) + # Add adapter and make sure the trainable params are in float32. unet.add_adapter(unet_lora_config) + if args.mixed_precision == "fp16": + for param in unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -888,39 +894,42 @@ def collate_fn(examples): ignore_patterns=["step_*", "epoch_*"], ) - # Final inference - # Load previous pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype - ) - pipeline = pipeline.to(accelerator.device) + # Final inference + # Load previous pipeline + if args.validation_prompt is not None: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) - # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) + # load attention processors + pipeline.load_lora_weights(args.output_dir) - # run inference - generator = torch.Generator(device=accelerator.device) - if args.seed is not None: - generator = generator.manual_seed(args.seed) - images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + # run inference + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) - if accelerator.is_main_process: - for tracker in accelerator.trackers: - if len(images) != 0: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + for tracker in accelerator.trackers: + if len(images) != 0: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) accelerator.end_training() From 32ff4773d4b6662ddbb35c4a75f7178eb2b70cf0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 19 Dec 2023 11:58:34 +0530 Subject: [PATCH 04/42] ControlNetXS fixes. (#6228) update --- src/diffusers/models/controlnetxs.py | 81 ++++++++++++++----- .../controlnetxs/test_controlnetxs.py | 9 ++- 2 files changed, 67 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/controlnetxs.py b/src/diffusers/models/controlnetxs.py index 3cc77fe70d72..41fe624b9b59 100644 --- a/src/diffusers/models/controlnetxs.py +++ b/src/diffusers/models/controlnetxs.py @@ -23,9 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .attention_processor import ( - AttentionProcessor, -) +from .attention_processor import USE_PEFT_BACKEND, AttentionProcessor from .autoencoders import AutoencoderKL from .lora import LoRACompatibleConv from .modeling_utils import ModelMixin @@ -817,11 +815,23 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} norm_kwargs["num_channels"] += by # surgery done here # conv1 - conv1_args = ( - "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ") - ) + conv1_args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + conv1_args.append("lora_layer") + for a in conv1_args: assert hasattr(old_conv1, a) + conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. conv1_kwargs["in_channels"] += by # surgery done here @@ -839,25 +849,42 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, } # swap old with new modules unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) - unet.down_blocks[block_no].resnets[resnet_idx].conv1 = LoRACompatibleConv(**conv1_kwargs) - unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs) + unet.down_blocks[block_no].resnets[resnet_idx].conv1 = ( + nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) + ) + unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( + nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) + ) unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by): """Increase channels sizes to allow for additional concatted information from base model""" old_down = unet.down_blocks[block_no].downsamplers[0].conv - # conv1 - args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split( - " " - ) + + args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + args.append("lora_layer") + for a in args: assert hasattr(old_down, a) kwargs = {a: getattr(old_down, a) for a in args} kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. kwargs["in_channels"] += by # surgery done here # swap old with new modules - unet.down_blocks[block_no].downsamplers[0].conv = LoRACompatibleConv(**kwargs) + unet.down_blocks[block_no].downsamplers[0].conv = ( + nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) + ) unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here @@ -871,12 +898,20 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): assert hasattr(old_norm1, a) norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} norm_kwargs["num_channels"] += by # surgery done here - # conv1 - conv1_args = ( - "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ") - ) - for a in conv1_args: - assert hasattr(old_conv1, a) + conv1_args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + conv1_args.append("lora_layer") + conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. conv1_kwargs["in_channels"] += by # surgery done here @@ -894,8 +929,12 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): } # swap old with new modules unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) - unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs) - unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs) + unet.mid_block.resnets[0].conv1 = ( + nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) + ) + unet.mid_block.resnets[0].conv_shortcut = ( + nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) + ) unet.mid_block.resnets[0].in_channels += by # surgery done here diff --git a/tests/pipelines/controlnetxs/test_controlnetxs.py b/tests/pipelines/controlnetxs/test_controlnetxs.py index e3212e9e301c..1f184e5bb14c 100644 --- a/tests/pipelines/controlnetxs/test_controlnetxs.py +++ b/tests/pipelines/controlnetxs/test_controlnetxs.py @@ -34,6 +34,7 @@ enable_full_determinism, load_image, load_numpy, + numpy_cosine_similarity_distance, require_python39_or_higher, require_torch_2, require_torch_gpu, @@ -273,7 +274,9 @@ def test_canny(self): original_image = image[-3:, -3:, -1].flatten() expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701]) - assert np.allclose(original_image, expected_image, atol=1e-04) + + max_diff = numpy_cosine_similarity_distance(original_image, expected_image) + assert max_diff < 1e-4 def test_depth(self): controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-depth") @@ -298,7 +301,9 @@ def test_depth(self): original_image = image[-3:, -3:, -1].flatten() expected_image = np.array([0.1098, 0.1025, 0.1211, 0.1129, 0.1165, 0.1262, 0.1185, 0.1261, 0.1703]) - assert np.allclose(original_image, expected_image, atol=1e-04) + + max_diff = numpy_cosine_similarity_distance(original_image, expected_image) + assert max_diff < 1e-4 @require_python39_or_higher @require_torch_2 From bf40d7d82a732a35e0f5b907e59064ee080cde9f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 19 Dec 2023 13:26:25 +0530 Subject: [PATCH 05/42] add peft dependency to fast push tests (#6229) * add peft dependency * add peft dependency at the correct place. --- .github/workflows/push_tests_fast.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 6ea873d0a79c..2f69b00af982 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -98,6 +98,7 @@ jobs: - name: Run example PyTorch CPU tests if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | + python -m pip install peft python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ examples From 3e71a206502994d6de0b908e36254df2d09b172c Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 19 Dec 2023 07:07:24 -1000 Subject: [PATCH 06/42] [refactor embeddings]pixart-alpha (#6212) pixart-alpha Co-authored-by: yiyixuxu --- src/diffusers/models/embeddings.py | 35 +++++-------------- src/diffusers/models/normalization.py | 4 +-- src/diffusers/models/transformer_2d.py | 4 +-- .../pixart_alpha/pipeline_pixart_alpha.py | 5 +++ 4 files changed, 17 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 73abc9869230..db68591bdb44 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -729,7 +729,7 @@ def forward( return objs -class CombinedTimestepSizeEmbeddings(nn.Module): +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. @@ -746,45 +746,27 @@ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool self.use_additional_conditions = use_additional_conditions if use_additional_conditions: - self.use_additional_conditions = True self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): - if size.ndim == 1: - size = size[:, None] - - if size.shape[0] != batch_size: - size = size.repeat(batch_size // size.shape[0], 1) - if size.shape[0] != batch_size: - raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") - - current_batch_size, dims = size.shape[0], size.shape[1] - size = size.reshape(-1) - size_freq = self.additional_condition_proj(size).to(size.dtype) - - size_emb = embedder(size_freq) - size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) - return size_emb - def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) if self.use_additional_conditions: - resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) - aspect_ratio = self.apply_condition( - aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder - ) - conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) else: conditioning = timesteps_emb return conditioning -class CaptionProjection(nn.Module): +class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -796,9 +778,8 @@ def __init__(self, in_features, hidden_size, num_tokens=120): self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) self.act_1 = nn.GELU(approximate="tanh") self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) - self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5)) - def forward(self, caption, force_drop_ids=None): + def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 11d2a344744e..25af4d853b86 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from .activations import get_activation -from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings +from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings class AdaLayerNorm(nn.Module): @@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module): def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): super().__init__() - self.emb = CombinedTimestepSizeEmbeddings( + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions ) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 3aecc43f0f5b..128395cc161a 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -22,7 +22,7 @@ from ..models.embeddings import ImagePositionalEmbeddings from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version from .attention import BasicTransformerBlock -from .embeddings import CaptionProjection, PatchEmbed +from .embeddings import PatchEmbed, PixArtAlphaTextProjection from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin from .normalization import AdaLayerNormSingle @@ -235,7 +235,7 @@ def __init__( self.caption_projection = None if caption_channels is not None: - self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) self.gradient_checkpointing = False diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 090b66915dd0..82a170400068 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -853,6 +853,11 @@ def __call__( aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} # 7. Denoising loop From df476d9f63891406db2d531aa5faf195193e3354 Mon Sep 17 00:00:00 2001 From: raven Date: Wed, 20 Dec 2023 06:14:37 +0900 Subject: [PATCH 07/42] [Docs] Fix a code example in the ControlNet Inpainting documentation (#6236) fix document on masked image in inpainting controlnet --- docs/source/en/using-diffusers/controlnet.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md index c50d2e96e8ed..e7f6eb27561d 100644 --- a/docs/source/en/using-diffusers/controlnet.md +++ b/docs/source/en/using-diffusers/controlnet.md @@ -203,7 +203,7 @@ def make_inpaint_condition(image, image_mask): image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 assert image.shape[0:1] == image_mask.shape[0:1] - image[image_mask > 0.5] = 1.0 # set as masked pixel + image[image_mask > 0.5] = -1.0 # set as masked pixel image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) image = torch.from_numpy(image) return image From 54339629927bffafe810b312c2f3526c67a52324 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:50:18 -0800 Subject: [PATCH 08/42] [docs] Batched seeds (#6237) batched seed --- docs/source/en/using-diffusers/reusing_seeds.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md index d2638b469e30..6d0f6ac9837f 100644 --- a/docs/source/en/using-diffusers/reusing_seeds.md +++ b/docs/source/en/using-diffusers/reusing_seeds.md @@ -41,6 +41,20 @@ Now, define four different `Generator`s and assign each `Generator` a seed (`0` generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)] ``` + + +To create a batched seed, you should use a list comprehension that iterates over the length specified in `range()`. This creates a unique `Generator` object for each image in the batch. If you only multiply the `Generator` by the batch size, this only creates one `Generator` object that is used sequentially for each image in the batch. + +For example, if you want to use the same seed to create 4 identical images: + +```py +❌ [torch.Generator().manual_seed(seed)] * 4 + +✅ [torch.Generator().manual_seed(seed) for _ in range(4)] +``` + + + Generate the images and have a look: ```python From ff43dba7eaee6d2055d299cf58183b8d19a35daa Mon Sep 17 00:00:00 2001 From: hako-mikan <122196982+hako-mikan@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:07:19 +0900 Subject: [PATCH 09/42] [Fix] Fix Regional Prompting Pipeline (#6188) * Update regional_prompting_stable_diffusion.py * reformat * reformat * reformat * reformat * reformat * reformat * reformat * regormat * reformat * reformat * reformat * reformat * Update regional_prompting_stable_diffusion.py --------- Co-authored-by: Sayak Paul --- .../regional_prompting_stable_diffusion.py | 75 +++++++++++++------ 1 file changed, 53 insertions(+), 22 deletions(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index 525e75bc68b9..71f24a81bd15 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -73,7 +73,14 @@ def __init__( requires_safety_checker: bool = True, ): super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker, ) self.register_modules( vae=vae, @@ -102,22 +109,22 @@ def __call__( return_dict: bool = True, rp_args: Dict[str, str] = None, ): - active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721 + active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt if negative_prompt is None: - negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721 + negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt) device = self._execution_device regions = 0 self.power = int(rp_args["power"]) if "power" in rp_args else 1 - prompts = prompt if type(prompt) == list else [prompt] # noqa: E721 - n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721 + prompts = prompt if isinstance(prompt, list) else [prompt] + n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt] self.batch = batch = num_images_per_prompt * len(prompts) all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) - cn = len(all_prompts_cn) == len(all_n_prompts_cn) + equal = len(all_prompts_cn) == len(all_n_prompts_cn) if Compel: compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder) @@ -129,7 +136,7 @@ def getcompelembs(prps): return torch.cat(embl) conds = getcompelembs(all_prompts_cn) - unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts) + unconds = getcompelembs(all_n_prompts_cn) embs = getcompelembs(prompts) n_embs = getcompelembs(n_prompts) prompt = negative_prompt = None @@ -137,7 +144,7 @@ def getcompelembs(prps): conds = self.encode_prompt(prompts, device, 1, True)[0] unconds = ( self.encode_prompt(n_prompts, device, 1, True)[0] - if cn + if equal else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] ) embs = n_embs = None @@ -206,8 +213,11 @@ def forward( else: px, nx = hidden_states.chunk(2) - if cn: - hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0) + if equal: + hidden_states = torch.cat( + [px for i in range(regions)] + [nx for i in range(regions)], + 0, + ) encoder_hidden_states = torch.cat([conds] + [unconds]) else: hidden_states = torch.cat([px for i in range(regions)] + [nx], 0) @@ -289,9 +299,9 @@ def forward( if any(x in mode for x in ["COL", "ROW"]): reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2]) center = reshaped.shape[0] // 2 - px = reshaped[0:center] if cn else reshaped[0:-batch] - nx = reshaped[center:] if cn else reshaped[-batch:] - outs = [px, nx] if cn else [px] + px = reshaped[0:center] if equal else reshaped[0:-batch] + nx = reshaped[center:] if equal else reshaped[-batch:] + outs = [px, nx] if equal else [px] for out in outs: c = 0 for i, ocell in enumerate(ocells): @@ -321,15 +331,16 @@ def forward( :, ] c += 1 - px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) + px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx) hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) hidden_states = hidden_states.reshape(xshape) #### Regional Prompting Prompt mode elif "PRO" in mode: - center = reshaped.shape[0] // 2 - px = reshaped[0:center] if cn else reshaped[0:-batch] - nx = reshaped[center:] if cn else reshaped[-batch:] + px, nx = ( + torch.chunk(hidden_states) if equal else hidden_states[0:-batch], + hidden_states[-batch:], + ) if (h, w) in self.attnmasks and self.maskready: @@ -340,8 +351,8 @@ def mask(input): out[b] = out[b] + out[r * batch + b] return out - px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx) - px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) + px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx) + px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx) hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) return hidden_states @@ -378,7 +389,15 @@ def hook_forwards(root_module: torch.nn.Module): save_mask = False if mode == "PROMPT" and save_mask: - saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions) + saveattnmaps( + self, + output, + height, + width, + thresholds, + num_inference_steps // 2, + regions, + ) return output @@ -437,7 +456,11 @@ def startend(cells, array): def make_emblist(self, prompts): with torch.no_grad(): tokens = self.tokenizer( - prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" + prompts, + max_length=self.tokenizer.model_max_length, + padding=True, + truncation=True, + return_tensors="pt", ).input_ids.to(self.device) embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype) return embs @@ -563,7 +586,15 @@ def tokendealer(self, all_prompts): def scaled_dot_product_attention( - self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + getattn=False, ) -> torch.Tensor: # Efficient implementation equivalent to the following: L, S = query.size(-2), key.size(-2) From 457abdf2cf31956a15df7233187b0b358307c7d1 Mon Sep 17 00:00:00 2001 From: Beinsezii <39478211+Beinsezii@users.noreply.github.com> Date: Tue, 19 Dec 2023 23:39:25 -0800 Subject: [PATCH 10/42] EulerAncestral add `rescale_betas_zero_snr` (#6187) * EulerAncestral add `rescale_betas_zero_snr` Uses same infinite sigma fix from EulerDiscrete. Interestingly the ancestral version had the opposite problem: too much contrast instead of too little. * UT for EulerAncestral `rescale_betas_zero_snr` * EulerAncestral upcast samples during step() It helps this scheduler too, particularly when the model is using bf16. While the noise dtype is still the model's it's automatically upcasted for the add so all it affects is determinism. --------- Co-authored-by: Sayak Paul --- .../scheduling_euler_ancestral_discrete.py | 56 +++++++++++++++++++ .../test_scheduler_euler_ancestral.py | 4 ++ 2 files changed, 60 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index e476c329455e..ca188378a38f 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -92,6 +92,43 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. @@ -122,6 +159,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -138,6 +179,7 @@ def __init__( prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -152,9 +194,17 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) @@ -327,6 +377,9 @@ def step( sigma = self.sigmas[self.step_index] + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output @@ -357,6 +410,9 @@ def step( prev_sample = prev_sample + noise * sigma_up + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + # upon completion increase step index by one self._step_index += 1 diff --git a/tests/schedulers/test_scheduler_euler_ancestral.py b/tests/schedulers/test_scheduler_euler_ancestral.py index a0818042fad9..9f22ab38ddaf 100644 --- a/tests/schedulers/test_scheduler_euler_ancestral.py +++ b/tests/schedulers/test_scheduler_euler_ancestral.py @@ -37,6 +37,10 @@ def test_prediction_type(self): for prediction_type in ["epsilon", "v_prediction"]: self.check_over_configs(prediction_type=prediction_type) + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() From 22b45304bf85a3c5281753d6b3259ccaf96e5085 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 20 Dec 2023 21:01:33 +0530 Subject: [PATCH 11/42] [Refactor upsamplers and downsamplers] separate out upsamplers and downsamplers. (#6128) * separate out upsamplers and downsamplers. * import all the necessary blocks in resnet for backward comp. * move upsample2d and downsample2d to utils. * move downsample_2d to downsamplers.py * apply feedback * fix import * samplers -> sampling --- src/diffusers/models/downsampling.py | 318 ++++++++++++ src/diffusers/models/resnet.py | 714 +-------------------------- src/diffusers/models/upsampling.py | 426 ++++++++++++++++ 3 files changed, 759 insertions(+), 699 deletions(-) create mode 100644 src/diffusers/models/downsampling.py create mode 100644 src/diffusers/models/upsampling.py diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py new file mode 100644 index 000000000000..d39bae22e831 --- /dev/null +++ b/src/diffusers/models/downsampling.py @@ -0,0 +1,318 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import USE_PEFT_BACKEND +from .lora import LoRACompatibleConv +from .upsampling import upfirdn2d_native + + +class Downsample1D(nn.Module): + """A 1D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 1D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + assert inputs.shape[1] == self.channels + return self.conv(inputs) + + +class Downsample2D(nn.Module): + """A 2D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + + if use_conv: + conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + if not USE_PEFT_BACKEND: + if isinstance(self.conv, LoRACompatibleConv): + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class FirDownsample2D(nn.Module): + """A 2D FIR downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__( + self, + channels: Optional[int] = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d( + self, + hidden_states: torch.FloatTensor, + weight: Optional[torch.FloatTensor] = None, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.FloatTensor: + """Fused `Conv2d()` followed by `downsample_2d()`. + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.FloatTensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. + + Returns: + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, convH, convW = weight.shape + pad_value = (kernel.shape[0] - factor) + (convW - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + + return output + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + if self.use_conv: + downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return hidden_states + + +# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead +class KDownsample2D(nn.Module): + r"""A 2D K-downsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) + weight = inputs.new_zeros( + [ + inputs.shape[1], + inputs.shape[1], + self.kernel.shape[0], + self.kernel.shape[1], + ] + ) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) + weight[indices, indices] = kernel + return F.conv2d(inputs, weight, stride=2) + + +def downsample_2d( + hidden_states: torch.FloatTensor, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, +) -> torch.FloatTensor: + r"""Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states (`torch.FloatTensor`) + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. + + Returns: + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 970d2be05b7a..bbfb71ca3fbf 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -23,562 +23,23 @@ from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .attention_processor import SpatialNorm +from .downsampling import ( # noqa + Downsample1D, + Downsample2D, + FirDownsample2D, + KDownsample2D, + downsample_2d, +) from .lora import LoRACompatibleConv, LoRACompatibleLinear from .normalization import AdaGroupNorm - - -class Upsample1D(nn.Module): - """A 1D upsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - use_conv_transpose (`bool`, default `False`): - option to use a convolution transpose. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - name (`str`, default `conv`): - name of the upsampling 1D layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - use_conv_transpose: bool = False, - out_channels: Optional[int] = None, - name: str = "conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - self.conv = None - if use_conv_transpose: - self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - assert inputs.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(inputs) - - outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") - - if self.use_conv: - outputs = self.conv(outputs) - - return outputs - - -class Downsample1D(nn.Module): - """A 1D downsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - padding (`int`, default `1`): - padding for the convolution. - name (`str`, default `conv`): - name of the downsampling 1D layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - out_channels: Optional[int] = None, - padding: int = 1, - name: str = "conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - assert inputs.shape[1] == self.channels - return self.conv(inputs) - - -class Upsample2D(nn.Module): - """A 2D upsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - use_conv_transpose (`bool`, default `False`): - option to use a convolution transpose. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - name (`str`, default `conv`): - name of the upsampling 2D layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - use_conv_transpose: bool = False, - out_channels: Optional[int] = None, - name: str = "conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - - conv = None - if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - conv = conv_cls(self.channels, self.out_channels, 3, padding=1) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward( - self, - hidden_states: torch.FloatTensor, - output_size: Optional[int] = None, - scale: float = 1.0, - ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if self.use_conv_transpose: - return self.conv(hidden_states) - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(dtype) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if self.use_conv: - if self.name == "conv": - if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) - else: - if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: - hidden_states = self.Conv2d_0(hidden_states, scale) - else: - hidden_states = self.Conv2d_0(hidden_states) - - return hidden_states - - -class Downsample2D(nn.Module): - """A 2D downsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - padding (`int`, default `1`): - padding for the convolution. - name (`str`, default `conv`): - name of the downsampling 2D layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - out_channels: Optional[int] = None, - padding: int = 1, - name: str = "conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - - if use_conv: - conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool2d(kernel_size=stride, stride=stride) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv - - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if self.use_conv and self.padding == 0: - pad = (0, 1, 0, 1) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - - if not USE_PEFT_BACKEND: - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class FirUpsample2D(nn.Module): - """A 2D FIR upsampling layer with an optional convolution. - - Parameters: - channels (`int`, optional): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - fir_kernel (`tuple`, default `(1, 3, 3, 1)`): - kernel for the FIR filter. - """ - - def __init__( - self, - channels: Optional[int] = None, - out_channels: Optional[int] = None, - use_conv: bool = False, - fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.use_conv = use_conv - self.fir_kernel = fir_kernel - self.out_channels = out_channels - - def _upsample_2d( - self, - hidden_states: torch.FloatTensor, - weight: Optional[torch.FloatTensor] = None, - kernel: Optional[torch.FloatTensor] = None, - factor: int = 2, - gain: float = 1, - ) -> torch.FloatTensor: - """Fused `upsample_2d()` followed by `Conv2d()`. - - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states (`torch.FloatTensor`): - Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight (`torch.FloatTensor`, *optional*): - Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be - performed by `inChannels = x.shape[0] // numGroups`. - kernel (`torch.FloatTensor`, *optional*): - FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which - corresponds to nearest-neighbor upsampling. - factor (`int`, *optional*): Integer upsampling factor (default: 2). - gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). - - Returns: - output (`torch.FloatTensor`): - Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same - datatype as `hidden_states`. - """ - - assert isinstance(factor, int) and factor >= 1 - - # Setup filter kernel. - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * (gain * (factor**2)) - - if self.use_conv: - convH = weight.shape[2] - convW = weight.shape[3] - inC = weight.shape[1] - - pad_value = (kernel.shape[0] - factor) - (convW - 1) - - stride = (factor, factor) - # Determine data dimensions. - output_shape = ( - (hidden_states.shape[2] - 1) * factor + convH, - (hidden_states.shape[3] - 1) * factor + convW, - ) - output_padding = ( - output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, - ) - assert output_padding[0] >= 0 and output_padding[1] >= 0 - num_groups = hidden_states.shape[1] // inC - - # Transpose weights. - weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) - weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - - inverse_conv = F.conv_transpose2d( - hidden_states, - weight, - stride=stride, - output_padding=output_padding, - padding=0, - ) - - output = upfirdn2d_native( - inverse_conv, - torch.tensor(kernel, device=inverse_conv.device), - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), - ) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - - return output - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - if self.use_conv: - height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) - height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - - return height - - -class FirDownsample2D(nn.Module): - """A 2D FIR downsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - fir_kernel (`tuple`, default `(1, 3, 3, 1)`): - kernel for the FIR filter. - """ - - def __init__( - self, - channels: Optional[int] = None, - out_channels: Optional[int] = None, - use_conv: bool = False, - fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.fir_kernel = fir_kernel - self.use_conv = use_conv - self.out_channels = out_channels - - def _downsample_2d( - self, - hidden_states: torch.FloatTensor, - weight: Optional[torch.FloatTensor] = None, - kernel: Optional[torch.FloatTensor] = None, - factor: int = 2, - gain: float = 1, - ) -> torch.FloatTensor: - """Fused `Conv2d()` followed by `downsample_2d()`. - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states (`torch.FloatTensor`): - Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight (`torch.FloatTensor`, *optional*): - Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be - performed by `inChannels = x.shape[0] // numGroups`. - kernel (`torch.FloatTensor`, *optional*): - FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which - corresponds to average pooling. - factor (`int`, *optional*, default to `2`): - Integer downsampling factor. - gain (`float`, *optional*, default to `1.0`): - Scaling factor for signal magnitude. - - Returns: - output (`torch.FloatTensor`): - Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same - datatype as `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * gain - - if self.use_conv: - _, _, convH, convW = weight.shape - pad_value = (kernel.shape[0] - factor) + (convW - 1) - stride_value = [factor, factor] - upfirdn_input = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - pad=((pad_value + 1) // 2, pad_value // 2), - ) - output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - - return output - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - if self.use_conv: - downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - - return hidden_states - - -# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead -class KDownsample2D(nn.Module): - r"""A 2D K-downsampling layer. - - Parameters: - pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. - """ - - def __init__(self, pad_mode: str = "reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) - weight = inputs.new_zeros( - [ - inputs.shape[1], - inputs.shape[1], - self.kernel.shape[0], - self.kernel.shape[1], - ] - ) - indices = torch.arange(inputs.shape[1], device=inputs.device) - kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) - weight[indices, indices] = kernel - return F.conv2d(inputs, weight, stride=2) - - -class KUpsample2D(nn.Module): - r"""A 2D K-upsampling layer. - - Parameters: - pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. - """ - - def __init__(self, pad_mode: str = "reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) - weight = inputs.new_zeros( - [ - inputs.shape[1], - inputs.shape[1], - self.kernel.shape[0], - self.kernel.shape[1], - ] - ) - indices = torch.arange(inputs.shape[1], device=inputs.device) - kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) - weight[indices, indices] = kernel - return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) +from .upsampling import ( # noqa + FirUpsample2D, + KUpsample2D, + Upsample1D, + Upsample2D, + upfirdn2d_native, + upsample_2d, +) class ResnetBlock2D(nn.Module): @@ -894,151 +355,6 @@ def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: return out + self.residual_conv(inputs) -def upsample_2d( - hidden_states: torch.FloatTensor, - kernel: Optional[torch.FloatTensor] = None, - factor: int = 2, - gain: float = 1, -) -> torch.FloatTensor: - r"""Upsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given - filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is - a: multiple of the upsampling factor. - - Args: - hidden_states (`torch.FloatTensor`): - Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel (`torch.FloatTensor`, *optional*): - FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which - corresponds to nearest-neighbor upsampling. - factor (`int`, *optional*, default to `2`): - Integer upsampling factor. - gain (`float`, *optional*, default to `1.0`): - Scaling factor for signal magnitude (default: 1.0). - - Returns: - output (`torch.FloatTensor`): - Tensor of the shape `[N, C, H * factor, W * factor]` - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * (gain * (factor**2)) - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - return output - - -def downsample_2d( - hidden_states: torch.FloatTensor, - kernel: Optional[torch.FloatTensor] = None, - factor: int = 2, - gain: float = 1, -) -> torch.FloatTensor: - r"""Downsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the - given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the - specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its - shape is a multiple of the downsampling factor. - - Args: - hidden_states (`torch.FloatTensor`) - Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel (`torch.FloatTensor`, *optional*): - FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which - corresponds to average pooling. - factor (`int`, *optional*, default to `2`): - Integer downsampling factor. - gain (`float`, *optional*, default to `1.0`): - Scaling factor for signal magnitude. - - Returns: - output (`torch.FloatTensor`): - Tensor of the shape `[N, C, H // factor, W // factor]` - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * gain - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - return output - - -def upfirdn2d_native( - tensor: torch.Tensor, - kernel: torch.Tensor, - up: int = 1, - down: int = 1, - pad: Tuple[int, int] = (0, 0), -) -> torch.Tensor: - up_x = up_y = up - down_x = down_y = down - pad_x0 = pad_y0 = pad[0] - pad_x1 = pad_y1 = pad[1] - - _, channel, in_h, in_w = tensor.shape - tensor = tensor.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = tensor.shape - kernel_h, kernel_w = kernel.shape - - out = tensor.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) - out = out.to(tensor.device) # Move back to mps if necessary - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) - - class TemporalConvLayer(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py new file mode 100644 index 000000000000..542a5d9d1eb0 --- /dev/null +++ b/src/diffusers/models/upsampling.py @@ -0,0 +1,426 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import USE_PEFT_BACKEND +from .lora import LoRACompatibleConv + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 1D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = conv_cls(self.channels, self.out_channels, 3, padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + output_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) + else: + if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: + hidden_states = self.Conv2d_0(hidden_states, scale) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class FirUpsample2D(nn.Module): + """A 2D FIR upsampling layer with an optional convolution. + + Parameters: + channels (`int`, optional): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__( + self, + channels: Optional[int] = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d( + self, + hidden_states: torch.FloatTensor, + weight: Optional[torch.FloatTensor] = None, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.FloatTensor: + """Fused `upsample_2d()` followed by `Conv2d()`. + + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.FloatTensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*): Integer upsampling factor (default: 2). + gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). + + Returns: + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + + pad_value = (kernel.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + # Determine data dimensions. + output_shape = ( + (hidden_states.shape[2] - 1) * factor + convH, + (hidden_states.shape[3] - 1) * factor + convW, + ) + output_padding = ( + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = hidden_states.shape[1] // inC + + # Transpose weights. + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + + inverse_conv = F.conv_transpose2d( + hidden_states, + weight, + stride=stride, + output_padding=output_padding, + padding=0, + ) + + output = upfirdn2d_native( + inverse_conv, + torch.tensor(kernel, device=inverse_conv.device), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + + return output + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + if self.use_conv: + height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return height + + +class KUpsample2D(nn.Module): + r"""A 2D K-upsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = inputs.new_zeros( + [ + inputs.shape[1], + inputs.shape[1], + self.kernel.shape[0], + self.kernel.shape[1], + ] + ) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) + weight[indices, indices] = kernel + return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) + + +def upfirdn2d_native( + tensor: torch.Tensor, + kernel: torch.Tensor, + up: int = 1, + down: int = 1, + pad: Tuple[int, int] = (0, 0), +) -> torch.Tensor: + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(tensor.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) + + +def upsample_2d( + hidden_states: torch.FloatTensor, + kernel: Optional[torch.FloatTensor] = None, + factor: int = 2, + gain: float = 1, +) -> torch.FloatTensor: + r"""Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*, default to `2`): + Integer upsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude (default: 1.0). + + Returns: + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output From 0532cece973310e75de56b5b26b15831d6b5cd0d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 21 Dec 2023 09:03:17 +0530 Subject: [PATCH 12/42] Bump transformers from 4.34.0 to 4.36.0 in /examples/research_projects/realfill (#6255) Bump transformers in /examples/research_projects/realfill Bumps [transformers](https://github.com/huggingface/transformers) from 4.34.0 to 4.36.0. - [Release notes](https://github.com/huggingface/transformers/releases) - [Commits](https://github.com/huggingface/transformers/compare/v4.34.0...v4.36.0) --- updated-dependencies: - dependency-name: transformers dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- examples/research_projects/realfill/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/realfill/requirements.txt b/examples/research_projects/realfill/requirements.txt index 3827f0852a20..5d69d8456324 100644 --- a/examples/research_projects/realfill/requirements.txt +++ b/examples/research_projects/realfill/requirements.txt @@ -1,6 +1,6 @@ diffusers==0.20.1 accelerate==0.23.0 -transformers==4.34.0 +transformers==4.36.0 peft==0.5.0 torch==2.0.1 torchvision>=0.16 From 6ca9c4af05d7263ef26f277330487fc7c9af0878 Mon Sep 17 00:00:00 2001 From: lvzi <39146704+lvzii@users.noreply.github.com> Date: Thu, 21 Dec 2023 11:39:26 +0800 Subject: [PATCH 13/42] fix: unscale fp16 gradient problem & potential error (#6086) (#6231) Co-authored-by: Sayak Paul --- .../text_to_image/train_text_to_image_lora_sdxl.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index a8a41b150523..4bcc441e04ec 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -640,6 +640,17 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -1187,6 +1198,9 @@ def compute_time_ids(original_size, crops_coords_top_left): torch.cuda.empty_cache() # Final inference + # Make sure vae.dtype is consistent with the unet.dtype + if args.mixed_precision == "fp16": + vae.to(weight_dtype) # Load previous pipeline pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, From 6269045c5b6579df80e681cbebe875196f157d70 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 21 Dec 2023 12:26:36 +0530 Subject: [PATCH 14/42] [Refactor] move diffedit out of stable_diffusion (#6260) * move diffedit out of stable_diffuson * fix: import * style * fix: import --- src/diffusers/pipelines/__init__.py | 4 +- .../pipelines/stable_diffusion/__init__.py | 5 -- .../stable_diffusion_diffedit/__init__.py | 48 +++++++++++++++++++ .../pipeline_stable_diffusion_diffedit.py | 4 +- 4 files changed, 52 insertions(+), 9 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py (99%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e7d34b623711..e760355ff754 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -181,7 +181,6 @@ "CLIPImageProjection", "StableDiffusionAttendAndExcitePipeline", "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENTextImagePipeline", @@ -209,6 +208,7 @@ "StableDiffusionXLPipeline", ] ) + _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] _import_structure["t2i_adapter"] = [ "StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline", @@ -422,7 +422,6 @@ CLIPImageProjection, StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline, StableDiffusionImageVariationPipeline, @@ -438,6 +437,7 @@ StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) + from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index dbd79ec1f367..085b46befff6 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -67,20 +67,17 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, StableDiffusionPix2PixZeroPipeline, ) _dummy_objects.update( { "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, - "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, } ) else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] - _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"] try: if not ( @@ -181,14 +178,12 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, StableDiffusionPix2PixZeroPipeline, ) else: from .pipeline_stable_diffusion_depth2img import ( StableDiffusionDepth2ImgPipeline, ) - from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline try: if not ( diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py b/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py new file mode 100644 index 000000000000..e2145edb96c6 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py rename to src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index 81d936be62b4..d0d132555e69 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -40,8 +40,8 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name From bcecfbc873f9041bb5f57495640bfa5e6dc74976 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 21 Dec 2023 12:35:09 +0530 Subject: [PATCH 15/42] move attend and excite out of stable_diffusion --- src/diffusers/pipelines/__init__.py | 4 ++-- src/diffusers/pipelines/stable_diffusion/__init__.py | 4 ---- .../__init__.py | 4 ++-- .../pipeline_stable_diffusion_attend_and_excite.py | 4 ++-- 4 files changed, 6 insertions(+), 10 deletions(-) rename src/diffusers/pipelines/{stable_diffusion_diffedit => stable_diffusion_attend_and_excite}/__init__.py (84%) rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py (99%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e760355ff754..72e72f177268 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -179,7 +179,6 @@ _import_structure["stable_diffusion"].extend( [ "CLIPImageProjection", - "StableDiffusionAttendAndExcitePipeline", "StableDiffusionDepth2ImgPipeline", "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENPipeline", @@ -209,6 +208,7 @@ ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["t2i_adapter"] = [ "StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline", @@ -420,7 +420,6 @@ from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_diffusion import ( CLIPImageProjection, - StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline, @@ -437,6 +436,7 @@ StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) + from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_xl import ( diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 085b46befff6..40e3680010bb 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -32,7 +32,6 @@ _import_structure["clip_image_project_model"] = ["CLIPImageProjection"] _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] _import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"] - _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] _import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"] @@ -136,9 +135,6 @@ StableDiffusionPipelineOutput, StableDiffusionSafetyChecker, ) - from .pipeline_stable_diffusion_attend_and_excite import ( - StableDiffusionAttendAndExcitePipeline, - ) from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline from .pipeline_stable_diffusion_gligen_text_image import ( StableDiffusionGLIGENTextImagePipeline, diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py similarity index 84% rename from src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py rename to src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py index e2145edb96c6..cce556fceb23 100644 --- a/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,7 +32,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline else: import sys diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py rename to src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py index 78023f544ecf..401c45c23bea 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py @@ -37,8 +37,8 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) From c5ff469d0ea7161c6166d4bad9741b60725baf3f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 21 Dec 2023 12:35:58 +0530 Subject: [PATCH 16/42] Revert "move attend and excite out of stable_diffusion" This reverts commit bcecfbc873f9041bb5f57495640bfa5e6dc74976. --- src/diffusers/pipelines/__init__.py | 4 ++-- src/diffusers/pipelines/stable_diffusion/__init__.py | 4 ++++ .../pipeline_stable_diffusion_attend_and_excite.py | 4 ++-- .../__init__.py | 4 ++-- 4 files changed, 10 insertions(+), 6 deletions(-) rename src/diffusers/pipelines/{stable_diffusion_attend_and_excite => stable_diffusion}/pipeline_stable_diffusion_attend_and_excite.py (99%) rename src/diffusers/pipelines/{stable_diffusion_attend_and_excite => stable_diffusion_diffedit}/__init__.py (84%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 72e72f177268..e760355ff754 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -179,6 +179,7 @@ _import_structure["stable_diffusion"].extend( [ "CLIPImageProjection", + "StableDiffusionAttendAndExcitePipeline", "StableDiffusionDepth2ImgPipeline", "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENPipeline", @@ -208,7 +209,6 @@ ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] - _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["t2i_adapter"] = [ "StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline", @@ -420,6 +420,7 @@ from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_diffusion import ( CLIPImageProjection, + StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline, @@ -436,7 +437,6 @@ StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) - from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_xl import ( diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 40e3680010bb..085b46befff6 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -32,6 +32,7 @@ _import_structure["clip_image_project_model"] = ["CLIPImageProjection"] _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] _import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"] + _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] _import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"] @@ -135,6 +136,9 @@ StableDiffusionPipelineOutput, StableDiffusionSafetyChecker, ) + from .pipeline_stable_diffusion_attend_and_excite import ( + StableDiffusionAttendAndExcitePipeline, + ) from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline from .pipeline_stable_diffusion_gligen_text_image import ( StableDiffusionGLIGENTextImagePipeline, diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py rename to src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 401c45c23bea..78023f544ecf 100644 --- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -37,8 +37,8 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion import StableDiffusionPipelineOutput -from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py b/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py similarity index 84% rename from src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py rename to src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py index cce556fceb23..e2145edb96c6 100644 --- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] + _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,7 +32,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline else: import sys From 35a969d297cba69110d175ee79c59312b9f49e1e Mon Sep 17 00:00:00 2001 From: YShow <66633207+Yimi81@users.noreply.github.com> Date: Thu, 21 Dec 2023 16:47:52 +0800 Subject: [PATCH 17/42] [Training] remove depcreated method from lora scripts again (#6266) * remove depcreated method from lora scripts * check code quality --- .../train_text_to_image_lora_sdxl.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 4bcc441e04ec..be17c13c2885 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -22,7 +22,6 @@ import random import shutil from pathlib import Path -from typing import Dict import datasets import numpy as np @@ -436,22 +435,6 @@ def parse_args(input_args=None): } -def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: - """ - Returns: - a state dict containing just the attention processor parameters. - """ - attn_processors = unet.attn_processors - - attn_processors_state_dict = {} - - for attn_processor_key, attn_processor in attn_processors.items(): - for parameter_key, parameter in attn_processor.state_dict().items(): - attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter - - return attn_processors_state_dict - - def tokenize_prompt(tokenizer, prompt): text_inputs = tokenizer( prompt, From bffadde1265976d3bc5e7e93178119c65de4de28 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 21 Dec 2023 15:24:24 +0530 Subject: [PATCH 18/42] [Refactor] move k diffusion out of stable_diffusion (#6267) move k diffusion out of stable_diffusion --- src/diffusers/pipelines/__init__.py | 4 +- .../pipelines/stable_diffusion/__init__.py | 30 ---------- .../stable_diffusion_k_diffusion/__init__.py | 60 +++++++++++++++++++ .../pipeline_stable_diffusion_k_diffusion.py | 2 +- 4 files changed, 63 insertions(+), 33 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py (99%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e760355ff754..115de2a8fc7a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -268,7 +268,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) else: - _import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"]) + _import_structure["stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"] try: if not is_flax_available(): raise OptionalDependencyNotAvailable() @@ -498,7 +498,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * else: - from .stable_diffusion import StableDiffusionKDiffusionPipeline + from .stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline try: if not is_flax_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 085b46befff6..c8dac8046548 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -79,22 +79,7 @@ else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"] -try: - if not ( - is_torch_available() - and is_transformers_available() - and is_k_diffusion_available() - and is_k_diffusion_version(">=", "0.0.12") - ): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import ( - dummy_torch_and_transformers_and_k_diffusion_objects, - ) - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) -else: - _import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"] try: if not (is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() @@ -185,21 +170,6 @@ StableDiffusionDepth2ImgPipeline, ) - try: - if not ( - is_torch_available() - and is_transformers_available() - and is_k_diffusion_available() - and is_k_diffusion_version(">=", "0.0.12") - ): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * - else: - from .pipeline_stable_diffusion_k_diffusion import ( - StableDiffusionKDiffusionPipeline, - ) - try: if not (is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py new file mode 100644 index 000000000000..6c4bd0047f02 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py @@ -0,0 +1,60 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_k_diffusion_available, + is_k_diffusion_version, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not ( + is_transformers_available() + and is_torch_available() + and is_k_diffusion_available() + and is_k_diffusion_version(">=", "0.0.12") + ): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) +else: + _import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not ( + is_transformers_available() + and is_torch_available() + and is_k_diffusion_available() + and is_k_diffusion_version(">=", "0.0.12") + ): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * + else: + from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py rename to src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py index 5c472fad98ef..53e5a34a3b33 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -27,7 +27,7 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput +from ..stable_diffusion import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 2c34c7d6ddba1776e9131a052893dfeb2c48be82 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 21 Dec 2023 15:26:52 +0530 Subject: [PATCH 19/42] [Refactor] move gligen out of stable diffusion. (#6265) * move gligen out of stable diffusion. * fix: import * fix import module --- src/diffusers/pipelines/__init__.py | 10 ++-- .../pipelines/stable_diffusion/__init__.py | 6 --- .../stable_diffusion_gligen/__init__.py | 50 +++++++++++++++++++ .../pipeline_stable_diffusion_gligen.py | 4 +- ...line_stable_diffusion_gligen_text_image.py | 6 +-- 5 files changed, 60 insertions(+), 16 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_gligen/__init__.py rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py (99%) rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py (99%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 115de2a8fc7a..b10102e0909f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -181,9 +181,6 @@ "CLIPImageProjection", "StableDiffusionAttendAndExcitePipeline", "StableDiffusionDepth2ImgPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", "StableDiffusionImageVariationPipeline", "StableDiffusionImg2ImgPipeline", "StableDiffusionInpaintPipeline", @@ -199,6 +196,10 @@ ] ) _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] + _import_structure["stable_diffusion_gligen"] = [ + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + ] _import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"] _import_structure["stable_diffusion_xl"].extend( [ @@ -422,8 +423,6 @@ CLIPImageProjection, StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, - StableDiffusionGLIGENPipeline, - StableDiffusionGLIGENTextImagePipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, @@ -438,6 +437,7 @@ StableUnCLIPPipeline, ) from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index c8dac8046548..c161df0d5bc3 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -33,8 +33,6 @@ _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] _import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"] _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] - _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] - _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] _import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"] _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] @@ -124,10 +122,6 @@ from .pipeline_stable_diffusion_attend_and_excite import ( StableDiffusionAttendAndExcitePipeline, ) - from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline - from .pipeline_stable_diffusion_gligen_text_image import ( - StableDiffusionGLIGENTextImagePipeline, - ) from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_instruct_pix2pix import ( diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/__init__.py b/src/diffusers/pipelines/stable_diffusion_gligen/__init__.py new file mode 100644 index 000000000000..147980cbf9e5 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_gligen/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] + _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline + from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py rename to src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py index b85f40a54579..91d7357fd352 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -36,8 +36,8 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py rename to src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py index 405097248e2a..2c172ce46e45 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py @@ -35,9 +35,9 @@ from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .clip_image_project_model import CLIPImageProjection -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.clip_image_project_model import CLIPImageProjection +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 9ea6ac1b07bc351943a6a0c0011e8ccb4ebb9b96 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 21 Dec 2023 16:09:49 +0530 Subject: [PATCH 20/42] [Refactor] move sag out of `stable_diffusion` (#6264) move sag out of . --- src/diffusers/pipelines/__init__.py | 4 +- .../pipelines/stable_diffusion/__init__.py | 2 - .../stable_diffusion_sag/__init__.py | 48 +++++++++++++++++++ .../pipeline_stable_diffusion_sag.py | 4 +- 4 files changed, 52 insertions(+), 6 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_sag/__init__.py rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_sag}/pipeline_stable_diffusion_sag.py (99%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b10102e0909f..6744054adcaa 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -189,13 +189,13 @@ "StableDiffusionLDM3DPipeline", "StableDiffusionPanoramaPipeline", "StableDiffusionPipeline", - "StableDiffusionSAGPipeline", "StableDiffusionUpscalePipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", ] ) _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] + _import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] _import_structure["stable_diffusion_gligen"] = [ "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENTextImagePipeline", @@ -431,7 +431,6 @@ StableDiffusionLDM3DPipeline, StableDiffusionPanoramaPipeline, StableDiffusionPipeline, - StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, @@ -439,6 +438,7 @@ from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .stable_diffusion_sag import StableDiffusionSAGPipeline from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index c161df0d5bc3..ab920834306a 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -42,7 +42,6 @@ _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] _import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] - _import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] _import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"] _import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"] _import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"] @@ -132,7 +131,6 @@ ) from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline - from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline diff --git a/src/diffusers/pipelines/stable_diffusion_sag/__init__.py b/src/diffusers/pipelines/stable_diffusion_sag/__init__.py new file mode 100644 index 000000000000..378e0e57817f --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_sag/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py rename to src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py index 792a3c40b33d..36a0a956c15b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -34,8 +34,8 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 43979c2890134fb66e913f7fd5d65b1213ce8c72 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 21 Dec 2023 11:50:05 +0100 Subject: [PATCH 21/42] TST Fix LoRA test that fails with PEFT >= 0.7.0 (#6216) See #6185 for context. Co-authored-by: Sayak Paul --- tests/lora/test_lora_layers_peft.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 6d3ac8b4592a..1d8c6977440c 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import importlib import os import tempfile import time @@ -24,6 +25,7 @@ import torch.nn.functional as F from huggingface_hub import hf_hub_download from huggingface_hub.repocard import RepoCard +from packaging import version from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( @@ -1983,10 +1985,26 @@ def test_sdxl_1_0_fuse_unfuse_all(self): fused_te_2_state_dict = pipe.text_encoder_2.state_dict() unet_state_dict = pipe.unet.state_dict() + peft_ge_070 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0") + + def remap_key(key, sd): + # some keys have moved around for PEFT >= 0.7.0, but they should still be loaded correctly + if (key in sd) or (not peft_ge_070): + return key + + # instead of linear.weight, we now have linear.base_layer.weight, etc. + if key.endswith(".weight"): + key = key[:-7] + ".base_layer.weight" + elif key.endswith(".bias"): + key = key[:-5] + ".base_layer.bias" + return key + for key, value in text_encoder_1_sd.items(): + key = remap_key(key, fused_te_state_dict) self.assertTrue(torch.allclose(fused_te_state_dict[key], value)) for key, value in text_encoder_2_sd.items(): + key = remap_key(key, fused_te_2_state_dict) self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value)) for key, value in unet_state_dict.items(): From 325f6c53edf10a7b3f4804d4b38e89f95873d3c2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 21 Dec 2023 16:49:32 +0530 Subject: [PATCH 22/42] [Refactor] move attend and excite out of `stable_diffusion`. (#6261) * move attend and excite out. * fix: import * fix diffedit --- src/diffusers/pipelines/__init__.py | 4 +- .../pipelines/stable_diffusion/__init__.py | 5 +- .../__init__.py | 48 +++++++++++++++++++ ...line_stable_diffusion_attend_and_excite.py | 4 +- 4 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py (99%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6744054adcaa..8256a01045ee 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -179,7 +179,6 @@ _import_structure["stable_diffusion"].extend( [ "CLIPImageProjection", - "StableDiffusionAttendAndExcitePipeline", "StableDiffusionDepth2ImgPipeline", "StableDiffusionImageVariationPipeline", "StableDiffusionImg2ImgPipeline", @@ -194,6 +193,7 @@ "StableUnCLIPPipeline", ] ) + _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] _import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] _import_structure["stable_diffusion_gligen"] = [ @@ -421,7 +421,6 @@ from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_diffusion import ( CLIPImageProjection, - StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, @@ -435,6 +434,7 @@ StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) + from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index ab920834306a..7f72e74307f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -33,6 +33,8 @@ _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] _import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"] _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] + _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] + _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] _import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"] _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] @@ -118,9 +120,6 @@ StableDiffusionPipelineOutput, StableDiffusionSafetyChecker, ) - from .pipeline_stable_diffusion_attend_and_excite import ( - StableDiffusionAttendAndExcitePipeline, - ) from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_instruct_pix2pix import ( diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py new file mode 100644 index 000000000000..cce556fceb23 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py rename to src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py index 78023f544ecf..401c45c23bea 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py @@ -37,8 +37,8 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) From 9c7cc360114a0602ba14c54a2f439bff32097653 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 21 Dec 2023 18:17:05 +0530 Subject: [PATCH 23/42] [Refactor] move panorama out of `stable_diffusion` (#6262) * move panorama out. * fix: diffedit * fix: import. * fix: impirt --- src/diffusers/pipelines/__init__.py | 6 +-- .../pipelines/stable_diffusion/__init__.py | 2 - .../stable_diffusion_panorama/__init__.py | 48 +++++++++++++++++++ .../pipeline_stable_diffusion_panorama.py | 4 +- 4 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_panorama/__init__.py rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py (99%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8256a01045ee..d05d13ef45d4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -185,12 +185,11 @@ "StableDiffusionInpaintPipeline", "StableDiffusionInstructPix2PixPipeline", "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionPanoramaPipeline", "StableDiffusionPipeline", "StableDiffusionUpscalePipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", + "StableDiffusionLDM3DPipeline", ] ) _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] @@ -210,6 +209,7 @@ ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + _import_structure["stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] _import_structure["t2i_adapter"] = [ "StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline", @@ -428,7 +428,6 @@ StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, StableDiffusionLDM3DPipeline, - StableDiffusionPanoramaPipeline, StableDiffusionPipeline, StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, @@ -437,6 +436,7 @@ from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline + from .stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_sag import StableDiffusionSAGPipeline from .stable_diffusion_xl import ( diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 7f72e74307f8..28fb96c3309b 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -42,7 +42,6 @@ _import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"] _import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] - _import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] _import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"] _import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"] @@ -129,7 +128,6 @@ StableDiffusionLatentUpscalePipeline, ) from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline - from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/__init__.py b/src/diffusers/pipelines/stable_diffusion_panorama/__init__.py new file mode 100644 index 000000000000..f7572db7236c --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_panorama/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py rename to src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index bcc063499459..f0ef4b9f88f3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -33,8 +33,8 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name From ab0459f2b7685d54b8b4ea1578eeda6ddece0913 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 21 Dec 2023 18:17:28 +0530 Subject: [PATCH 24/42] [Deprecated pipelines] remove pix2pix zero from init (#6268) remove pix2pix zero from init --- src/diffusers/pipelines/stable_diffusion/__init__.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 28fb96c3309b..6cb2744e6653 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -65,18 +65,15 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, - StableDiffusionPix2PixZeroPipeline, ) _dummy_objects.update( { "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, - "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, } ) else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] - _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"] try: if not (is_transformers_available() and is_onnx_available()): @@ -150,10 +147,7 @@ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import ( - StableDiffusionDepth2ImgPipeline, - StableDiffusionPix2PixZeroPipeline, - ) + from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline else: from .pipeline_stable_diffusion_depth2img import ( StableDiffusionDepth2ImgPipeline, From 5b186b712837104b40d095b26ed6a2ec61246cb4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 21 Dec 2023 18:59:55 +0530 Subject: [PATCH 25/42] [Refactor] move ldm3d out of stable_diffusion. (#6263) ldm3d. --- .../stable_diffusion/ldm3d_diffusion.md | 4 +- src/diffusers/pipelines/__init__.py | 3 +- .../pipelines/stable_diffusion/__init__.py | 2 - .../stable_diffusion_ldm3d/__init__.py | 48 +++++++++++++++++++ .../pipeline_stable_diffusion_ldm3d.py | 2 +- 5 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py rename src/diffusers/pipelines/{stable_diffusion => stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py (99%) diff --git a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md index b2dc7b735717..45900b3f11f2 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md +++ b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md @@ -31,14 +31,14 @@ Make sure to check out the Stable Diffusion [Tips](overview#tips) section to lea ## StableDiffusionLDM3DPipeline -[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline +[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline - all - __call__ ## LDM3DPipelineOutput -[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput +[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput - all - __call__ diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index d05d13ef45d4..92839e596978 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -209,6 +209,7 @@ ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + _import_structure["stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] _import_structure["stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] _import_structure["t2i_adapter"] = [ "StableDiffusionAdapterPipeline", @@ -427,7 +428,6 @@ StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, - StableDiffusionLDM3DPipeline, StableDiffusionPipeline, StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, @@ -436,6 +436,7 @@ from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline + from .stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_sag import StableDiffusionSAGPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 6cb2744e6653..0eda32d333b9 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -40,7 +40,6 @@ _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] _import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"] _import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"] - _import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] _import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"] @@ -124,7 +123,6 @@ from .pipeline_stable_diffusion_latent_upscale import ( StableDiffusionLatentUpscalePipeline, ) - from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py new file mode 100644 index 000000000000..dae2affddd1f --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py similarity index 99% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py rename to src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index ee9335a2bb01..699bd10041d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -37,7 +37,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 4039815276215ea0ff8ce7ac6670dc9dbe08f817 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 21 Dec 2023 11:40:55 -0800 Subject: [PATCH 26/42] open muse (#5437) amused rename Update docs/source/en/api/pipelines/amused.md Co-authored-by: Patrick von Platen AdaLayerNormContinuous default values custom micro conditioning micro conditioning docs put lookup from codebook in constructor fix conversion script remove manual fused flash attn kernel add training script temp remove training script add dummy gradient checkpointing func clarify temperatures is an instance variable by setting it remove additional SkipFF block args hardcode norm args rename tests folder fix paths and samples fix tests add training script training readme lora saving and loading non-lora saving/loading some readme fixes guards Update docs/source/en/api/pipelines/amused.md Co-authored-by: Suraj Patil Update examples/amused/README.md Co-authored-by: Suraj Patil Update examples/amused/train_amused.py Co-authored-by: Suraj Patil vae upcasting add fp16 integration tests use tuple for micro cond copyrights remove casts delegate to torch.nn.LayerNorm move temperature to pipeline call upsampling/downsampling changes --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/amused.md | 30 + examples/amused/README.md | 326 ++++++ examples/amused/train_amused.py | 972 ++++++++++++++++++ scripts/convert_amused.py | 523 ++++++++++ src/diffusers/__init__.py | 10 + src/diffusers/loaders/lora.py | 95 +- src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention.py | 139 ++- src/diffusers/models/autoencoders/vae.py | 4 + src/diffusers/models/downsampling.py | 22 +- src/diffusers/models/embeddings.py | 5 +- src/diffusers/models/normalization.py | 106 ++ src/diffusers/models/upsampling.py | 40 +- src/diffusers/models/uvit_2d.py | 471 +++++++++ src/diffusers/models/vq_model.py | 9 +- src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/amused/__init__.py | 62 ++ .../pipelines/amused/pipeline_amused.py | 328 ++++++ .../amused/pipeline_amused_img2img.py | 347 +++++++ .../amused/pipeline_amused_inpaint.py | 378 +++++++ src/diffusers/schedulers/__init__.py | 2 + src/diffusers/schedulers/scheduling_amused.py | 162 +++ src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 45 + tests/pipelines/amused/__init__.py | 0 tests/pipelines/amused/test_amused.py | 181 ++++ tests/pipelines/amused/test_amused_img2img.py | 239 +++++ tests/pipelines/amused/test_amused_inpaint.py | 277 +++++ tests/pipelines/test_pipelines_common.py | 4 +- 30 files changed, 4789 insertions(+), 24 deletions(-) create mode 100644 docs/source/en/api/pipelines/amused.md create mode 100644 examples/amused/README.md create mode 100644 examples/amused/train_amused.py create mode 100644 scripts/convert_amused.py create mode 100644 src/diffusers/models/uvit_2d.py create mode 100644 src/diffusers/pipelines/amused/__init__.py create mode 100644 src/diffusers/pipelines/amused/pipeline_amused.py create mode 100644 src/diffusers/pipelines/amused/pipeline_amused_img2img.py create mode 100644 src/diffusers/pipelines/amused/pipeline_amused_inpaint.py create mode 100644 src/diffusers/schedulers/scheduling_amused.py create mode 100644 tests/pipelines/amused/__init__.py create mode 100644 tests/pipelines/amused/test_amused.py create mode 100644 tests/pipelines/amused/test_amused_img2img.py create mode 100644 tests/pipelines/amused/test_amused_inpaint.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 62588bf4abb8..3e9e83e6512e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -244,6 +244,8 @@ - sections: - local: api/pipelines/overview title: Overview + - local: api/pipelines/amused + title: aMUSEd - local: api/pipelines/animatediff title: AnimateDiff - local: api/pipelines/attend_and_excite diff --git a/docs/source/en/api/pipelines/amused.md b/docs/source/en/api/pipelines/amused.md new file mode 100644 index 000000000000..cb8693802173 --- /dev/null +++ b/docs/source/en/api/pipelines/amused.md @@ -0,0 +1,30 @@ + + +# aMUSEd + +Amused is a lightweight text to image model based off of the [muse](https://arxiv.org/pdf/2301.00704.pdf) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once. + +Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes. + +| Model | Params | +|-------|--------| +| [amused-256](https://huggingface.co/huggingface/amused-256) | 603M | +| [amused-512](https://huggingface.co/huggingface/amused-512) | 608M | + +## AmusedPipeline + +[[autodoc]] AmusedPipeline + - __call__ + - all + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention \ No newline at end of file diff --git a/examples/amused/README.md b/examples/amused/README.md new file mode 100644 index 000000000000..517c2d382f8e --- /dev/null +++ b/examples/amused/README.md @@ -0,0 +1,326 @@ +## Amused training + +Amused can be finetuned on simple datasets relatively cheaply and quickly. Using 8bit optimizers, lora, and gradient accumulation, amused can be finetuned with as little as 5.5 GB. Here are a set of examples for finetuning amused on some relatively simple datasets. These training recipies are aggressively oriented towards minimal resources and fast verification -- i.e. the batch sizes are quite low and the learning rates are quite high. For optimal quality, you will probably want to increase the batch sizes and decrease learning rates. + +All training examples use fp16 mixed precision and gradient checkpointing. We don't show 8 bit adam + lora as its about the same memory use as just using lora (bitsandbytes uses full precision optimizer states for weights below a minimum size). + +### Finetuning the 256 checkpoint + +These examples finetune on this [nouns](https://huggingface.co/datasets/m1guelpf/nouns) dataset. + +Example results: + +![noun1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun1.png) ![noun2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun2.png) ![noun3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun3.png) + + +#### Full finetuning + +Batch size: 8, Learning rate: 1e-4, Gives decent results in 750-1000 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 8 | 1 | 8 | 19.7 GB | +| 4 | 2 | 8 | 18.3 GB | +| 1 | 8 | 8 | 17.9 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 1e-4 \ + --pretrained_model_name_or_path huggingface/amused-256 \ + --instance_data_dataset 'm1guelpf/nouns' \ + --image_key image \ + --prompt_key text \ + --resolution 256 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \ + 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \ + 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \ + 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \ + 'a pixel art character with square red glasses' \ + 'a pixel art character' \ + 'square red glasses on a pixel art character' \ + 'square red glasses on a pixel art character with a baseball-shaped head' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +#### Full finetuning + 8 bit adam + +Note that this training config keeps the batch size low and the learning rate high to get results fast with low resources. However, due to 8 bit adam, it will diverge eventually. If you want to train for longer, you will have to up the batch size and lower the learning rate. + +Batch size: 16, Learning rate: 2e-5, Gives decent results in ~750 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 16 | 1 | 16 | 20.1 GB | +| 8 | 2 | 16 | 15.6 GB | +| 1 | 16 | 16 | 10.7 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 2e-5 \ + --use_8bit_adam \ + --pretrained_model_name_or_path huggingface/amused-256 \ + --instance_data_dataset 'm1guelpf/nouns' \ + --image_key image \ + --prompt_key text \ + --resolution 256 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \ + 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \ + 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \ + 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \ + 'a pixel art character with square red glasses' \ + 'a pixel art character' \ + 'square red glasses on a pixel art character' \ + 'square red glasses on a pixel art character with a baseball-shaped head' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +#### Full finetuning + lora + +Batch size: 16, Learning rate: 8e-4, Gives decent results in 1000-1250 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 16 | 1 | 16 | 14.1 GB | +| 8 | 2 | 16 | 10.1 GB | +| 1 | 16 | 16 | 6.5 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 8e-4 \ + --use_lora \ + --pretrained_model_name_or_path huggingface/amused-256 \ + --instance_data_dataset 'm1guelpf/nouns' \ + --image_key image \ + --prompt_key text \ + --resolution 256 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \ + 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \ + 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \ + 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \ + 'a pixel art character with square red glasses' \ + 'a pixel art character' \ + 'square red glasses on a pixel art character' \ + 'square red glasses on a pixel art character with a baseball-shaped head' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +### Finetuning the 512 checkpoint + +These examples finetune on this [minecraft](https://huggingface.co/monadical-labs/minecraft-preview) dataset. + +Example results: + +![minecraft1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft1.png) ![minecraft2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft2.png) ![minecraft3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft3.png) + +#### Full finetuning + +Batch size: 8, Learning rate: 8e-5, Gives decent results in 500-1000 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 8 | 1 | 8 | 24.2 GB | +| 4 | 2 | 8 | 19.7 GB | +| 1 | 8 | 8 | 16.99 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 8e-5 \ + --pretrained_model_name_or_path huggingface/amused-512 \ + --instance_data_dataset 'monadical-labs/minecraft-preview' \ + --prompt_prefix 'minecraft ' \ + --image_key image \ + --prompt_key text \ + --resolution 512 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'minecraft Avatar' \ + 'minecraft character' \ + 'minecraft' \ + 'minecraft president' \ + 'minecraft pig' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +#### Full finetuning + 8 bit adam + +Batch size: 8, Learning rate: 5e-6, Gives decent results in 500-1000 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 8 | 1 | 8 | 21.2 GB | +| 4 | 2 | 8 | 13.3 GB | +| 1 | 8 | 8 | 9.9 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 5e-6 \ + --pretrained_model_name_or_path huggingface/amused-512 \ + --instance_data_dataset 'monadical-labs/minecraft-preview' \ + --prompt_prefix 'minecraft ' \ + --image_key image \ + --prompt_key text \ + --resolution 512 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'minecraft Avatar' \ + 'minecraft character' \ + 'minecraft' \ + 'minecraft president' \ + 'minecraft pig' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +#### Full finetuning + lora + +Batch size: 8, Learning rate: 1e-4, Gives decent results in 500-1000 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 8 | 1 | 8 | 12.7 GB | +| 4 | 2 | 8 | 9.0 GB | +| 1 | 8 | 8 | 5.6 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 1e-4 \ + --use_lora \ + --pretrained_model_name_or_path huggingface/amused-512 \ + --instance_data_dataset 'monadical-labs/minecraft-preview' \ + --prompt_prefix 'minecraft ' \ + --image_key image \ + --prompt_key text \ + --resolution 512 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'minecraft Avatar' \ + 'minecraft character' \ + 'minecraft' \ + 'minecraft president' \ + 'minecraft pig' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +### Styledrop + +[Styledrop](https://arxiv.org/abs/2306.00983) is an efficient finetuning method for learning a new style from just one or very few images. It has an optional first stage to generate human picked additional training samples. The additional training samples can be used to augment the initial images. Our examples exclude the optional additional image selection stage and instead we just finetune on a single image. + +This is our example style image: +![example](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png) + +Download it to your local directory with +```sh +wget https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png +``` + +#### 256 + +Example results: + +![glowing_256_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_1.png) ![glowing_256_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_2.png) ![glowing_256_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_3.png) + +Learning rate: 4e-4, Gives decent results in 1500-2000 steps + +Memory used: 6.5 GB + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --mixed_precision fp16 \ + --report_to wandb \ + --use_lora \ + --pretrained_model_name_or_path huggingface/amused-256 \ + --train_batch_size 1 \ + --lr_scheduler constant \ + --learning_rate 4e-4 \ + --validation_prompts \ + 'A chihuahua walking on the street in [V] style' \ + 'A banana on the table in [V] style' \ + 'A church on the street in [V] style' \ + 'A tabby cat walking in the forest in [V] style' \ + --instance_data_image 'A mushroom in [V] style.png' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 100 \ + --resolution 256 +``` + +#### 512 + +Example results: + +![glowing_512_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_1.png) ![glowing_512_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_2.png) ![glowing_512_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_3.png) + +Learning rate: 1e-3, Lora alpha 1, Gives decent results in 1500-2000 steps + +Memory used: 5.6 GB + +``` +accelerate launch train_amused.py \ + --output_dir \ + --mixed_precision fp16 \ + --report_to wandb \ + --use_lora \ + --pretrained_model_name_or_path huggingface/amused-512 \ + --train_batch_size 1 \ + --lr_scheduler constant \ + --learning_rate 1e-3 \ + --validation_prompts \ + 'A chihuahua walking on the street in [V] style' \ + 'A banana on the table in [V] style' \ + 'A church on the street in [V] style' \ + 'A tabby cat walking in the forest in [V] style' \ + --instance_data_image 'A mushroom in [V] style.png' \ + --max_train_steps 100000 \ + --checkpointing_steps 500 \ + --validation_steps 100 \ + --resolution 512 \ + --lora_alpha 1 +``` \ No newline at end of file diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py new file mode 100644 index 000000000000..7ae7088d66d8 --- /dev/null +++ b/examples/amused/train_amused.py @@ -0,0 +1,972 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import logging +import math +import os +import shutil +from contextlib import nullcontext +from pathlib import Path + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import DataLoader, Dataset, default_collate +from torchvision import transforms +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +import diffusers.optimization +from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel +from diffusers.loaders import LoraLoaderMixin +from diffusers.utils import is_wandb_available + + +if is_wandb_available(): + import wandb + +logger = get_logger(__name__, log_level="INFO") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--instance_data_dataset", + type=str, + default=None, + required=False, + help="A Hugging Face dataset containing the training images", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--instance_data_image", type=str, default=None, required=False, help="A single training image" + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument("--ema_decay", type=float, default=0.9999) + parser.add_argument("--ema_update_after_step", type=int, default=0) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument( + "--output_dir", + type=str, + default="muse_training", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--logging_steps", + type=int, + default=50, + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more details" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=0.0003, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="wandb", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--validation_prompts", type=str, nargs="*") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument("--split_vae_encode", type=int, required=False, default=None) + parser.add_argument("--min_masking_rate", type=float, default=0.0) + parser.add_argument("--cond_dropout_prob", type=float, default=0.0) + parser.add_argument("--max_grad_norm", default=None, type=float, help="Max gradient norm.", required=False) + parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa") + parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa") + parser.add_argument("--lora_r", default=16, type=int) + parser.add_argument("--lora_alpha", default=32, type=int) + parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+") + parser.add_argument("--text_encoder_lora_r", default=16, type=int) + parser.add_argument("--text_encoder_lora_alpha", default=32, type=int) + parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+") + parser.add_argument("--train_text_encoder", action="store_true") + parser.add_argument("--image_key", type=str, required=False) + parser.add_argument("--prompt_key", type=str, required=False) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument("--prompt_prefix", type=str, required=False, default=None) + + args = parser.parse_args() + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + num_datasources = sum( + [x is not None for x in [args.instance_data_dir, args.instance_data_image, args.instance_data_dataset]] + ) + + if num_datasources != 1: + raise ValueError( + "provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`" + ) + + if args.instance_data_dir is not None: + if not os.path.exists(args.instance_data_dir): + raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}") + + if args.instance_data_image is not None: + if not os.path.exists(args.instance_data_image): + raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}") + + if args.instance_data_dataset is not None and (args.image_key is None or args.prompt_key is None): + raise ValueError("`--instance_data_dataset` requires setting `--image_key` and `--prompt_key`") + + return args + + +class InstanceDataRootDataset(Dataset): + def __init__( + self, + instance_data_root, + tokenizer, + size=512, + ): + self.size = size + self.tokenizer = tokenizer + self.instance_images_path = list(Path(instance_data_root).iterdir()) + + def __len__(self): + return len(self.instance_images_path) + + def __getitem__(self, index): + image_path = self.instance_images_path[index % len(self.instance_images_path)] + instance_image = Image.open(image_path) + rv = process_image(instance_image, self.size) + + prompt = os.path.splitext(os.path.basename(image_path))[0] + rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0] + return rv + + +class InstanceDataImageDataset(Dataset): + def __init__( + self, + instance_data_image, + train_batch_size, + size=512, + ): + self.value = process_image(Image.open(instance_data_image), size) + self.train_batch_size = train_batch_size + + def __len__(self): + # Needed so a full batch of the data can be returned. Otherwise will return + # batches of size 1 + return self.train_batch_size + + def __getitem__(self, index): + return self.value + + +class HuggingFaceDataset(Dataset): + def __init__( + self, + hf_dataset, + tokenizer, + image_key, + prompt_key, + prompt_prefix=None, + size=512, + ): + self.size = size + self.image_key = image_key + self.prompt_key = prompt_key + self.tokenizer = tokenizer + self.hf_dataset = hf_dataset + self.prompt_prefix = prompt_prefix + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, index): + item = self.hf_dataset[index] + + rv = process_image(item[self.image_key], self.size) + + prompt = item[self.prompt_key] + + if self.prompt_prefix is not None: + prompt = self.prompt_prefix + prompt + + rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0] + + return rv + + +def process_image(image, size): + image = exif_transpose(image) + + if not image.mode == "RGB": + image = image.convert("RGB") + + orig_height = image.height + orig_width = image.width + + image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image) + + c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size)) + image = transforms.functional.crop(image, c_top, c_left, size, size) + + image = transforms.ToTensor()(image) + + micro_conds = torch.tensor( + [orig_width, orig_height, c_top, c_left, 6.0], + ) + + return {"image": image, "micro_conds": micro_conds} + + +def tokenize_prompt(tokenizer, prompt): + return tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=77, + return_tensors="pt", + ).input_ids + + +def encode_prompt(text_encoder, input_ids): + outputs = text_encoder(input_ids, return_dict=True, output_hidden_states=True) + encoder_hidden_states = outputs.hidden_states[-2] + cond_embeds = outputs[0] + return encoder_hidden_states, cond_embeds + + +def main(args): + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if accelerator.is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if accelerator.is_main_process: + accelerator.init_trackers("amused", config=vars(copy.deepcopy(args))) + + if args.seed is not None: + set_seed(args.seed) + + # TODO - will have to fix loading if training text encoder + text_encoder = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant + ) + vq_model = VQModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant + ) + + if args.train_text_encoder: + if args.text_encoder_use_lora: + lora_config = LoraConfig( + r=args.text_encoder_lora_r, + lora_alpha=args.text_encoder_lora_alpha, + target_modules=args.text_encoder_lora_target_modules, + ) + text_encoder.add_adapter(lora_config) + text_encoder.train() + text_encoder.requires_grad_(True) + else: + text_encoder.eval() + text_encoder.requires_grad_(False) + + vq_model.requires_grad_(False) + + model = UVit2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + ) + + if args.use_lora: + lora_config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=args.lora_target_modules, + ) + model.add_adapter(lora_config) + + model.train() + + if args.gradient_checkpointing: + model.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.use_ema: + ema = EMAModel( + model.parameters(), + decay=args.ema_decay, + update_after_step=args.ema_update_after_step, + model_cls=UVit2DModel, + model_config=model.config, + ) + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + for model_ in models: + if isinstance(model_, type(accelerator.unwrap_model(model))): + if args.use_lora: + transformer_lora_layers_to_save = get_peft_model_state_dict(model_) + else: + model_.save_pretrained(os.path.join(output_dir, "transformer")) + elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))): + if args.text_encoder_use_lora: + text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_) + else: + model_.save_pretrained(os.path.join(output_dir, "text_encoder")) + else: + raise ValueError(f"unexpected save model: {model_.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None: + LoraLoaderMixin.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + if args.use_ema: + ema.save_pretrained(os.path.join(output_dir, "ema_model")) + + def load_model_hook(models, input_dir): + transformer = None + text_encoder_ = None + + while len(models) > 0: + model_ = models.pop() + + if isinstance(model_, type(accelerator.unwrap_model(model))): + if args.use_lora: + transformer = model_ + else: + load_model = UVit2DModel.from_pretrained(os.path.join(input_dir, "transformer")) + model_.load_state_dict(load_model.state_dict()) + del load_model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + if args.text_encoder_use_lora: + text_encoder_ = model_ + else: + load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder")) + model_.load_state_dict(load_model.state_dict()) + del load_model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer is not None or text_encoder_ is not None: + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ + ) + LoraLoaderMixin.load_lora_into_transformer( + lora_state_dict, network_alphas=network_alphas, transformer=transformer + ) + + if args.use_ema: + load_from = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=UVit2DModel) + ema.load_state_dict(load_from.state_dict()) + del load_from + + accelerator.register_load_state_pre_hook(load_model_hook) + accelerator.register_save_state_pre_hook(save_model_hook) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + ) + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.adam_weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + if args.train_text_encoder: + optimizer_grouped_parameters.append( + {"params": text_encoder.parameters(), "weight_decay": args.adam_weight_decay} + ) + + optimizer = optimizer_cls( + optimizer_grouped_parameters, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + if args.instance_data_dir is not None: + dataset = InstanceDataRootDataset( + instance_data_root=args.instance_data_dir, + tokenizer=tokenizer, + size=args.resolution, + ) + elif args.instance_data_image is not None: + dataset = InstanceDataImageDataset( + instance_data_image=args.instance_data_image, + train_batch_size=args.train_batch_size, + size=args.resolution, + ) + elif args.instance_data_dataset is not None: + dataset = HuggingFaceDataset( + hf_dataset=load_dataset(args.instance_data_dataset, split="train"), + tokenizer=tokenizer, + image_key=args.image_key, + prompt_key=args.prompt_key, + prompt_prefix=args.prompt_prefix, + size=args.resolution, + ) + else: + assert False + + train_dataloader = DataLoader( + dataset, + batch_size=args.train_batch_size, + shuffle=True, + num_workers=args.dataloader_num_workers, + collate_fn=default_collate, + ) + train_dataloader.num_batches = len(train_dataloader) + + lr_scheduler = diffusers.optimization.get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + + logger.info("Preparing model, optimizer and dataloaders") + + if args.train_text_encoder: + model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare( + model, optimizer, lr_scheduler, train_dataloader, text_encoder + ) + else: + model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare( + model, optimizer, lr_scheduler, train_dataloader + ) + + train_dataloader.num_batches = len(train_dataloader) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if not args.train_text_encoder: + text_encoder.to(device=accelerator.device, dtype=weight_dtype) + + vq_model.to(device=accelerator.device) + + if args.use_ema: + ema.to(accelerator.device) + + with nullcontext() if args.train_text_encoder else torch.no_grad(): + empty_embeds, empty_clip_embeds = encode_prompt( + text_encoder, tokenize_prompt(tokenizer, "").to(text_encoder.device, non_blocking=True) + ) + + # There is a single image, we can just pre-encode the single prompt + if args.instance_data_image is not None: + prompt = os.path.splitext(os.path.basename(args.instance_data_image))[0] + encoder_hidden_states, cond_embeds = encode_prompt( + text_encoder, tokenize_prompt(tokenizer, prompt).to(text_encoder.device, non_blocking=True) + ) + encoder_hidden_states = encoder_hidden_states.repeat(args.train_batch_size, 1, 1) + cond_embeds = cond_embeds.repeat(args.train_batch_size, 1) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + # Afterwards we recalculate our number of training epochs. + # Note: We are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num training steps = {args.max_train_steps}") + logger.info(f" Instantaneous batch size per device = { args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + + resume_from_checkpoint = args.resume_from_checkpoint + if resume_from_checkpoint: + if resume_from_checkpoint == "latest": + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + if len(dirs) > 0: + resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1]) + else: + resume_from_checkpoint = None + + if resume_from_checkpoint is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + else: + accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}") + + if resume_from_checkpoint is None: + global_step = 0 + first_epoch = 0 + else: + accelerator.load_state(resume_from_checkpoint) + global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + + # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + for epoch in range(first_epoch, num_train_epochs): + for batch in train_dataloader: + with torch.no_grad(): + micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True) + pixel_values = batch["image"].to(accelerator.device, non_blocking=True) + + batch_size = pixel_values.shape[0] + + split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size + num_splits = math.ceil(batch_size / split_batch_size) + image_tokens = [] + for i in range(num_splits): + start_idx = i * split_batch_size + end_idx = min((i + 1) * split_batch_size, batch_size) + bs = pixel_values.shape[0] + image_tokens.append( + vq_model.quantize(vq_model.encode(pixel_values[start_idx:end_idx]).latents)[2][2].reshape( + bs, -1 + ) + ) + image_tokens = torch.cat(image_tokens, dim=0) + + batch_size, seq_len = image_tokens.shape + + timesteps = torch.rand(batch_size, device=image_tokens.device) + mask_prob = torch.cos(timesteps * math.pi * 0.5) + mask_prob = mask_prob.clip(args.min_masking_rate) + + num_token_masked = (seq_len * mask_prob).round().clamp(min=1) + batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1) + mask = batch_randperm < num_token_masked.unsqueeze(-1) + + mask_id = accelerator.unwrap_model(model).config.vocab_size - 1 + input_ids = torch.where(mask, mask_id, image_tokens) + labels = torch.where(mask, image_tokens, -100) + + if args.cond_dropout_prob > 0.0: + assert encoder_hidden_states is not None + + batch_size = encoder_hidden_states.shape[0] + + mask = ( + torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1) + < args.cond_dropout_prob + ) + + empty_embeds_ = empty_embeds.expand(batch_size, -1, -1) + encoder_hidden_states = torch.where( + (encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_ + ) + + empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1) + cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_) + + bs = input_ids.shape[0] + vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1) + resolution = args.resolution // vae_scale_factor + input_ids = input_ids.reshape(bs, resolution, resolution) + + if "prompt_input_ids" in batch: + with nullcontext() if args.train_text_encoder else torch.no_grad(): + encoder_hidden_states, cond_embeds = encode_prompt( + text_encoder, batch["prompt_input_ids"].to(accelerator.device, non_blocking=True) + ) + + # Train Step + with accelerator.accumulate(model): + codebook_size = accelerator.unwrap_model(model).config.codebook_size + + logits = ( + model( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + micro_conds=micro_conds, + pooled_text_emb=cond_embeds, + ) + .reshape(bs, codebook_size, -1) + .permute(0, 2, 1) + .reshape(-1, codebook_size) + ) + + loss = F.cross_entropy( + logits, + labels.view(-1), + ignore_index=-100, + reduction="mean", + ) + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean() + + accelerator.backward(loss) + + if args.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema.step(model.parameters()) + + if (global_step + 1) % args.logging_steps == 0: + logs = { + "step_loss": avg_loss.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss: {avg_loss.item():0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + if (global_step + 1) % args.checkpointing_steps == 0: + save_checkpoint(args, accelerator, global_step + 1) + + if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process: + if args.use_ema: + ema.store(model.parameters()) + ema.copy_to(model.parameters()) + + with torch.no_grad(): + logger.info("Generating images...") + + model.eval() + + if args.train_text_encoder: + text_encoder.eval() + + scheduler = AmusedScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + variant=args.variant, + ) + + pipe = AmusedPipeline( + transformer=accelerator.unwrap_model(model), + tokenizer=tokenizer, + text_encoder=text_encoder, + vqvae=vq_model, + scheduler=scheduler, + ) + + pil_images = pipe(prompt=args.validation_prompts).images + wandb_images = [ + wandb.Image(image, caption=args.validation_prompts[i]) + for i, image in enumerate(pil_images) + ] + + wandb.log({"generated_images": wandb_images}, step=global_step + 1) + + model.train() + + if args.train_text_encoder: + text_encoder.train() + + if args.use_ema: + ema.restore(model.parameters()) + + global_step += 1 + + # Stop training if max steps is reached + if global_step >= args.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(args, accelerator, global_step) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + if args.use_ema: + ema.copy_to(model.parameters()) + model.save_pretrained(args.output_dir) + + accelerator.end_training() + + +def save_checkpoint(args, accelerator, global_step): + output_dir = args.output_dir + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and args.checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + +if __name__ == "__main__": + main(parse_args()) diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py new file mode 100644 index 000000000000..fdddbef7cd65 --- /dev/null +++ b/scripts/convert_amused.py @@ -0,0 +1,523 @@ +import inspect +import os +from argparse import ArgumentParser + +import numpy as np +import torch +from muse import MaskGiTUViT, VQGANModel +from muse import PipelineMuse as OldPipelineMuse +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import VQModel +from diffusers.models.attention_processor import AttnProcessor +from diffusers.models.uvit_2d import UVit2DModel +from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline +from diffusers.schedulers import AmusedScheduler + + +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) + +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(True) + +# Enable CUDNN deterministic mode +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +torch.backends.cuda.matmul.allow_tf32 = False + +device = "cuda" + + +def main(): + args = ArgumentParser() + args.add_argument("--model_256", action="store_true") + args.add_argument("--write_to", type=str, required=False, default=None) + args.add_argument("--transformer_path", type=str, required=False, default=None) + args = args.parse_args() + + transformer_path = args.transformer_path + subfolder = "transformer" + + if transformer_path is None: + if args.model_256: + transformer_path = "openMUSE/muse-256" + else: + transformer_path = ( + "../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/" + ) + subfolder = None + + old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder) + + old_transformer.to(device) + + old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae") + old_vae.to(device) + + vqvae = make_vqvae(old_vae) + + tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder") + + text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder") + text_encoder.to(device) + + transformer = make_transformer(old_transformer, args.model_256) + + scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id) + + new_pipe = AmusedPipeline( + vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler + ) + + old_pipe = OldPipelineMuse( + vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer + ) + old_pipe.to(device) + + if args.model_256: + transformer_seq_len = 256 + orig_size = (256, 256) + else: + transformer_seq_len = 1024 + orig_size = (512, 512) + + old_out = old_pipe( + "dog", + generator=torch.Generator(device).manual_seed(0), + transformer_seq_len=transformer_seq_len, + orig_size=orig_size, + timesteps=12, + )[0] + + new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0] + + old_out = np.array(old_out) + new_out = np.array(new_out) + + diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64)) + + # assert diff diff.sum() == 0 + print("skipping pipeline full equivalence check") + + print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}") + + if args.model_256: + assert diff.max() <= 3 + assert diff.sum() / diff.size < 0.7 + else: + assert diff.max() <= 1 + assert diff.sum() / diff.size < 0.4 + + if args.write_to is not None: + new_pipe.save_pretrained(args.write_to) + + +def make_transformer(old_transformer, model_256): + args = dict(old_transformer.config) + force_down_up_sample = args["force_down_up_sample"] + + signature = inspect.signature(UVit2DModel.__init__) + + args_ = { + "downsample": force_down_up_sample, + "upsample": force_down_up_sample, + "block_out_channels": args["block_out_channels"][0], + "sample_size": 16 if model_256 else 32, + } + + for s in list(signature.parameters.keys()): + if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]: + continue + + args_[s] = args[s] + + new_transformer = UVit2DModel(**args_) + new_transformer.to(device) + + new_transformer.set_attn_processor(AttnProcessor()) + + state_dict = old_transformer.state_dict() + + state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight") + state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight") + + for i in range(22): + state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.attn_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight" + ) + + state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.query.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.key.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.value.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.out.weight" + ) + + state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight" + ) + + state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.query.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.key.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.value.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.out.weight" + ) + + state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight" + ) + + wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight") + wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight") + proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0) + state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight + + state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight") + + if force_down_up_sample: + state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight") + state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight") + + state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight") + state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight") + + state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight") + + for i in range(3): + state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.norm.norm.weight" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.0.weight" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.2.beta" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.4.weight" + ) + state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight" + ) + + state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.query.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.key.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.value.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.out.weight" + ) + + state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight" + ) + + state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.norm.norm.weight" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.0.weight" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.2.beta" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.4.weight" + ) + state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight" + ) + + state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.query.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.key.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.value.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.out.weight" + ) + + state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight" + ) + + for key in list(state_dict.keys()): + if key.startswith("up_blocks.0"): + key_ = "up_block." + ".".join(key.split(".")[2:]) + state_dict[key_] = state_dict.pop(key) + + if key.startswith("down_blocks.0"): + key_ = "down_block." + ".".join(key.split(".")[2:]) + state_dict[key_] = state_dict.pop(key) + + new_transformer.load_state_dict(state_dict) + + input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device) + encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device) + cond_embeds = torch.randn((1, 768), device=old_transformer.device) + micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device) + + old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds) + old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2) + + new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds) + + # NOTE: these differences are solely due to using the geglu block that has a single linear layer of + # double output dimension instead of two different linear layers + max_diff = (old_out - new_out).abs().max() + total_diff = (old_out - new_out).abs().sum() + print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}") + assert max_diff < 0.01 + assert total_diff < 1500 + + return new_transformer + + +def make_vqvae(old_vae): + new_vae = VQModel( + act_fn="silu", + block_out_channels=[128, 256, 256, 512, 768], + down_block_types=[ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=64, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=8192, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + new_vae.to(device) + + # fmt: off + + new_state_dict = {} + + old_state_dict = old_vae.state_dict() + + new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight") + new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias") + + convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0") + convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1") + convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2") + convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3") + convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4") + + new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight") + new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias") + new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight") + new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias") + new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight") + new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias") + new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight") + new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias") + new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight") + new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias") + new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight") + new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias") + new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight") + new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias") + new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight") + new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias") + new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight") + new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias") + new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight") + new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias") + new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight") + new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias") + new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight") + new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight") + new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias") + new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight") + new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias") + new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight") + new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias") + new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight") + new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias") + new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight") + new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias") + new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight") + new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias") + new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight") + new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias") + new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight") + new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias") + new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight") + new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias") + new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight") + new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias") + + convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4") + convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3") + convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2") + convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1") + convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0") + + new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight") + new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias") + new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight") + new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias") + + # fmt: on + + assert len(old_state_dict.keys()) == 0 + + new_vae.load_state_dict(new_state_dict) + + input = torch.randn((1, 3, 512, 512), device=device) + input = input.clamp(-1, 1) + + old_encoder_output = old_vae.quant_conv(old_vae.encoder(input)) + new_encoder_output = new_vae.quant_conv(new_vae.encoder(input)) + assert (old_encoder_output == new_encoder_output).all() + + old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output)) + new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output)) + + # assert (old_decoder_output == new_decoder_output).all() + print("kipping vae decoder equivalence check") + print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}") + + old_output = old_vae(input)[0] + new_output = new_vae(input)[0] + + # assert (old_output == new_output).all() + print("skipping full vae equivalence check") + print(f"vae full diff { (old_output - new_output).float().abs().sum()}") + + return new_vae + + +def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to): + # fmt: off + + new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias") + + if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias") + + new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias") + + if f"{prefix_from}.downsample.conv.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight") + new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias") + + if f"{prefix_from}.upsample.conv.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight") + new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias") + + if f"{prefix_from}.block.2.norm1.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias") + + # fmt: on + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c43000e27b82..10c5b0f46565 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -95,6 +95,7 @@ "UNet3DConditionModel", "UNetMotionModel", "UNetSpatioTemporalConditionModel", + "UVit2DModel", "VQModel", ] ) @@ -131,6 +132,7 @@ ) _import_structure["schedulers"].extend( [ + "AmusedScheduler", "CMStochasticIterativeScheduler", "DDIMInverseScheduler", "DDIMParallelScheduler", @@ -202,6 +204,9 @@ [ "AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline", + "AmusedImg2ImgPipeline", + "AmusedInpaintPipeline", + "AmusedPipeline", "AnimateDiffPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", @@ -472,6 +477,7 @@ UNet3DConditionModel, UNetMotionModel, UNetSpatioTemporalConditionModel, + UVit2DModel, VQModel, ) from .optimization import ( @@ -506,6 +512,7 @@ ScoreSdeVePipeline, ) from .schedulers import ( + AmusedScheduler, CMStochasticIterativeScheduler, DDIMInverseScheduler, DDIMParallelScheduler, @@ -560,6 +567,9 @@ from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, + AmusedImg2ImgPipeline, + AmusedInpaintPipeline, + AmusedPipeline, AnimateDiffPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index c1c3a260ec11..fc50c52e412b 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -59,6 +59,7 @@ TEXT_ENCODER_NAME = "text_encoder" UNET_NAME = "unet" +TRANSFORMER_NAME = "transformer" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" @@ -74,6 +75,7 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME + transformer_name = TRANSFORMER_NAME num_fused_loras = 0 def load_lora_weights( @@ -661,6 +663,89 @@ def load_lora_into_text_encoder( _pipeline.enable_sequential_cpu_offload() # Unsafe code /> + @classmethod + def load_lora_into_transformer( + cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + See `LoRALinearLayer` for more details. + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] + network_alphas = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + if len(state_dict.keys()) > 0: + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. @@ -786,6 +871,7 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -820,8 +906,10 @@ def pack_weights(layers, prefix): layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} return layers_state_dict - if not (unet_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") + if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers): + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`." + ) if unet_lora_layers: state_dict.update(pack_weights(unet_lora_layers, "unet")) @@ -829,6 +917,9 @@ def pack_weights(layers, prefix): if text_encoder_lora_layers: state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + if transformer_lora_layers: + state_dict.update(pack_weights(transformer_lora_layers, "transformer")) + # Save the model cls.write_lora_layers( state_dict=state_dict, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7487bbf2f98e..6e7fe72bc949 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -47,6 +47,7 @@ _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] + _import_structure["uvit_2d"] = ["UVit2DModel"] _import_structure["vq_model"] = ["VQModel"] if is_flax_available(): @@ -81,6 +82,7 @@ from .unet_kandinsky3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + from .uvit_2d import UVit2DModel from .vq_model import VQModel if is_flax_available(): diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 08faaaf3e5bf..a34d7421b4f9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -14,6 +14,7 @@ from typing import Any, Dict, Optional import torch +import torch.nn.functional as F from torch import nn from ..utils import USE_PEFT_BACKEND @@ -22,7 +23,7 @@ from .attention_processor import Attention from .embeddings import SinusoidalPositionalEmbedding from .lora import LoRACompatibleLinear -from .normalization import AdaLayerNorm, AdaLayerNormZero +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm def _chunked_feed_forward( @@ -148,6 +149,11 @@ def __init__( attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, ): super().__init__() self.only_cross_attention = only_cross_attention @@ -156,6 +162,7 @@ def __init__( self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm_single = norm_type == "ada_norm_single" self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( @@ -179,6 +186,15 @@ def __init__( self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_zero: self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) @@ -190,6 +206,7 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, + out_bias=attention_out_bias, ) # 2. Cross-Attn @@ -197,11 +214,20 @@ def __init__( # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - ) + if self.use_ada_layer_norm: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, @@ -210,20 +236,32 @@ def __init__( dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, + out_bias=attention_out_bias, ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward - if not self.use_ada_layer_norm_single: - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + if self.use_ada_layer_norm_continuous: + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + elif not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, ) # 4. Fuser @@ -252,6 +290,7 @@ def forward( timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.FloatTensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention @@ -265,6 +304,8 @@ def forward( ) elif self.use_layer_norm: norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif self.use_ada_layer_norm_single: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) @@ -314,6 +355,8 @@ def forward( # For PixArt norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) else: raise ValueError("Incorrect norm") @@ -329,7 +372,9 @@ def forward( hidden_states = attn_output + hidden_states # 4. Feed-forward - if not self.use_ada_layer_norm_single: + if self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.use_ada_layer_norm_single: norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: @@ -490,6 +535,78 @@ def forward( return hidden_states +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + class FeedForward(nn.Module): r""" A feed-forward layer. @@ -512,10 +629,12 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + inner_dim=None, bias: bool = True, ): super().__init__() - inner_dim = int(dim * mult) + if inner_dim is None: + inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 9ed0232e6983..3f1643bc50ef 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -77,6 +77,7 @@ def __init__( norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, + mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block @@ -124,6 +125,7 @@ def __init__( attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, + add_attention=mid_block_add_attention, ) # out @@ -213,6 +215,7 @@ def __init__( norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial + mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block @@ -240,6 +243,7 @@ def __init__( attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, + add_attention=mid_block_add_attention, ) # up diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index d39bae22e831..ecab1fffe2f0 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -20,6 +20,7 @@ from ..utils import USE_PEFT_BACKEND from .lora import LoRACompatibleConv +from .normalization import RMSNorm from .upsampling import upfirdn2d_native @@ -89,6 +90,11 @@ def __init__( out_channels: Optional[int] = None, padding: int = 1, name: str = "conv", + kernel_size=3, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, ): super().__init__() self.channels = channels @@ -99,8 +105,19 @@ def __init__( self.name = name conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + if use_conv: - conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = conv_cls( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) @@ -117,6 +134,9 @@ def __init__( def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index db68591bdb44..7e98f77baf26 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -197,11 +197,12 @@ def __init__( out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, + sample_proj_bias=True, ): super().__init__() linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear - self.linear_1 = linear_cls(in_channels, time_embed_dim) + self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) @@ -214,7 +215,7 @@ def __init__( time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out) + self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 25af4d853b86..7f6e2c145435 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from ..utils import is_torch_version from .activations import get_activation from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings @@ -146,3 +148,107 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift return x + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +if is_torch_version(">=", "2.1.0"): + LayerNorm = nn.LayerNorm +else: + # Has optional bias parameter compared to torch layer norm + # TODO: replace with torch layernorm once min required torch version >= 2.1 + class LayerNorm(nn.Module): + def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) if bias else None + else: + self.weight = None + self.bias = None + + def forward(self, input): + return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +class GlobalResponseNorm(nn.Module): + # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * nx) + self.beta + x diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 542a5d9d1eb0..1e4e61201059 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -20,6 +20,7 @@ from ..utils import USE_PEFT_BACKEND from .lora import LoRACompatibleConv +from .normalization import RMSNorm class Upsample1D(nn.Module): @@ -95,6 +96,13 @@ def __init__( use_conv_transpose: bool = False, out_channels: Optional[int] = None, name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, ): super().__init__() self.channels = channels @@ -102,13 +110,29 @@ def __init__( self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name + self.interpolate = interpolate conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + conv = None if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + if kernel_size is None: + kernel_size = 4 + conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) elif use_conv: - conv = conv_cls(self.channels, self.out_channels, 3, padding=1) + if kernel_size is None: + kernel_size = 3 + conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": @@ -124,6 +148,9 @@ def forward( ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if self.use_conv_transpose: return self.conv(hidden_states) @@ -140,10 +167,11 @@ def forward( # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + if self.interpolate: + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: diff --git a/src/diffusers/models/uvit_2d.py b/src/diffusers/models/uvit_2d.py new file mode 100644 index 000000000000..14dd8aee8e89 --- /dev/null +++ b/src/diffusers/models/uvit_2d.py @@ -0,0 +1,471 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from .attention import BasicTransformerBlock, SkipFFTransformerBlock +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from .embeddings import TimestepEmbedding, get_timestep_embedding +from .modeling_utils import ModelMixin +from .normalization import GlobalResponseNorm, RMSNorm +from .resnet import Downsample2D, Upsample2D + + +class UVit2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + # global config + hidden_size: int = 1024, + use_bias: bool = False, + hidden_dropout: float = 0.0, + # conditioning dimensions + cond_embed_dim: int = 768, + micro_cond_encode_dim: int = 256, + micro_cond_embed_dim: int = 1280, + encoder_hidden_size: int = 768, + # num tokens + vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded + codebook_size: int = 8192, + # `UVit2DConvEmbed` + in_channels: int = 768, + block_out_channels: int = 768, + num_res_blocks: int = 3, + downsample: bool = False, + upsample: bool = False, + block_num_heads: int = 12, + # `TransformerLayer` + num_hidden_layers: int = 22, + num_attention_heads: int = 16, + # `Attention` + attention_dropout: float = 0.0, + # `FeedForward` + intermediate_size: int = 2816, + # `Norm` + layer_norm_eps: float = 1e-6, + ln_elementwise_affine: bool = True, + sample_size: int = 64, + ): + super().__init__() + + self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) + self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + + self.embed = UVit2DConvEmbed( + in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias + ) + + self.cond_embed = TimestepEmbedding( + micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias + ) + + self.down_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample, + False, + ) + + self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) + self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) + + self.transformer_layers = nn.ModuleList( + [ + BasicTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=hidden_size // num_attention_heads, + dropout=hidden_dropout, + cross_attention_dim=hidden_size, + attention_bias=use_bias, + norm_type="ada_norm_continuous", + ada_norm_continous_conditioning_embedding_dim=hidden_size, + norm_elementwise_affine=ln_elementwise_affine, + norm_eps=layer_norm_eps, + ada_norm_bias=use_bias, + ff_inner_dim=intermediate_size, + ff_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_hidden_layers) + ] + ) + + self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) + + self.up_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample=False, + upsample=upsample, + ) + + self.mlm_layer = ConvMlmLayer( + block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + pass + + def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): + encoder_hidden_states = self.encoder_proj(encoder_hidden_states) + encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) + + micro_cond_embeds = get_timestep_embedding( + micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) + + pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) + pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) + pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) + + hidden_states = self.embed(input_ids) + + hidden_states = self.down_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) + + hidden_states = self.project_to_hidden_norm(hidden_states) + hidden_states = self.project_to_hidden(hidden_states) + + for layer in self.transformer_layers: + if self.training and self.gradient_checkpointing: + + def layer_(*args): + return checkpoint(layer, *args) + + else: + layer_ = layer + + hidden_states = layer_( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, + ) + + hidden_states = self.project_from_hidden_norm(hidden_states) + hidden_states = self.project_from_hidden(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + + hidden_states = self.up_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + logits = self.mlm_layer(hidden_states) + + return logits + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + +class UVit2DConvEmbed(nn.Module): + def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): + super().__init__() + self.embeddings = nn.Embedding(vocab_size, in_channels) + self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) + self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) + + def forward(self, input_ids): + embeddings = self.embeddings(input_ids) + embeddings = self.layer_norm(embeddings) + embeddings = embeddings.permute(0, 3, 1, 2) + embeddings = self.conv(embeddings) + return embeddings + + +class UVitBlock(nn.Module): + def __init__( + self, + channels, + num_res_blocks: int, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample: bool, + upsample: bool, + ): + super().__init__() + + if downsample: + self.downsample = Downsample2D( + channels, + use_conv=True, + padding=0, + name="Conv2d_0", + kernel_size=2, + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + ) + else: + self.downsample = None + + self.res_blocks = nn.ModuleList( + [ + ConvNextBlock( + channels, + layer_norm_eps, + ln_elementwise_affine, + use_bias, + hidden_dropout, + hidden_size, + ) + for i in range(num_res_blocks) + ] + ) + + self.attention_blocks = nn.ModuleList( + [ + SkipFFTransformerBlock( + channels, + block_num_heads, + channels // block_num_heads, + hidden_size, + use_bias, + attention_dropout, + channels, + attention_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_res_blocks) + ] + ) + + if upsample: + self.upsample = Upsample2D( + channels, + use_conv_transpose=True, + kernel_size=2, + padding=0, + name="conv", + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + interpolate=False, + ) + else: + self.upsample = None + + def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): + if self.downsample is not None: + x = self.downsample(x) + + for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): + x = res_block(x, pooled_text_emb) + + batch_size, channels, height, width = x.shape + x = x.view(batch_size, channels, height * width).permute(0, 2, 1) + x = attention_block( + x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs + ) + x = x.permute(0, 2, 1).view(batch_size, channels, height, width) + + if self.upsample is not None: + x = self.upsample(x) + + return x + + +class ConvNextBlock(nn.Module): + def __init__( + self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 + ): + super().__init__() + self.depthwise = nn.Conv2d( + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=use_bias, + ) + self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) + self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) + self.channelwise_act = nn.GELU() + self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) + self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) + self.channelwise_dropout = nn.Dropout(hidden_dropout) + self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) + + def forward(self, x, cond_embeds): + x_res = x + + x = self.depthwise(x) + + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + + x = self.channelwise_linear_1(x) + x = self.channelwise_act(x) + x = self.channelwise_norm(x) + x = self.channelwise_linear_2(x) + x = self.channelwise_dropout(x) + + x = x.permute(0, 3, 1, 2) + + x = x + x_res + + scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) + x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] + + return x + + +class ConvMlmLayer(nn.Module): + def __init__( + self, + block_out_channels: int, + in_channels: int, + use_bias: bool, + ln_elementwise_affine: bool, + layer_norm_eps: float, + codebook_size: int, + ): + super().__init__() + self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) + self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) + self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) + + def forward(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + logits = self.conv2(hidden_states) + return logits diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index bfe62ec863b3..5695d7258f2e 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -88,6 +88,9 @@ def __init__( vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + lookup_from_codebook=False, + force_upcast=False, ): super().__init__() @@ -101,6 +104,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=False, + mid_block_add_attention=mid_block_add_attention, ) vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels @@ -119,6 +123,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, norm_type=norm_type, + mid_block_add_attention=mid_block_add_attention, ) @apply_forward_hook @@ -133,11 +138,13 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOut @apply_forward_hook def decode( - self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None ) -> Union[DecoderOutput, torch.FloatTensor]: # also go through quantization layer if not force_not_quantize: quant, _, _ = self.quantize(h) + elif self.config.lookup_from_codebook: + quant = self.quantize.get_codebook_entry(h, shape) else: quant = h quant2 = self.post_quant_conv(quant) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 92839e596978..3bf67dfc1cdc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -108,6 +108,7 @@ "VersatileDiffusionTextToImagePipeline", ] ) + _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["animatediff"] = ["AnimateDiffPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ @@ -342,6 +343,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_objects import * else: + from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline from .animatediff import AnimateDiffPipeline from .audioldm import AudioLDMPipeline from .audioldm2 import ( diff --git a/src/diffusers/pipelines/amused/__init__.py b/src/diffusers/pipelines/amused/__init__.py new file mode 100644 index 000000000000..3c4d07a426b5 --- /dev/null +++ b/src/diffusers/pipelines/amused/__init__.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedImg2ImgPipeline, + AmusedInpaintPipeline, + AmusedPipeline, + ) + + _dummy_objects.update( + { + "AmusedPipeline": AmusedPipeline, + "AmusedImg2ImgPipeline": AmusedImg2ImgPipeline, + "AmusedInpaintPipeline": AmusedInpaintPipeline, + } + ) +else: + _import_structure["pipeline_amused"] = ["AmusedPipeline"] + _import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"] + _import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedPipeline, + ) + else: + from .pipeline_amused import AmusedPipeline + from .pipeline_amused_img2img import AmusedImg2ImgPipeline + from .pipeline_amused_inpaint import AmusedInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/amused/pipeline_amused.py b/src/diffusers/pipelines/amused/pipeline_amused.py new file mode 100644 index 000000000000..e93569c2302f --- /dev/null +++ b/src/diffusers/pipelines/amused/pipeline_amused.py @@ -0,0 +1,328 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedPipeline + + >>> pipe = AmusedPipeline.from_pretrained( + ... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class AmusedPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.IntTensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.IntTensor`, *optional*): + Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image + gneration. If not provided, the starting latents will be completely masked. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.FloatTensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ + and the micro-conditioning section of https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if height is None: + height = self.transformer.config.sample_size * self.vae_scale_factor + + if width is None: + width = self.transformer.config.sample_size * self.vae_scale_factor + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if latents is None: + latents = torch.full( + shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device + ) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + + num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep in enumerate(self.scheduler.timesteps): + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py new file mode 100644 index 000000000000..694b7c2229f3 --- /dev/null +++ b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py @@ -0,0 +1,347 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedImg2ImgPipeline.from_pretrained( + ... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "winter mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> image = pipe(prompt, input_image).images[0] + ``` +""" + + +class AmusedImg2ImgPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + image: PipelineImageInput = None, + strength: float = 0.5, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.5): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.FloatTensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ + and the micro-conditioning section of https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + latents = self.scheduler.add_noise( + latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator + ) + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py new file mode 100644 index 000000000000..a4c5644c961c --- /dev/null +++ b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py @@ -0,0 +1,378 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedInpaintPipeline.from_pretrained( + ... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "fall mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> mask = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ... ) + ... .resize((512, 512)) + ... .convert("L") + ... ) + >>> pipe(prompt, input_image, mask).images[0].save("out.png") + ``` +""" + + +class AmusedInpaintPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + do_resize=True, + ) + self.scheduler.register_to_config(masking_schedule="linear") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + strength: float = 1.0, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.FloatTensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ + and the micro-conditioning section of https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + + mask = self.mask_processor.preprocess( + mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor + ) + mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device) + latents[mask] = self.scheduler.config.mask_token_id + + starting_mask_ratio = mask.sum() / latents.numel() + + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + starting_mask_ratio=starting_mask_ratio, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 40c435dd5637..e908ba87acdd 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -39,6 +39,7 @@ else: _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] + _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] @@ -129,6 +130,7 @@ from ..utils.dummy_pt_objects import * # noqa F403 else: from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler + from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler diff --git a/src/diffusers/schedulers/scheduling_amused.py b/src/diffusers/schedulers/scheduling_amused.py new file mode 100644 index 000000000000..51fbe6a4dc7d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_amused.py @@ -0,0 +1,162 @@ +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +def gumbel_noise(t, generator=None): + device = generator.device if generator is not None else t.device + noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device) + return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20)) + + +def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): + confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator) + sorted_confidence = torch.sort(confidence, dim=-1).values + cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) + masking = confidence < cut_off + return masking + + +@dataclass +class AmusedSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: torch.FloatTensor = None + + +class AmusedScheduler(SchedulerMixin, ConfigMixin): + order = 1 + + temperatures: torch.Tensor + + @register_to_config + def __init__( + self, + mask_token_id: int, + masking_schedule: str = "cosine", + ): + self.temperatures = None + self.timesteps = None + + def set_timesteps( + self, + num_inference_steps: int, + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + device: Union[str, torch.device] = None, + ): + self.timesteps = torch.arange(num_inference_steps, device=device).flip(0) + + if isinstance(temperature, (tuple, list)): + self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device) + else: + self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.long, + sample: torch.LongTensor, + starting_mask_ratio: int = 1, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[AmusedSchedulerOutput, Tuple]: + two_dim_input = sample.ndim == 3 and model_output.ndim == 4 + + if two_dim_input: + batch_size, codebook_size, height, width = model_output.shape + sample = sample.reshape(batch_size, height * width) + model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1) + + unknown_map = sample == self.config.mask_token_id + + probs = model_output.softmax(dim=-1) + + device = probs.device + probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU + if probs_.device.type == "cpu" and probs_.dtype != torch.float32: + probs_ = probs_.float() # multinomial is not implemented for cpu half precision + probs_ = probs_.reshape(-1, probs.size(-1)) + pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device) + pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1]) + pred_original_sample = torch.where(unknown_map, pred_original_sample, sample) + + if timestep == 0: + prev_sample = pred_original_sample + else: + seq_len = sample.shape[1] + step_idx = (self.timesteps == timestep).nonzero() + ratio = (step_idx + 1) / len(self.timesteps) + + if self.config.masking_schedule == "cosine": + mask_ratio = torch.cos(ratio * math.pi / 2) + elif self.config.masking_schedule == "linear": + mask_ratio = 1 - ratio + else: + raise ValueError(f"unknown masking schedule {self.config.masking_schedule}") + + mask_ratio = starting_mask_ratio * mask_ratio + + mask_len = (seq_len * mask_ratio).floor() + # do not mask more than amount previously masked + mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + # mask at least one + mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len) + + selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0] + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + + masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator) + + # Masks tokens with lower confidence. + prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample) + + if two_dim_input: + prev_sample = prev_sample.reshape(batch_size, height, width) + pred_original_sample = pred_original_sample.reshape(batch_size, height, width) + + if not return_dict: + return (prev_sample, pred_original_sample) + + return AmusedSchedulerOutput(prev_sample, pred_original_sample) + + def add_noise(self, sample, timesteps, generator=None): + step_idx = (self.timesteps == timesteps).nonzero() + ratio = (step_idx + 1) / len(self.timesteps) + + if self.config.masking_schedule == "cosine": + mask_ratio = torch.cos(ratio * math.pi / 2) + elif self.config.masking_schedule == "linear": + mask_ratio = 1 - ratio + else: + raise ValueError(f"unknown masking schedule {self.config.masking_schedule}") + + mask_indices = ( + torch.rand( + sample.shape, device=generator.device if generator is not None else sample.device, generator=generator + ).to(sample.device) + < mask_ratio + ) + + masked_sample = sample.clone() + + masked_sample[mask_indices] = self.config.mask_token_id + + return masked_sample diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 263bcaea5a8d..5bd2f493ce08 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -317,6 +317,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UVit2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class VQModel(metaclass=DummyObject): _backends = ["torch"] @@ -660,6 +675,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AmusedScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 89fa03e57287..ae6c6c916065 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -32,6 +32,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AmusedImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AmusedInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AmusedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AnimateDiffPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/amused/__init__.py b/tests/pipelines/amused/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py new file mode 100644 index 000000000000..38159cf2ac15 --- /dev/null +++ b/tests/pipelines/amused/test_amused.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import AmusedPipeline, AmusedScheduler, UVit2DModel, VQModel +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AmusedPipeline + params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = UVit2DModel( + hidden_size=32, + use_bias=False, + hidden_dropout=0.0, + cond_embed_dim=32, + micro_cond_encode_dim=2, + micro_cond_embed_dim=10, + encoder_hidden_size=32, + vocab_size=32, + codebook_size=32, + in_channels=32, + block_out_channels=32, + num_res_blocks=1, + downsample=True, + upsample=True, + block_num_heads=1, + num_hidden_layers=1, + num_attention_heads=1, + attention_dropout=0.0, + intermediate_size=32, + layer_norm_eps=1e-06, + ln_elementwise_affine=True, + ) + scheduler = AmusedScheduler(mask_token_id=31) + torch.manual_seed(0) + vqvae = VQModel( + act_fn="silu", + block_out_channels=[32], + down_block_types=[ + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=32, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=32, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=64, + layer_norm_eps=1e-05, + num_attention_heads=8, + num_hidden_layers=3, + pad_token_id=1, + vocab_size=1000, + projection_dim=32, + ) + text_encoder = CLIPTextModelWithProjection(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "scheduler": scheduler, + "vqvae": vqvae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + "height": 4, + "width": 4, + } + return inputs + + def test_inference_batch_consistent(self, batch_sizes=[2]): + self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) + + @unittest.skip("aMUSEd does not support lists of generators") + def test_inference_batch_single_identical(self): + ... + + +@slow +@require_torch_gpu +class AmusedPipelineSlowTests(unittest.TestCase): + def test_amused_256(self): + pipe = AmusedPipeline.from_pretrained("huggingface/amused-256") + pipe.to(torch_device) + + image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.4011, 0.3992, 0.3790, 0.3856, 0.3772, 0.3711, 0.3919, 0.3850, 0.3625]) + assert np.abs(image_slice - expected_slice).max() < 3e-3 + + def test_amused_256_fp16(self): + pipe = AmusedPipeline.from_pretrained("huggingface/amused-256", variant="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + + image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.0554, 0.05129, 0.0344, 0.0452, 0.0476, 0.0271, 0.0495, 0.0527, 0.0158]) + assert np.abs(image_slice - expected_slice).max() < 7e-3 + + def test_amused_512(self): + pipe = AmusedPipeline.from_pretrained("huggingface/amused-512") + pipe.to(torch_device) + + image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.9960, 0.9960, 0.9946, 0.9980, 0.9947, 0.9932, 0.9960, 0.9961, 0.9947]) + assert np.abs(image_slice - expected_slice).max() < 3e-3 + + def test_amused_512_fp16(self): + pipe = AmusedPipeline.from_pretrained("huggingface/amused-512", variant="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + + image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.9983, 1.0, 1.0, 1.0, 1.0, 0.9989, 0.9994, 0.9976, 0.9977]) + assert np.abs(image_slice - expected_slice).max() < 3e-3 diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py new file mode 100644 index 000000000000..dcd29ae88e5b --- /dev/null +++ b/tests/pipelines/amused/test_amused_img2img.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import AmusedImg2ImgPipeline, AmusedScheduler, UVit2DModel, VQModel +from diffusers.utils import load_image +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device + +from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AmusedImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "latents"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - { + "latents", + } + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = UVit2DModel( + hidden_size=32, + use_bias=False, + hidden_dropout=0.0, + cond_embed_dim=32, + micro_cond_encode_dim=2, + micro_cond_embed_dim=10, + encoder_hidden_size=32, + vocab_size=32, + codebook_size=32, + in_channels=32, + block_out_channels=32, + num_res_blocks=1, + downsample=True, + upsample=True, + block_num_heads=1, + num_hidden_layers=1, + num_attention_heads=1, + attention_dropout=0.0, + intermediate_size=32, + layer_norm_eps=1e-06, + ln_elementwise_affine=True, + ) + scheduler = AmusedScheduler(mask_token_id=31) + torch.manual_seed(0) + vqvae = VQModel( + act_fn="silu", + block_out_channels=[32], + down_block_types=[ + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=32, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=32, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=64, + layer_norm_eps=1e-05, + num_attention_heads=8, + num_hidden_layers=3, + pad_token_id=1, + vocab_size=1000, + projection_dim=32, + ) + text_encoder = CLIPTextModelWithProjection(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "scheduler": scheduler, + "vqvae": vqvae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + "image": image, + } + return inputs + + def test_inference_batch_consistent(self, batch_sizes=[2]): + self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) + + @unittest.skip("aMUSEd does not support lists of generators") + def test_inference_batch_single_identical(self): + ... + + +@slow +@require_torch_gpu +class AmusedImg2ImgPipelineSlowTests(unittest.TestCase): + def test_amused_256(self): + pipe = AmusedImg2ImgPipeline.from_pretrained("huggingface/amused-256") + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg") + .resize((256, 256)) + .convert("RGB") + ) + + image = pipe( + "winter mountains", + image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9993, 1.0, 0.9996, 1.0, 0.9995, 0.9925, 0.9990, 0.9954, 1.0]) + + assert np.abs(image_slice - expected_slice).max() < 1e-2 + + def test_amused_256_fp16(self): + pipe = AmusedImg2ImgPipeline.from_pretrained( + "huggingface/amused-256", torch_dtype=torch.float16, variant="fp16" + ) + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg") + .resize((256, 256)) + .convert("RGB") + ) + + image = pipe( + "winter mountains", + image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9980, 0.9980, 0.9940, 0.9944, 0.9960, 0.9908, 1.0, 1.0, 0.9986]) + + assert np.abs(image_slice - expected_slice).max() < 1e-2 + + def test_amused_512(self): + pipe = AmusedImg2ImgPipeline.from_pretrained("huggingface/amused-512") + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg") + .resize((512, 512)) + .convert("RGB") + ) + + image = pipe( + "winter mountains", + image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.1344, 0.0985, 0.0, 0.1194, 0.1809, 0.0765, 0.0854, 0.1371, 0.0933]) + assert np.abs(image_slice - expected_slice).max() < 0.1 + + def test_amused_512_fp16(self): + pipe = AmusedImg2ImgPipeline.from_pretrained( + "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg") + .resize((512, 512)) + .convert("RGB") + ) + + image = pipe( + "winter mountains", + image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.1536, 0.1767, 0.0227, 0.1079, 0.2400, 0.1427, 0.1511, 0.1564, 0.1542]) + assert np.abs(image_slice - expected_slice).max() < 0.1 diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py new file mode 100644 index 000000000000..014485d7b9e4 --- /dev/null +++ b/tests/pipelines/amused/test_amused_inpaint.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import AmusedInpaintPipeline, AmusedScheduler, UVit2DModel, VQModel +from diffusers.utils import load_image +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device + +from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AmusedInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"} + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - { + "latents", + } + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = UVit2DModel( + hidden_size=32, + use_bias=False, + hidden_dropout=0.0, + cond_embed_dim=32, + micro_cond_encode_dim=2, + micro_cond_embed_dim=10, + encoder_hidden_size=32, + vocab_size=32, + codebook_size=32, + in_channels=32, + block_out_channels=32, + num_res_blocks=1, + downsample=True, + upsample=True, + block_num_heads=1, + num_hidden_layers=1, + num_attention_heads=1, + attention_dropout=0.0, + intermediate_size=32, + layer_norm_eps=1e-06, + ln_elementwise_affine=True, + ) + scheduler = AmusedScheduler(mask_token_id=31) + torch.manual_seed(0) + vqvae = VQModel( + act_fn="silu", + block_out_channels=[32], + down_block_types=[ + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=32, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=32, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=64, + layer_norm_eps=1e-05, + num_attention_heads=8, + num_hidden_layers=3, + pad_token_id=1, + vocab_size=1000, + projection_dim=32, + ) + text_encoder = CLIPTextModelWithProjection(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "scheduler": scheduler, + "vqvae": vqvae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device) + mask_image = torch.full((1, 1, 4, 4), 1.0, dtype=torch.float32, device=device) + mask_image[0, 0, 0, 0] = 0 + mask_image[0, 0, 0, 1] = 0 + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + "image": image, + "mask_image": mask_image, + } + return inputs + + def test_inference_batch_consistent(self, batch_sizes=[2]): + self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) + + @unittest.skip("aMUSEd does not support lists of generators") + def test_inference_batch_single_identical(self): + ... + + +@slow +@require_torch_gpu +class AmusedInpaintPipelineSlowTests(unittest.TestCase): + def test_amused_256(self): + pipe = AmusedInpaintPipeline.from_pretrained("huggingface/amused-256") + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg") + .resize((256, 256)) + .convert("RGB") + ) + + mask_image = ( + load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ) + .resize((256, 256)) + .convert("L") + ) + + image = pipe( + "winter mountains", + image, + mask_image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.0699, 0.0716, 0.0608, 0.0715, 0.0797, 0.0638, 0.0802, 0.0924, 0.0634]) + assert np.abs(image_slice - expected_slice).max() < 0.1 + + def test_amused_256_fp16(self): + pipe = AmusedInpaintPipeline.from_pretrained( + "huggingface/amused-256", variant="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg") + .resize((256, 256)) + .convert("RGB") + ) + + mask_image = ( + load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ) + .resize((256, 256)) + .convert("L") + ) + + image = pipe( + "winter mountains", + image, + mask_image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.0735, 0.0749, 0.0650, 0.0739, 0.0805, 0.0667, 0.0802, 0.0923, 0.0622]) + assert np.abs(image_slice - expected_slice).max() < 0.1 + + def test_amused_512(self): + pipe = AmusedInpaintPipeline.from_pretrained("huggingface/amused-512") + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg") + .resize((512, 512)) + .convert("RGB") + ) + + mask_image = ( + load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ) + .resize((512, 512)) + .convert("L") + ) + + image = pipe( + "winter mountains", + image, + mask_image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0005, 0.0]) + assert np.abs(image_slice - expected_slice).max() < 0.05 + + def test_amused_512_fp16(self): + pipe = AmusedInpaintPipeline.from_pretrained( + "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg") + .resize((512, 512)) + .convert("RGB") + ) + + mask_image = ( + load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ) + .resize((512, 512)) + .convert("L") + ) + + image = pipe( + "winter mountains", + image, + mask_image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0025, 0.0]) + assert np.abs(image_slice - expected_slice).max() < 3e-3 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index cac5ee442ae6..ed2920cb0c73 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -437,7 +437,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]): self._test_inference_batch_consistent(batch_sizes=batch_sizes) def _test_inference_batch_consistent( - self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"] + self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True ): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -472,7 +472,7 @@ def _test_inference_batch_consistent( else: batched_input[name] = batch_size * [value] - if "generator" in inputs: + if batch_generator and "generator" in inputs: batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)] if "batch_size" in inputs: From c022e52923c897d0c16ebcaa9feecfbf1dfbec66 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 22 Dec 2023 13:35:21 +0530 Subject: [PATCH 27/42] Remove ONNX inpaint legacy (#6269) update Co-authored-by: Sayak Paul --- ...st_onnx_stable_diffusion_inpaint_legacy.py | 97 ------------------- 1 file changed, 97 deletions(-) delete mode 100644 tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py deleted file mode 100644 index 235aa32f7338..000000000000 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py +++ /dev/null @@ -1,97 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np - -from diffusers import OnnxStableDiffusionInpaintPipelineLegacy -from diffusers.utils.testing_utils import ( - is_onnx_available, - load_image, - load_numpy, - nightly, - require_onnxruntime, - require_torch_gpu, -) - - -if is_onnx_available(): - import onnxruntime as ort - - -@nightly -@require_onnxruntime -@require_torch_gpu -class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase): - @property - def gpu_provider(self): - return ( - "CUDAExecutionProvider", - { - "gpu_mem_limit": "15000000000", # 15GB - "arena_extend_strategy": "kSameAsRequested", - }, - ) - - @property - def gpu_options(self): - options = ort.SessionOptions() - options.enable_mem_pattern = False - return options - - def test_inference(self): - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ) - mask_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" - ) - expected_image = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/red_cat_sitting_on_a_park_bench_onnx.npy" - ) - - # using the PNDM scheduler by default - pipe = OnnxStableDiffusionInpaintPipelineLegacy.from_pretrained( - "CompVis/stable-diffusion-v1-4", - revision="onnx", - safety_checker=None, - feature_extractor=None, - provider=self.gpu_provider, - sess_options=self.gpu_options, - ) - pipe.set_progress_bar_config(disable=None) - - prompt = "A red cat sitting on a park bench" - - generator = np.random.RandomState(0) - output = pipe( - prompt=prompt, - image=init_image, - mask_image=mask_image, - strength=0.75, - guidance_scale=7.5, - num_inference_steps=15, - generator=generator, - output_type="np", - ) - - image = output.images[0] - - assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).max() < 1e-2 From 59d1caa2385fae1761232e22ed6bcf4f9a492bf7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 22 Dec 2023 13:35:52 +0530 Subject: [PATCH 28/42] Remove peft tests from old lora backend tests (#6273) update --- tests/lora/test_lora_layers_peft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 1d8c6977440c..f6cd2a714ae2 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -1397,7 +1397,7 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): @slow @require_torch_gpu -class LoraIntegrationTests(unittest.TestCase): +class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): def tearDown(self): import gc @@ -1650,7 +1650,7 @@ def test_load_unload_load_kohya_lora(self): @slow @require_torch_gpu -class LoraSDXLIntegrationTests(unittest.TestCase): +class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): def tearDown(self): import gc From 7fe47596af40ea900318e6a2a773a00ff3f3a115 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 22 Dec 2023 09:37:30 +0100 Subject: [PATCH 29/42] Allow diffusers to load with Flax, w/o PyTorch (#6272) --- src/diffusers/utils/torch_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 00bc75f41be3..d0d02fb92e72 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -89,7 +89,7 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) -def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor: +def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). This version of the method comes from here: @@ -121,8 +121,8 @@ def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tens def apply_freeu( - resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: + resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs +) -> Tuple["torch.Tensor", "torch.Tensor"]: """Applies the FreeU mechanism as introduced in https: //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. From 3369bc810a09a52521bbf8cc1ec77df3a8c682a8 Mon Sep 17 00:00:00 2001 From: Bingxin Ke <45253439+markkua@users.noreply.github.com> Date: Fri, 22 Dec 2023 11:11:46 +0100 Subject: [PATCH 30/42] [Community Pipeline] Add Marigold Monocular Depth Estimation (#6249) * [Community Pipeline] Add Marigold Monocular Depth Estimation - add single-file pipeline - update README * fix format - add one blank line * format script with ruff * use direct image link in example code --------- Co-authored-by: Sayak Paul --- examples/community/README.md | 48 ++ .../community/marigold_depth_estimation.py | 602 ++++++++++++++++++ 2 files changed, 650 insertions(+) create mode 100644 examples/community/marigold_depth_estimation.py diff --git a/examples/community/README.md b/examples/community/README.md index 7af6d1d7eb02..c3aa1ecf3d64 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -8,6 +8,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | Example | Description | Code Example | Colab | Author | |:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| +| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) | | LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) | | CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) | | One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | @@ -61,6 +62,53 @@ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custo ## Example usages +### Marigold Depth Estimation + +Marigold is a universal monocular depth estimator that delivers accurate and sharp predictions in the wild. Based on Stable Diffusion, it is trained exclusively with synthetic depth data and excels in zero-shot adaptation to real-world imagery. This pipeline is an official implementation of the inference process. More details can be found on our [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) (also implemented with diffusers). + +![Marigold Teaser](https://marigoldmonodepth.github.io/images/teaser_collage_compressed.jpg) + +This depth estimation pipeline processes a single input image through multiple diffusion denoising stages to estimate depth maps. These maps are subsequently merged to produce the final output. Below is an example code snippet, including optional arguments: + +```python +import numpy as np +from PIL import Image +from diffusers import DiffusionPipeline +from diffusers.utils import load_image + +pipe = DiffusionPipeline.from_pretrained( + "Bingxin/Marigold", + custom_pipeline="marigold_depth_estimation" + # torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float). +) + +pipe.to("cuda") + +img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_example.jpg" +image: Image.Image = load_image(img_path_or_url) + +pipeline_output = pipe( + image, # Input image. + # denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10. + # ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10. + # processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768. + # match_input_res=True, # (optional) Resize depth prediction to match input resolution. + # batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0. + # color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". + # show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress. +) + +depth: np.ndarray = pipeline_output.depth_np # Predicted depth map +depth_colored: Image.Image = pipeline_output.depth_colored # Colorized prediction + +# Save as uint16 PNG +depth_uint16 = (depth * 65535.0).astype(np.uint16) +Image.fromarray(depth_uint16).save("./depth_map.png", mode="I;16") + +# Save colorized depth map +depth_colored.save("./depth_colored.png") +``` + ### LLM-grounded Diffusion LMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py new file mode 100644 index 000000000000..31da842112fb --- /dev/null +++ b/examples/community/marigold_depth_estimation.py @@ -0,0 +1,602 @@ +# Copyright 2023 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + + +import math +from typing import Dict, Union + +import matplotlib +import numpy as np +import torch +from PIL import Image +from scipy.optimize import minimize +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.utils import BaseOutput, check_min_version + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.20.1.dev0") + + +class MarigoldDepthOutput(BaseOutput): + """ + Output class for Marigold monocular depth prediction pipeline. + + Args: + depth_np (`np.ndarray`): + Predicted depth map, with depth values in the range of [0, 1]. + depth_colored (`PIL.Image.Image`): + Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. + uncertainty (`None` or `np.ndarray`): + Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. + """ + + depth_np: np.ndarray + depth_colored: Image.Image + uncertainty: Union[None, np.ndarray] + + +class MarigoldPipeline(DiffusionPipeline): + """ + Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the depth latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps + to and from latent representations. + scheduler (`DDIMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + """ + + rgb_latent_scale_factor = 0.18215 + depth_latent_scale_factor = 0.18215 + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: DDIMScheduler, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + self.empty_text_embed = None + + @torch.no_grad() + def __call__( + self, + input_image: Image, + denoising_steps: int = 10, + ensemble_size: int = 10, + processing_res: int = 768, + match_input_res: bool = True, + batch_size: int = 0, + color_map: str = "Spectral", + show_progress_bar: bool = True, + ensemble_kwargs: Dict = None, + ) -> MarigoldDepthOutput: + """ + Function invoked when calling the pipeline. + + Args: + input_image (`Image`): + Input RGB (or gray-scale) image. + processing_res (`int`, *optional*, defaults to `768`): + Maximum resolution of processing. + If set to 0: will not resize at all. + match_input_res (`bool`, *optional*, defaults to `True`): + Resize depth prediction to match input resolution. + Only valid if `limit_input_res` is not None. + denoising_steps (`int`, *optional*, defaults to `10`): + Number of diffusion denoising steps (DDIM) during inference. + ensemble_size (`int`, *optional*, defaults to `10`): + Number of predictions to be ensembled. + batch_size (`int`, *optional*, defaults to `0`): + Inference batch size, no bigger than `num_ensemble`. + If set to 0, the script will automatically decide the proper batch size. + show_progress_bar (`bool`, *optional*, defaults to `True`): + Display a progress bar of diffusion denoising. + color_map (`str`, *optional*, defaults to `"Spectral"`): + Colormap used to colorize the depth map. + ensemble_kwargs (`dict`, *optional*, defaults to `None`): + Arguments for detailed ensembling settings. + Returns: + `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: + - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] + - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1] + - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) + coming from ensembling. None if `ensemble_size = 1` + """ + + device = self.device + input_size = input_image.size + + if not match_input_res: + assert processing_res is not None, "Value error: `resize_output_back` is only valid with " + assert processing_res >= 0 + assert denoising_steps >= 1 + assert ensemble_size >= 1 + + # ----------------- Image Preprocess ----------------- + # Resize image + if processing_res > 0: + input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res) + # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel + input_image = input_image.convert("RGB") + image = np.asarray(input_image) + + # Normalize rgb values + rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W] + rgb_norm = rgb / 255.0 + rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) + rgb_norm = rgb_norm.to(device) + assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0 + + # ----------------- Predicting depth ----------------- + # Batch repeated input image + duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) + single_rgb_dataset = TensorDataset(duplicated_rgb) + if batch_size > 0: + _bs = batch_size + else: + _bs = self._find_batch_size( + ensemble_size=ensemble_size, + input_res=max(rgb_norm.shape[1:]), + dtype=self.dtype, + ) + + single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False) + + # Predict depth maps (batched) + depth_pred_ls = [] + if show_progress_bar: + iterable = tqdm(single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False) + else: + iterable = single_rgb_loader + for batch in iterable: + (batched_img,) = batch + depth_pred_raw = self.single_infer( + rgb_in=batched_img, + num_inference_steps=denoising_steps, + show_pbar=show_progress_bar, + ) + depth_pred_ls.append(depth_pred_raw.detach().clone()) + depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() + torch.cuda.empty_cache() # clear vram cache for ensembling + + # ----------------- Test-time ensembling ----------------- + if ensemble_size > 1: + depth_pred, pred_uncert = self.ensemble_depths(depth_preds, **(ensemble_kwargs or {})) + else: + depth_pred = depth_preds + pred_uncert = None + + # ----------------- Post processing ----------------- + # Scale prediction to [0, 1] + min_d = torch.min(depth_pred) + max_d = torch.max(depth_pred) + depth_pred = (depth_pred - min_d) / (max_d - min_d) + + # Convert to numpy + depth_pred = depth_pred.cpu().numpy().astype(np.float32) + + # Resize back to original resolution + if match_input_res: + pred_img = Image.fromarray(depth_pred) + pred_img = pred_img.resize(input_size) + depth_pred = np.asarray(pred_img) + + # Clip output range + depth_pred = depth_pred.clip(0, 1) + + # Colorize + depth_colored = self.colorize_depth_maps( + depth_pred, 0, 1, cmap=color_map + ).squeeze() # [3, H, W], value in (0, 1) + depth_colored = (depth_colored * 255).astype(np.uint8) + depth_colored_hwc = self.chw2hwc(depth_colored) + depth_colored_img = Image.fromarray(depth_colored_hwc) + return MarigoldDepthOutput( + depth_np=depth_pred, + depth_colored=depth_colored_img, + uncertainty=pred_uncert, + ) + + def _encode_empty_text(self): + """ + Encode text embedding for empty prompt. + """ + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) + self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) + + @torch.no_grad() + def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor: + """ + Perform an individual depth prediction without ensembling. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image. + num_inference_steps (`int`): + Number of diffusion denoisign steps (DDIM) during inference. + show_pbar (`bool`): + Display a progress bar of diffusion denoising. + Returns: + `torch.Tensor`: Predicted depth map. + """ + device = rgb_in.device + + # Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # [T] + + # Encode image + rgb_latent = self._encode_rgb(rgb_in) + + # Initial depth map (noise) + depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w] + + # Batched empty text embedding + if self.empty_text_embed is None: + self._encode_empty_text() + batch_empty_text_embed = self.empty_text_embed.repeat((rgb_latent.shape[0], 1, 1)) # [B, 2, 1024] + + # Denoising loop + if show_pbar: + iterable = tqdm( + enumerate(timesteps), + total=len(timesteps), + leave=False, + desc=" " * 4 + "Diffusion denoising", + ) + else: + iterable = enumerate(timesteps) + + for i, t in iterable: + unet_input = torch.cat([rgb_latent, depth_latent], dim=1) # this order is important + + # predict the noise residual + noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w] + + # compute the previous noisy sample x_t -> x_t-1 + depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample + torch.cuda.empty_cache() + depth = self._decode_depth(depth_latent) + + # clip prediction + depth = torch.clip(depth, -1.0, 1.0) + # shift to [0, 1] + depth = (depth + 1.0) / 2.0 + + return depth + + def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: + """ + Encode RGB image into latent. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image to be encoded. + + Returns: + `torch.Tensor`: Image latent. + """ + # encode + h = self.vae.encoder(rgb_in) + moments = self.vae.quant_conv(h) + mean, logvar = torch.chunk(moments, 2, dim=1) + # scale latent + rgb_latent = mean * self.rgb_latent_scale_factor + return rgb_latent + + def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: + """ + Decode depth latent into depth map. + + Args: + depth_latent (`torch.Tensor`): + Depth latent to be decoded. + + Returns: + `torch.Tensor`: Decoded depth map. + """ + # scale latent + depth_latent = depth_latent / self.depth_latent_scale_factor + # decode + z = self.vae.post_quant_conv(depth_latent) + stacked = self.vae.decoder(z) + # mean of output channels + depth_mean = stacked.mean(dim=1, keepdim=True) + return depth_mean + + @staticmethod + def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: + """ + Resize image to limit maximum edge length while keeping aspect ratio. + + Args: + img (`Image.Image`): + Image to be resized. + max_edge_resolution (`int`): + Maximum edge length (pixel). + + Returns: + `Image.Image`: Resized image. + """ + original_width, original_height = img.size + downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height) + + new_width = int(original_width * downscale_factor) + new_height = int(original_height * downscale_factor) + + resized_img = img.resize((new_width, new_height)) + return resized_img + + @staticmethod + def colorize_depth_maps(depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None): + """ + Colorize depth maps. + """ + assert len(depth_map.shape) >= 2, "Invalid dimension" + + if isinstance(depth_map, torch.Tensor): + depth = depth_map.detach().clone().squeeze().numpy() + elif isinstance(depth_map, np.ndarray): + depth = depth_map.copy().squeeze() + # reshape to [ (B,) H, W ] + if depth.ndim < 3: + depth = depth[np.newaxis, :, :] + + # colorize + cm = matplotlib.colormaps[cmap] + depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) + img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 + img_colored_np = np.rollaxis(img_colored_np, 3, 1) + + if valid_mask is not None: + if isinstance(depth_map, torch.Tensor): + valid_mask = valid_mask.detach().numpy() + valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] + if valid_mask.ndim < 3: + valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] + else: + valid_mask = valid_mask[:, np.newaxis, :, :] + valid_mask = np.repeat(valid_mask, 3, axis=1) + img_colored_np[~valid_mask] = 0 + + if isinstance(depth_map, torch.Tensor): + img_colored = torch.from_numpy(img_colored_np).float() + elif isinstance(depth_map, np.ndarray): + img_colored = img_colored_np + + return img_colored + + @staticmethod + def chw2hwc(chw): + assert 3 == len(chw.shape) + if isinstance(chw, torch.Tensor): + hwc = torch.permute(chw, (1, 2, 0)) + elif isinstance(chw, np.ndarray): + hwc = np.moveaxis(chw, 0, -1) + return hwc + + @staticmethod + def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: + """ + Automatically search for suitable operating batch size. + + Args: + ensemble_size (`int`): + Number of predictions to be ensembled. + input_res (`int`): + Operating resolution of the input image. + + Returns: + `int`: Operating batch size. + """ + # Search table for suggested max. inference batch size + bs_search_table = [ + # tested on A100-PCIE-80GB + {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, + {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, + # tested on A100-PCIE-40GB + {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, + {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, + {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, + {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, + # tested on RTX3090, RTX4090 + {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, + {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, + {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, + {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, + {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, + {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, + # tested on GTX1080Ti + {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, + {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, + {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, + {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, + {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, + ] + + if not torch.cuda.is_available(): + return 1 + + total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 + filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] + for settings in sorted( + filtered_bs_search_table, + key=lambda k: (k["res"], -k["total_vram"]), + ): + if input_res <= settings["res"] and total_vram >= settings["total_vram"]: + bs = settings["bs"] + if bs > ensemble_size: + bs = ensemble_size + elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: + bs = math.ceil(ensemble_size / 2) + return bs + + return 1 + + @staticmethod + def ensemble_depths( + input_images: torch.Tensor, + regularizer_strength: float = 0.02, + max_iter: int = 2, + tol: float = 1e-3, + reduction: str = "median", + max_res: int = None, + ): + """ + To ensemble multiple affine-invariant depth images (up to scale and shift), + by aligning estimating the scale and shift + """ + + def inter_distances(tensors: torch.Tensor): + """ + To calculate the distance between each two depth maps. + """ + distances = [] + for i, j in torch.combinations(torch.arange(tensors.shape[0])): + arr1 = tensors[i : i + 1] + arr2 = tensors[j : j + 1] + distances.append(arr1 - arr2) + dist = torch.concatenate(distances, dim=0) + return dist + + device = input_images.device + dtype = input_images.dtype + np_dtype = np.float32 + + original_input = input_images.clone() + n_img = input_images.shape[0] + ori_shape = input_images.shape + + if max_res is not None: + scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) + if scale_factor < 1: + downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") + input_images = downscaler(torch.from_numpy(input_images)).numpy() + + # init guess + _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) + t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) + x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) + + input_images = input_images.to(device) + + # objective function + def closure(x): + l = len(x) + s = x[: int(l / 2)] + t = x[int(l / 2) :] + s = torch.from_numpy(s).to(dtype=dtype).to(device) + t = torch.from_numpy(t).to(dtype=dtype).to(device) + + transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) + dists = inter_distances(transformed_arrays) + sqrt_dist = torch.sqrt(torch.mean(dists**2)) + + if "mean" == reduction: + pred = torch.mean(transformed_arrays, dim=0) + elif "median" == reduction: + pred = torch.median(transformed_arrays, dim=0).values + else: + raise ValueError + + near_err = torch.sqrt((0 - torch.min(pred)) ** 2) + far_err = torch.sqrt((1 - torch.max(pred)) ** 2) + + err = sqrt_dist + (near_err + far_err) * regularizer_strength + err = err.detach().cpu().numpy().astype(np_dtype) + return err + + res = minimize( + closure, + x, + method="BFGS", + tol=tol, + options={"maxiter": max_iter, "disp": False}, + ) + x = res.x + l = len(x) + s = x[: int(l / 2)] + t = x[int(l / 2) :] + + # Prediction + s = torch.from_numpy(s).to(dtype=dtype).to(device) + t = torch.from_numpy(t).to(dtype=dtype).to(device) + transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) + if "mean" == reduction: + aligned_images = torch.mean(transformed_arrays, dim=0) + std = torch.std(transformed_arrays, dim=0) + uncertainty = std + elif "median" == reduction: + aligned_images = torch.median(transformed_arrays, dim=0).values + # MAD (median absolute deviation) as uncertainty indicator + abs_dev = torch.abs(transformed_arrays - aligned_images) + mad = torch.median(abs_dev, dim=0).values + uncertainty = mad + else: + raise ValueError(f"Unknown reduction method: {reduction}") + + # Scale and shift to [0, 1] + _min = torch.min(aligned_images) + _max = torch.max(aligned_images) + aligned_images = (aligned_images - _min) / (_max - _min) + uncertainty /= _max - _min + + return aligned_images, uncertainty From df76a39e1bc1de5bec647ce56a7fe4d8d1b6a643 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Fri, 22 Dec 2023 06:42:04 -0600 Subject: [PATCH 31/42] Fix Prodigy optimizer in SDXL Dreambooth script (#6290) * Fix ProdigyOPT in SDXL Dreambooth script * style * style --- .../dreambooth/train_dreambooth_lora_sdxl.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9992292e30aa..8a3ac294fef2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1144,10 +1144,26 @@ def load_model_hook(models, input_dir): optimizer_class = prodigyopt.Prodigy + if args.learning_rate <= 0.1: + logger.warn( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, From 90b9479903dcf3b053dc2461d4d6266eed0c27ea Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 24 Dec 2023 09:59:41 +0530 Subject: [PATCH 32/42] [LoRA PEFT] fix LoRA loading so that correct alphas are parsed (#6225) * initialize alpha too. * add: test * remove config parsing * store rank * debug * remove faulty test --- examples/dreambooth/train_dreambooth_lora.py | 6 +++++- examples/dreambooth/train_dreambooth_lora_sdxl.py | 10 ++++++++-- examples/text_to_image/train_text_to_image_lora.py | 5 ++++- .../text_to_image/train_text_to_image_lora_sdxl.py | 10 ++++++++-- tests/lora/test_lora_layers_peft.py | 8 ++++++-- 5 files changed, 31 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 55ef2bbeb8eb..67132d6d88df 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -827,6 +827,7 @@ def main(args): # now we will add new LoRA weights to the attention layers unet_lora_config = LoraConfig( r=args.rank, + lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], ) @@ -835,7 +836,10 @@ def main(args): # The text encoder comes from 🤗 transformers, we will also attach adapters to it. if args.train_text_encoder: text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder.add_adapter(text_lora_config) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 8a3ac294fef2..0f41ad47d1ac 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -978,7 +978,10 @@ def main(args): # now we will add new LoRA weights to the attention layers unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) unet.add_adapter(unet_lora_config) @@ -986,7 +989,10 @@ def main(args): # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c8efbddd0b44..d6d0dee0883c 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -452,7 +452,10 @@ def main(): param.requires_grad_(False) unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) # Move unet, vae and text_encoder to device and cast to weight_dtype diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index be17c13c2885..d95fcbbba033 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -609,7 +609,10 @@ def main(args): # now we will add new LoRA weights to the attention layers # Set correct lora layers unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) unet.add_adapter(unet_lora_config) @@ -618,7 +621,10 @@ def main(args): if args.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index f6cd2a714ae2..30125f64f6ac 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -111,6 +111,7 @@ class PeftLoraLoaderMixinTests: def get_dummy_components(self, scheduler_cls=None): scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler + rank = 4 torch.manual_seed(0) unet = UNet2DConditionModel(**self.unet_kwargs) @@ -125,11 +126,14 @@ def get_dummy_components(self, scheduler_cls=None): tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") text_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False + r=rank, + lora_alpha=rank, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + init_lora_weights=False, ) unet_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False ) unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) From fe574c8b29297f4b9a562f21a88e9de3e4fda856 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sun, 24 Dec 2023 14:31:48 +0530 Subject: [PATCH 33/42] LoRA Unfusion test fix (#6291) update Co-authored-by: Sayak Paul --- tests/lora/test_lora_layers_peft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 30125f64f6ac..180d45b6803e 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -1881,7 +1881,9 @@ def test_sdxl_1_0_lora_unfusion(self): ).images images_without_fusion = images.flatten() - self.assertTrue(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3)) + max_diff = numpy_cosine_similarity_distance(images_with_fusion, images_without_fusion) + assert max_diff < 1e-4 + release_memory(pipe) def test_sdxl_1_0_lora_unfusion_effectivity(self): From 7c05b975b79df39875959494020e4b5eedd2c4c8 Mon Sep 17 00:00:00 2001 From: Celestial Phineas <17267055+celestialphineas@users.noreply.github.com> Date: Sun, 24 Dec 2023 17:02:24 +0800 Subject: [PATCH 34/42] Fix typos in the `ValueError` for a nested image list as `StableDiffusionControlNetPipeline` input. (#6286) Fixed typos in the `ValueError` for a nested image list as input. --- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index d7168bec8259..6bdc281ef8bf 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -633,7 +633,7 @@ def check_inputs( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") + raise ValueError("A single batch of multiple conditionings is not supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." @@ -659,7 +659,7 @@ def check_inputs( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") + raise ValueError("A single batch of multiple conditionings is not supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): From 2d43094ffc9b1ee377651c6c8a358c81f0c96005 Mon Sep 17 00:00:00 2001 From: mwkldeveloper Date: Sun, 24 Dec 2023 17:04:35 +0800 Subject: [PATCH 35/42] fix RuntimeError: Input type (float) and bias type (c10::Half) should be the same in train_text_to_image_lora.py (#6259) * fix RuntimeError: Input type (float) and bias type (c10::Half) should be the same * format source code * format code * remove the autocast blocks within the pipeline * add autocast blocks to pipeline caller in train_text_to_image_lora.py --- .../text_to_image/train_text_to_image_lora.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index d6d0dee0883c..2efbaf298d2e 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -847,10 +847,11 @@ def collate_fn(examples): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + with torch.cuda.amp.autocast(): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -916,8 +917,11 @@ def collate_fn(examples): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + with torch.cuda.amp.autocast(): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if len(images) != 0: From 008d9818a25bb667532cba1611093ccce1902b25 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 25 Dec 2023 10:45:14 +0530 Subject: [PATCH 36/42] fix: t2i apdater paper link (#6314) --- docs/source/en/training/t2i_adapters.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/training/t2i_adapters.md b/docs/source/en/training/t2i_adapters.md index 0f65ad8ed31d..03f4537cb28d 100644 --- a/docs/source/en/training/t2i_adapters.md +++ b/docs/source/en/training/t2i_adapters.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # T2I-Adapter -[T2I-Adapter]((https://hf.co/papers/2302.08453)) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it. +[T2I-Adapter](https://hf.co/papers/2302.08453) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it. The T2I-Adapter is only available for training with the Stable Diffusion XL (SDXL) model. From 89459a5d561b9c0bf1316f1be955154275d9d24a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 25 Dec 2023 11:26:45 +0530 Subject: [PATCH 37/42] fix: lora peft dummy components (#6308) * fix: lora peft dummy components * fix: dummy components --- tests/lora/test_lora_layers_peft.py | 68 +++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 180d45b6803e..38e55b9ed7b4 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -115,9 +115,12 @@ def get_dummy_components(self, scheduler_cls=None): torch.manual_seed(0) unet = UNet2DConditionModel(**self.unet_kwargs) + scheduler = scheduler_cls(**self.scheduler_kwargs) + torch.manual_seed(0) vae = AutoencoderKL(**self.vae_kwargs) + text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") @@ -1402,6 +1405,35 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): @slow @require_torch_gpu class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): + pipeline_class = StableDiffusionPipeline + scheduler_cls = DDIMScheduler + scheduler_kwargs = { + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "clip_sample": False, + "set_alpha_to_one": False, + "steps_offset": 1, + } + unet_kwargs = { + "block_out_channels": (32, 64), + "layers_per_block": 2, + "sample_size": 32, + "in_channels": 4, + "out_channels": 4, + "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), + "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), + "cross_attention_dim": 32, + } + vae_kwargs = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + def tearDown(self): import gc @@ -1655,6 +1687,42 @@ def test_load_unload_load_kohya_lora(self): @slow @require_torch_gpu class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): + has_two_text_encoders = True + pipeline_class = StableDiffusionXLPipeline + scheduler_cls = EulerDiscreteScheduler + scheduler_kwargs = { + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "timestep_spacing": "leading", + "steps_offset": 1, + } + unet_kwargs = { + "block_out_channels": (32, 64), + "layers_per_block": 2, + "sample_size": 32, + "in_channels": 4, + "out_channels": 4, + "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), + "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), + "attention_head_dim": (2, 4), + "use_linear_projection": True, + "addition_embed_type": "text_time", + "addition_time_embed_dim": 8, + "transformer_layers_per_block": (1, 2), + "projection_class_embeddings_input_dim": 80, # 6 * 8 + 32 + "cross_attention_dim": 64, + } + vae_kwargs = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + "sample_size": 128, + } + def tearDown(self): import gc From f4b0b26f7e4ea1d47e0ab83721ca3487d36fa093 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 25 Dec 2023 19:50:48 +0530 Subject: [PATCH 38/42] [Tests] Speed up example tests (#6319) * remove validation args from textual onverson tests * reduce number of train steps in textual inversion tests * fix: directories. * debig * fix: directories. * remove validation tests from textual onversion * try reducing the time of test_text_to_image_checkpointing_use_ema * fix: directories * speed up test_text_to_image_checkpointing * speed up test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * fix * speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * set checkpoints_total_limit to 2. * test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints speed up * speed up test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * debug * fix: directories. * speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit * speed up: test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * speed up test_controlnet_sdxl * speed up dreambooth tests * speed up test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * speed up test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * speed up test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit * speed up # checkpoint-2 should have been deleted * speed up examples/text_to_image/test_text_to_image.py::TextToImage::test_text_to_image_checkpointing_checkpoints_total_limit * additional speed ups * style --- examples/controlnet/test_controlnet.py | 17 ++-- .../custom_diffusion/test_custom_diffusion.py | 20 ++--- examples/dreambooth/test_dreambooth.py | 27 +++--- examples/dreambooth/test_dreambooth_lora.py | 29 +++---- .../instruct_pix2pix/test_instruct_pix2pix.py | 14 ++-- examples/text_to_image/test_text_to_image.py | 83 +++++++++---------- .../text_to_image/test_text_to_image_lora.py | 61 ++++++-------- .../test_textual_inversion.py | 18 ++-- .../test_unconditional.py | 12 +-- 9 files changed, 117 insertions(+), 164 deletions(-) diff --git a/examples/controlnet/test_controlnet.py b/examples/controlnet/test_controlnet.py index e62d095adaa2..e1adafe6be6f 100644 --- a/examples/controlnet/test_controlnet.py +++ b/examples/controlnet/test_controlnet.py @@ -65,7 +65,7 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check --train_batch_size=1 --gradient_accumulation_steps=1 --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet - --max_train_steps=9 + --max_train_steps=6 --checkpointing_steps=2 """.split() @@ -73,7 +73,7 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6"}, ) resume_run_args = f""" @@ -85,18 +85,15 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check --train_batch_size=1 --gradient_accumulation_steps=1 --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet - --max_train_steps=11 + --max_train_steps=8 --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 + --resume_from_checkpoint=checkpoint-6 + --checkpoints_total_limit=2 """.split() run_command(self._launch_args + resume_run_args) - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, - ) + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) class ControlNetSDXL(ExamplesTestsAccelerate): @@ -111,7 +108,7 @@ def test_controlnet_sdxl(self): --train_batch_size=1 --gradient_accumulation_steps=1 --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl - --max_train_steps=9 + --max_train_steps=4 --checkpointing_steps=2 """.split() diff --git a/examples/custom_diffusion/test_custom_diffusion.py b/examples/custom_diffusion/test_custom_diffusion.py index 78f24c5172d6..da4355d5ac25 100644 --- a/examples/custom_diffusion/test_custom_diffusion.py +++ b/examples/custom_diffusion/test_custom_diffusion.py @@ -76,10 +76,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): run_command(self._launch_args + test_args) - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-4", "checkpoint-6"}, - ) + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -93,7 +90,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple --train_batch_size=1 --modifier_token= --dataloader_num_workers=0 - --max_train_steps=9 + --max_train_steps=4 --checkpointing_steps=2 --no_safe_serialization """.split() @@ -102,7 +99,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + {"checkpoint-2", "checkpoint-4"}, ) resume_run_args = f""" @@ -115,16 +112,13 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple --train_batch_size=1 --modifier_token= --dataloader_num_workers=0 - --max_train_steps=11 + --max_train_steps=8 --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 --no_safe_serialization """.split() run_command(self._launch_args + resume_run_args) - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/dreambooth/test_dreambooth.py b/examples/dreambooth/test_dreambooth.py index 0c6c2a062325..ce2f3215bc71 100644 --- a/examples/dreambooth/test_dreambooth.py +++ b/examples/dreambooth/test_dreambooth.py @@ -89,7 +89,7 @@ def test_dreambooth_checkpointing(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 5, checkpointing_steps == 2 + # max_train_steps == 4, checkpointing_steps == 2 # Should create checkpoints at steps 2, 4 initial_run_args = f""" @@ -100,7 +100,7 @@ def test_dreambooth_checkpointing(self): --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 5 + --max_train_steps 4 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -114,7 +114,7 @@ def test_dreambooth_checkpointing(self): # check can run the original fully trained output pipeline pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(instance_prompt, num_inference_steps=2) + pipe(instance_prompt, num_inference_steps=1) # check checkpoint directories exist self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) @@ -123,7 +123,7 @@ def test_dreambooth_checkpointing(self): # check can run an intermediate checkpoint unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) - pipe(instance_prompt, num_inference_steps=2) + pipe(instance_prompt, num_inference_steps=1) # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) @@ -138,7 +138,7 @@ def test_dreambooth_checkpointing(self): --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 7 + --max_train_steps 6 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -153,7 +153,7 @@ def test_dreambooth_checkpointing(self): # check can run new fully trained pipeline pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(instance_prompt, num_inference_steps=2) + pipe(instance_prompt, num_inference_steps=1) # check old checkpoints do not exist self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) @@ -196,7 +196,7 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check --resolution=64 --train_batch_size=1 --gradient_accumulation_steps=1 - --max_train_steps=9 + --max_train_steps=4 --checkpointing_steps=2 """.split() @@ -204,7 +204,7 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + {"checkpoint-2", "checkpoint-4"}, ) resume_run_args = f""" @@ -216,15 +216,12 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check --resolution=64 --train_batch_size=1 --gradient_accumulation_steps=1 - --max_train_steps=11 + --max_train_steps=8 --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 """.split() run_command(self._launch_args + resume_run_args) - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/dreambooth/test_dreambooth_lora.py b/examples/dreambooth/test_dreambooth_lora.py index fc43269f732e..496ce22f814e 100644 --- a/examples/dreambooth/test_dreambooth_lora.py +++ b/examples/dreambooth/test_dreambooth_lora.py @@ -135,16 +135,13 @@ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_ --resolution=64 --train_batch_size=1 --gradient_accumulation_steps=1 - --max_train_steps=9 + --max_train_steps=4 --checkpointing_steps=2 """.split() run_command(self._launch_args + test_args) - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, - ) + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) resume_run_args = f""" examples/dreambooth/train_dreambooth_lora.py @@ -155,18 +152,15 @@ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_ --resolution=64 --train_batch_size=1 --gradient_accumulation_steps=1 - --max_train_steps=11 + --max_train_steps=8 --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 """.split() run_command(self._launch_args + resume_run_args) - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) def test_dreambooth_lora_if_model(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -328,7 +322,7 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 7 + --max_train_steps 6 --checkpointing_steps=2 --checkpoints_total_limit=2 --learning_rate 5.0e-04 @@ -342,14 +336,11 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): pipe = DiffusionPipeline.from_pretrained(pipeline_path) pipe.load_lora_weights(tmpdir) - pipe("a prompt", num_inference_steps=2) + pipe("a prompt", num_inference_steps=1) # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) + # checkpoint-2 should have been deleted + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" diff --git a/examples/instruct_pix2pix/test_instruct_pix2pix.py b/examples/instruct_pix2pix/test_instruct_pix2pix.py index c4d7500723fa..b30baf8b1b02 100644 --- a/examples/instruct_pix2pix/test_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/test_instruct_pix2pix.py @@ -40,7 +40,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self): --resolution=64 --random_flip --train_batch_size=1 - --max_train_steps=7 + --max_train_steps=6 --checkpointing_steps=2 --checkpoints_total_limit=2 --output_dir {tmpdir} @@ -63,7 +63,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple --resolution=64 --random_flip --train_batch_size=1 - --max_train_steps=9 + --max_train_steps=4 --checkpointing_steps=2 --output_dir {tmpdir} --seed=0 @@ -74,7 +74,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + {"checkpoint-2", "checkpoint-4"}, ) resume_run_args = f""" @@ -84,12 +84,12 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple --resolution=64 --random_flip --train_batch_size=1 - --max_train_steps=11 + --max_train_steps=8 --checkpointing_steps=2 --output_dir {tmpdir} --seed=0 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 """.split() run_command(self._launch_args + resume_run_args) @@ -97,5 +97,5 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + {"checkpoint-6", "checkpoint-8"}, ) diff --git a/examples/text_to_image/test_text_to_image.py b/examples/text_to_image/test_text_to_image.py index 308a038b5533..814c13cf486e 100644 --- a/examples/text_to_image/test_text_to_image.py +++ b/examples/text_to_image/test_text_to_image.py @@ -64,7 +64,7 @@ def test_text_to_image_checkpointing(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 5, checkpointing_steps == 2 + # max_train_steps == 4, checkpointing_steps == 2 # Should create checkpoints at steps 2, 4 initial_run_args = f""" @@ -76,7 +76,7 @@ def test_text_to_image_checkpointing(self): --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 5 + --max_train_steps 4 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -89,7 +89,7 @@ def test_text_to_image_checkpointing(self): run_command(self._launch_args + initial_run_args) pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist self.assertEqual( @@ -100,12 +100,12 @@ def test_text_to_image_checkpointing(self): # check can run an intermediate checkpoint unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) - # Run training script for 7 total steps resuming from checkpoint 4 + # Run training script for 2 total steps resuming from checkpoint 4 resume_run_args = f""" examples/text_to_image/train_text_to_image.py @@ -116,13 +116,13 @@ def test_text_to_image_checkpointing(self): --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 7 + --max_train_steps 2 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant --lr_warmup_steps 0 --output_dir {tmpdir} - --checkpointing_steps=2 + --checkpointing_steps=1 --resume_from_checkpoint=checkpoint-4 --seed=0 """.split() @@ -131,16 +131,13 @@ def test_text_to_image_checkpointing(self): # check can run new fully trained pipeline pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - { - # no checkpoint-2 -> check old checkpoints do not exist - # check new checkpoints exist - "checkpoint-4", - "checkpoint-6", - }, + {"checkpoint-4", "checkpoint-5"}, ) def test_text_to_image_checkpointing_use_ema(self): @@ -149,7 +146,7 @@ def test_text_to_image_checkpointing_use_ema(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 5, checkpointing_steps == 2 + # max_train_steps == 4, checkpointing_steps == 2 # Should create checkpoints at steps 2, 4 initial_run_args = f""" @@ -161,7 +158,7 @@ def test_text_to_image_checkpointing_use_ema(self): --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 5 + --max_train_steps 4 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -186,12 +183,12 @@ def test_text_to_image_checkpointing_use_ema(self): # check can run an intermediate checkpoint unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) - # Run training script for 7 total steps resuming from checkpoint 4 + # Run training script for 2 total steps resuming from checkpoint 4 resume_run_args = f""" examples/text_to_image/train_text_to_image.py @@ -202,13 +199,13 @@ def test_text_to_image_checkpointing_use_ema(self): --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 7 + --max_train_steps 2 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant --lr_warmup_steps 0 --output_dir {tmpdir} - --checkpointing_steps=2 + --checkpointing_steps=1 --resume_from_checkpoint=checkpoint-4 --use_ema --seed=0 @@ -218,16 +215,13 @@ def test_text_to_image_checkpointing_use_ema(self): # check can run new fully trained pipeline pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - { - # no checkpoint-2 -> check old checkpoints do not exist - # check new checkpoints exist - "checkpoint-4", - "checkpoint-6", - }, + {"checkpoint-4", "checkpoint-5"}, ) def test_text_to_image_checkpointing_checkpoints_total_limit(self): @@ -236,7 +230,7 @@ def test_text_to_image_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2 # Should create checkpoints at steps 2, 4, 6 # with checkpoint at step 2 deleted @@ -249,7 +243,7 @@ def test_text_to_image_checkpointing_checkpoints_total_limit(self): --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 7 + --max_train_steps 6 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -263,14 +257,11 @@ def test_text_to_image_checkpointing_checkpoints_total_limit(self): run_command(self._launch_args + initial_run_args) pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) + # checkpoint-2 should have been deleted + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" @@ -278,8 +269,8 @@ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_ch with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 9, checkpointing_steps == 2 - # Should create checkpoints at steps 2, 4, 6, 8 + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 initial_run_args = f""" examples/text_to_image/train_text_to_image.py @@ -290,7 +281,7 @@ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_ch --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 9 + --max_train_steps 4 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -303,15 +294,15 @@ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_ch run_command(self._launch_args + initial_run_args) pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + {"checkpoint-2", "checkpoint-4"}, ) - # resume and we should try to checkpoint at 10, where we'll have to remove + # resume and we should try to checkpoint at 6, where we'll have to remove # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint resume_run_args = f""" @@ -323,27 +314,27 @@ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_ch --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 11 + --max_train_steps 8 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant --lr_warmup_steps 0 --output_dir {tmpdir} --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 --seed=0 """.split() run_command(self._launch_args + resume_run_args) pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + {"checkpoint-6", "checkpoint-8"}, ) diff --git a/examples/text_to_image/test_text_to_image_lora.py b/examples/text_to_image/test_text_to_image_lora.py index 83cbb78b2dc6..4daee834d0e6 100644 --- a/examples/text_to_image/test_text_to_image_lora.py +++ b/examples/text_to_image/test_text_to_image_lora.py @@ -41,7 +41,7 @@ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2 # Should create checkpoints at steps 2, 4, 6 # with checkpoint at step 2 deleted @@ -52,7 +52,7 @@ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self): --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 7 + --max_train_steps 6 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -66,14 +66,11 @@ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self): pipe = DiffusionPipeline.from_pretrained(pipeline_path) pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) + # checkpoint-2 should have been deleted + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" @@ -81,7 +78,7 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2 # Should create checkpoints at steps 2, 4, 6 # with checkpoint at step 2 deleted @@ -94,7 +91,7 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 7 + --max_train_steps 6 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -112,14 +109,11 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None ) pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) + # checkpoint-2 should have been deleted + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" @@ -127,8 +121,8 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 9, checkpointing_steps == 2 - # Should create checkpoints at steps 2, 4, 6, 8 + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 initial_run_args = f""" examples/text_to_image/train_text_to_image_lora.py @@ -139,7 +133,7 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 9 + --max_train_steps 4 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -156,15 +150,15 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None ) pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + {"checkpoint-2", "checkpoint-4"}, ) - # resume and we should try to checkpoint at 10, where we'll have to remove + # resume and we should try to checkpoint at 6, where we'll have to remove # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint resume_run_args = f""" @@ -176,15 +170,15 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip --random_flip --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 11 + --max_train_steps 8 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant --lr_warmup_steps 0 --output_dir {tmpdir} --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 --seed=0 --num_validation_images=0 """.split() @@ -195,12 +189,12 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None ) pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + {"checkpoint-6", "checkpoint-8"}, ) @@ -272,7 +266,7 @@ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_li with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2 # Should create checkpoints at steps 2, 4, 6 # with checkpoint at step 2 deleted @@ -283,7 +277,7 @@ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_li --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 7 + --max_train_steps 6 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -298,11 +292,8 @@ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_li pipe = DiffusionPipeline.from_pretrained(pipeline_path) pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) + pipe(prompt, num_inference_steps=1) # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) + # checkpoint-2 should have been deleted + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) diff --git a/examples/textual_inversion/test_textual_inversion.py b/examples/textual_inversion/test_textual_inversion.py index a5d7bcb65dd3..ba9cabd9aafe 100644 --- a/examples/textual_inversion/test_textual_inversion.py +++ b/examples/textual_inversion/test_textual_inversion.py @@ -40,8 +40,6 @@ def test_textual_inversion(self): --learnable_property object --placeholder_token --initializer_token a - --validation_prompt - --validation_steps 1 --save_steps 1 --num_vectors 2 --resolution 64 @@ -68,8 +66,6 @@ def test_textual_inversion_checkpointing(self): --learnable_property object --placeholder_token --initializer_token a - --validation_prompt - --validation_steps 1 --save_steps 1 --num_vectors 2 --resolution 64 @@ -102,14 +98,12 @@ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multipl --learnable_property object --placeholder_token --initializer_token a - --validation_prompt - --validation_steps 1 --save_steps 1 --num_vectors 2 --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 3 + --max_train_steps 2 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant @@ -123,7 +117,7 @@ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multipl # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-1", "checkpoint-2", "checkpoint-3"}, + {"checkpoint-1", "checkpoint-2"}, ) resume_run_args = f""" @@ -133,21 +127,19 @@ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multipl --learnable_property object --placeholder_token --initializer_token a - --validation_prompt - --validation_steps 1 --save_steps 1 --num_vectors 2 --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1 - --max_train_steps 4 + --max_train_steps 2 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant --lr_warmup_steps 0 --output_dir {tmpdir} --checkpointing_steps=1 - --resume_from_checkpoint=checkpoint-3 + --resume_from_checkpoint=checkpoint-2 --checkpoints_total_limit=2 """.split() @@ -156,5 +148,5 @@ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multipl # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-3", "checkpoint-4"}, + {"checkpoint-2", "checkpoint-3"}, ) diff --git a/examples/unconditional_image_generation/test_unconditional.py b/examples/unconditional_image_generation/test_unconditional.py index b7e19abe9f6e..49e11f33d4e1 100644 --- a/examples/unconditional_image_generation/test_unconditional.py +++ b/examples/unconditional_image_generation/test_unconditional.py @@ -90,10 +90,10 @@ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_ch --train_batch_size 1 --num_epochs 1 --gradient_accumulation_steps 1 - --ddpm_num_inference_steps 2 + --ddpm_num_inference_steps 1 --learning_rate 1e-3 --lr_warmup_steps 5 - --checkpointing_steps=1 + --checkpointing_steps=2 """.split() run_command(self._launch_args + initial_run_args) @@ -101,7 +101,7 @@ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_ch # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6"}, ) resume_run_args = f""" @@ -113,12 +113,12 @@ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_ch --train_batch_size 1 --num_epochs 2 --gradient_accumulation_steps 1 - --ddpm_num_inference_steps 2 + --ddpm_num_inference_steps 1 --learning_rate 1e-3 --lr_warmup_steps 5 --resume_from_checkpoint=checkpoint-6 --checkpointing_steps=2 - --checkpoints_total_limit=3 + --checkpoints_total_limit=2 """.split() run_command(self._launch_args + resume_run_args) @@ -126,5 +126,5 @@ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_ch # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, + {"checkpoint-10", "checkpoint-12"}, ) From 84c403aedb967d68785e6e5ac359746f97a483cb Mon Sep 17 00:00:00 2001 From: Jianqi Pan Date: Tue, 26 Dec 2023 00:46:57 +0900 Subject: [PATCH 39/42] fix: cannot set guidance_scale (#6326) fix: set guidance_scale --- examples/community/stable_diffusion_tensorrt_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py index e6e5e9db71d0..a391daf1062d 100755 --- a/examples/community/stable_diffusion_tensorrt_img2img.py +++ b/examples/community/stable_diffusion_tensorrt_img2img.py @@ -1004,7 +1004,7 @@ def __call__( """ self.generator = generator self.denoising_steps = num_inference_steps - self.guidance_scale = guidance_scale + self._guidance_scale = guidance_scale # Pre-compute latent input scales and linear multistep coefficients self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) From a3d31e3a3eed1465dd0eafef641a256118618d32 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 25 Dec 2023 07:59:20 -0800 Subject: [PATCH 40/42] Change LCM-LoRA README Script Example Learning Rates to 1e-4 (#6304) Change README LCM-LoRA example learning rates to 1e-4. --- examples/consistency_distillation/README.md | 2 +- examples/consistency_distillation/README_sdxl.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/README.md b/examples/consistency_distillation/README.md index d1c874147173..b8e88c741e2f 100644 --- a/examples/consistency_distillation/README.md +++ b/examples/consistency_distillation/README.md @@ -94,7 +94,7 @@ accelerate launch train_lcm_distill_lora_sd_wds.py \ --mixed_precision=fp16 \ --resolution=512 \ --lora_rank=64 \ - --learning_rate=1e-6 --loss_type="huber" --adam_weight_decay=0.0 \ + --learning_rate=1e-4 --loss_type="huber" --adam_weight_decay=0.0 \ --max_train_steps=1000 \ --max_train_samples=4000000 \ --dataloader_num_workers=8 \ diff --git a/examples/consistency_distillation/README_sdxl.md b/examples/consistency_distillation/README_sdxl.md index 4d2177669a90..16d32bcc571e 100644 --- a/examples/consistency_distillation/README_sdxl.md +++ b/examples/consistency_distillation/README_sdxl.md @@ -96,7 +96,7 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \ --mixed_precision=fp16 \ --resolution=1024 \ --lora_rank=64 \ - --learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \ + --learning_rate=1e-4 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \ --max_train_steps=1000 \ --max_train_samples=4000000 \ --dataloader_num_workers=8 \ From e0d8c910e95cba86d66e7410711c018949c3a2d3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Dec 2023 11:39:28 +0100 Subject: [PATCH 41/42] [Peft] fix saving / loading when unet is not "unet" (#6046) * [Peft] fix saving / loading when unet is not "unet" * Update src/diffusers/loaders/lora.py Co-authored-by: Sayak Paul * undo stablediffusion-xl changes * use unet_name to get unet for lora helpers * use unet_name --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/ip_adapter.py | 6 ++-- src/diffusers/loaders/lora.py | 46 ++++++++++++++++++----------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 158bde436374..3df0492380e5 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -149,9 +149,11 @@ def load_ip_adapter( self.feature_extractor = CLIPImageProcessor() # load ip-adapter into unet - self.unet._load_ip_adapter_weights(state_dict) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet._load_ip_adapter_weights(state_dict) def set_ip_adapter_scale(self, scale): - for attn_processor in self.unet.attn_processors.values(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): attn_processor.scale = scale diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index fc50c52e412b..2ceff743daca 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -912,10 +912,10 @@ def pack_weights(layers, prefix): ) if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + state_dict.update(pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) if transformer_lora_layers: state_dict.update(pack_weights(transformer_lora_layers, "transformer")) @@ -975,6 +975,8 @@ def unload_lora_weights(self): >>> ... ``` """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if not USE_PEFT_BACKEND: if version.parse(__version__) > version.parse("0.23"): logger.warn( @@ -982,13 +984,13 @@ def unload_lora_weights(self): "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." ) - for _, module in self.unet.named_modules(): + for _, module in unet.named_modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) else: - recurse_remove_peft_layers(self.unet) - if hasattr(self.unet, "peft_config"): - del self.unet.peft_config + recurse_remove_peft_layers(unet) + if hasattr(unet, "peft_config"): + del unet.peft_config # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() @@ -1027,7 +1029,8 @@ def fuse_lora( ) if fuse_unet: - self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) if USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer @@ -1080,13 +1083,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet if unfuse_unet: if not USE_PEFT_BACKEND: - self.unet.unfuse_lora() + unet.unfuse_lora() else: from peft.tuners.tuners_utils import BaseTunerLayer - for module in self.unet.modules(): + for module in unet.modules(): if isinstance(module, BaseTunerLayer): module.unmerge() @@ -1202,8 +1206,9 @@ def set_adapters( adapter_names: Union[List[str], str], adapter_weights: Optional[List[float]] = None, ): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet # Handle the UNET - self.unet.set_adapters(adapter_names, adapter_weights) + unet.set_adapters(adapter_names, adapter_weights) # Handle the Text Encoder if hasattr(self, "text_encoder"): @@ -1216,7 +1221,8 @@ def disable_lora(self): raise ValueError("PEFT backend is required for this method.") # Disable unet adapters - self.unet.disable_lora() + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.disable_lora() # Disable text encoder adapters if hasattr(self, "text_encoder"): @@ -1229,7 +1235,8 @@ def enable_lora(self): raise ValueError("PEFT backend is required for this method.") # Enable unet adapters - self.unet.enable_lora() + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.enable_lora() # Enable text encoder adapters if hasattr(self, "text_encoder"): @@ -1251,7 +1258,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): adapter_names = [adapter_names] # Delete unet adapters - self.unet.delete_adapters(adapter_names) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.delete_adapters(adapter_names) for adapter_name in adapter_names: # Delete text encoder adapters @@ -1284,8 +1292,8 @@ def get_active_adapters(self) -> List[str]: from peft.tuners.tuners_utils import BaseTunerLayer active_adapters = [] - - for module in self.unet.modules(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for module in unet.modules(): if isinstance(module, BaseTunerLayer): active_adapters = module.active_adapters break @@ -1309,8 +1317,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]: if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) - if hasattr(self, "unet") and hasattr(self.unet, "peft_config"): - set_adapters["unet"] = list(self.unet.peft_config.keys()) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"): + set_adapters[self.unet_name] = list(self.unet.peft_config.keys()) return set_adapters @@ -1331,7 +1340,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, from peft.tuners.tuners_utils import BaseTunerLayer # Handle the UNET - for unet_module in self.unet.modules(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for unet_module in unet.modules(): if isinstance(unet_module, BaseTunerLayer): for adapter_name in adapter_names: unet_module.lora_A[adapter_name].to(device) From 35b81fffaea20cca3e870a834cecef7e52a7d1d9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Dec 2023 11:40:04 +0100 Subject: [PATCH 42/42] [Wuerstchen] fix fp16 training and correct lora args (#6245) fix fp16 training Co-authored-by: Sayak Paul --- .../text_to_image/train_text_to_image_lora_prior.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index 1e67f05abe7a..f1f6b3215201 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -527,9 +527,17 @@ def deepspeed_zero_init_disabled_context_manager(): # lora attn processor prior_lora_config = LoraConfig( - r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"] + r=args.rank, + lora_alpha=args.rank, + target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], ) + # Add adapter and make sure the trainable params are in float32. prior.add_adapter(prior_lora_config) + if args.mixed_precision == "fp16": + for param in prior.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir):