diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0c05f0ef7ffa..e0aea2ce1493 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -160,6 +160,8 @@ title: xFormers - local: optimization/tome title: Token merging + - local: optimization/deepcache + title: DeepCache title: General optimizations - sections: - local: using-diffusers/stable_diffusion_jax_how_to @@ -210,6 +212,8 @@ title: Textual Inversion - local: api/loaders/unet title: UNet + - local: api/loaders/peft + title: PEFT title: Loaders - sections: - local: api/models/overview diff --git a/docs/source/en/api/loaders/peft.md b/docs/source/en/api/loaders/peft.md new file mode 100644 index 000000000000..1aff389d3142 --- /dev/null +++ b/docs/source/en/api/loaders/peft.md @@ -0,0 +1,25 @@ + + +# PEFT + +Diffusers supports working with adapters (such as [LoRA](../../using-diffusers/loading_adapters)) via the [`peft` library](https://huggingface.co/docs/peft/index). We provide a `PeftAdapterMixin` class to handle this for modeling classes in Diffusers (such as [`UNet2DConditionModel`]). + + + +Refer to [this doc](../../tutorials/using_peft_for_inference.md) to get an overview of how to work with `peft` in Diffusers for inference. + + + +## PeftAdapterMixin + +[[autodoc]] loaders.peft.PeftAdapterMixin \ No newline at end of file diff --git a/docs/source/en/api/pipelines/amused.md b/docs/source/en/api/pipelines/amused.md index 122ecb4d99e6..9a7b81069401 100644 --- a/docs/source/en/api/pipelines/amused.md +++ b/docs/source/en/api/pipelines/amused.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # aMUSEd -aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2302.05543) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen. +aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen. Amused is a lightweight text to image model based off of the [MUSE](https://arxiv.org/abs/2301.00704) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once. diff --git a/docs/source/en/optimization/deepcache.md b/docs/source/en/optimization/deepcache.md new file mode 100644 index 000000000000..3ad513357def --- /dev/null +++ b/docs/source/en/optimization/deepcache.md @@ -0,0 +1,62 @@ + + +# DeepCache +[DeepCache](https://huggingface.co/papers/2312.00858) accelerates [`StableDiffusionPipeline`] and [`StableDiffusionXLPipeline`] by strategically caching and reusing high-level features while efficiently updating low-level features by taking advantage of the U-Net architecture. + +Start by installing [DeepCache](https://github.com/horseee/DeepCache): +```bash +pip install DeepCache +``` + +Then load and enable the [`DeepCacheSDHelper`](https://github.com/horseee/DeepCache#usage): + +```diff + import torch + from diffusers import StableDiffusionPipeline + pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to("cuda") + ++ from DeepCache import DeepCacheSDHelper ++ helper = DeepCacheSDHelper(pipe=pipe) ++ helper.set_params( ++ cache_interval=3, ++ cache_branch_id=0, ++ ) ++ helper.enable() + + image = pipe("a photo of an astronaut on a moon").images[0] +``` + +The `set_params` method accepts two arguments: `cache_interval` and `cache_branch_id`. `cache_interval` means the frequency of feature caching, specified as the number of steps between each cache operation. `cache_branch_id` identifies which branch of the network (ordered from the shallowest to the deepest layer) is responsible for executing the caching processes. +Opting for a lower `cache_branch_id` or a larger `cache_interval` can lead to faster inference speed at the expense of reduced image quality (ablation experiments of these two hyperparameters can be found in the [paper](https://arxiv.org/abs/2312.00858)). Once those arguments are set, use the `enable` or `disable` methods to activate or deactivate the `DeepCacheSDHelper`. + +
+ +
+ +You can find more generated samples (original pipeline vs DeepCache) and the corresponding inference latency in the [WandB report](https://wandb.ai/horseee/DeepCache/runs/jwlsqqgt?workspace=user-horseee). The prompts are randomly selected from the [MS-COCO 2017](https://cocodataset.org/#home) dataset. + +## Benchmark + +We tested how much faster DeepCache accelerates [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) with 50 inference steps on an NVIDIA RTX A5000, using different configurations for resolution, batch size, cache interval (I), and cache branch (B). + +| **Resolution** | **Batch size** | **Original** | **DeepCache(I=3, B=0)** | **DeepCache(I=5, B=0)** | **DeepCache(I=5, B=1)** | +|----------------|----------------|--------------|-------------------------|-------------------------|-------------------------| +| 512| 8| 15.96| 6.88(2.32x)| 5.03(3.18x)| 7.27(2.20x)| +| | 4| 8.39| 3.60(2.33x)| 2.62(3.21x)| 3.75(2.24x)| +| | 1| 2.61| 1.12(2.33x)| 0.81(3.24x)| 1.11(2.35x)| +| 768| 8| 43.58| 18.99(2.29x)| 13.96(3.12x)| 21.27(2.05x)| +| | 4| 22.24| 9.67(2.30x)| 7.10(3.13x)| 10.74(2.07x)| +| | 1| 6.33| 2.72(2.33x)| 1.97(3.21x)| 2.98(2.12x)| +| 1024| 8| 101.95| 45.57(2.24x)| 33.72(3.02x)| 53.00(1.92x)| +| | 4| 49.25| 21.86(2.25x)| 16.19(3.04x)| 25.78(1.91x)| +| | 1| 13.83| 6.07(2.28x)| 4.43(3.12x)| 7.15(1.93x)| diff --git a/docs/source/en/tutorials/fast_diffusion.md b/docs/source/en/tutorials/fast_diffusion.md index f258e35f43d7..cc6db4af4259 100644 --- a/docs/source/en/tutorials/fast_diffusion.md +++ b/docs/source/en/tutorials/fast_diffusion.md @@ -14,15 +14,15 @@ specific language governing permissions and limitations under the License. Diffusion models are known to be slower than their counter parts, GANs, because of the iterative and sequential reverse diffusion process. Recent works try to address limitation with: -* progressive timestep distillation (such as [LCM LoRA](../using-diffusers/inference_with_lcm_lora.md)) +* progressive timestep distillation (such as [LCM LoRA](../using-diffusers/inference_with_lcm_lora)) * model compression (such as [SSD-1B](https://huggingface.co/segmind/SSD-1B)) * reusing adjacent features of the denoiser (such as [DeepCache](https://github.com/horseee/DeepCache)) -In this tutorial, we focus on leveraging the power of PyTorch 2 to accelerate the inference latency of text-to-image diffusion pipeline, instead. We will use [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl.md) as a case study, but the techniques we will discuss should extend to other text-to-image diffusion pipelines. +In this tutorial, we focus on leveraging the power of PyTorch 2 to accelerate the inference latency of text-to-image diffusion pipeline, instead. We will use [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl) as a case study, but the techniques we will discuss should extend to other text-to-image diffusion pipelines. ## Setup -Make sure you're on the latest version of `diffusers`: +Make sure you're on the latest version of `diffusers`: ```bash pip install -U diffusers @@ -42,7 +42,7 @@ _This tutorial doesn't present the benchmarking code and focuses on how to perfo ## Baseline -Let's start with a baseline. Disable the use of a reduced precision and [`scaled_dot_product_attention`](../optimization/torch2.0.md): +Let's start with a baseline. Disable the use of a reduced precision and [`scaled_dot_product_attention`](../optimization/torch2.0): ```python from diffusers import StableDiffusionXLPipeline @@ -104,11 +104,11 @@ _(We later ran the experiments in float16 and found out that the recent versions * The benefits of using the bfloat16 numerical precision as compared to float16 are hardware-dependent. Modern generations of GPUs tend to favor bfloat16. * Furthermore, in our experiments, we bfloat16 to be much more resilient when used with quantization in comparison to float16. -We have a [dedicated guide](../optimization/fp16.md) for running inference in a reduced precision. +We have a [dedicated guide](../optimization/fp16) for running inference in a reduced precision. ## Running attention efficiently -Attention blocks are intensive to run. But with PyTorch's [`scaled_dot_product_attention`](../optimization/torch2.0.md), we can run them efficiently. +Attention blocks are intensive to run. But with PyTorch's [`scaled_dot_product_attention`](../optimization/torch2.0), we can run them efficiently. ```python from diffusers import StableDiffusionXLPipeline @@ -200,7 +200,7 @@ It provides a minor boost from 2.54 seconds to 2.52 seconds. -Support for `fuse_qkv_projections()` is limited and experimental. As such, it's not available for many non-SD pipelines such as [Kandinsky](../using-diffusers/kandinsky.md). You can refer to [this PR](https://github.com/huggingface/diffusers/pull/6179) to get an idea about how to support this kind of computation. +Support for `fuse_qkv_projections()` is limited and experimental. As such, it's not available for many non-SD pipelines such as [Kandinsky](../using-diffusers/kandinsky). You can refer to [this PR](https://github.com/huggingface/diffusers/pull/6179) to get an idea about how to support this kind of computation. diff --git a/docs/source/en/using-diffusers/svd.md b/docs/source/en/using-diffusers/svd.md index 8b9beb0b2fd1..fa86a0e1bb3a 100644 --- a/docs/source/en/using-diffusers/svd.md +++ b/docs/source/en/using-diffusers/svd.md @@ -53,11 +53,6 @@ frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0] export_to_video(frames, "generated.mp4", fps=7) ``` - - | **Source Image** | **Video** | |:------------:|:-----:| | ![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png) | ![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/output_rocket.gif) | @@ -86,7 +81,7 @@ You can achieve a 20-25% speed-up at the expense of slightly increased memory by Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement: - enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore. - enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size -- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration +- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note that**, in addition to leading to a small slowdown, this method also slightly leads to video quality deterioration. You can enable them as follows: diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 58df031c3f83..ddd8114ae4f2 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -70,7 +70,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) @@ -1316,13 +1316,15 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + if args.train_text_encoder: + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + if args.train_text_encoder: + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1335,6 +1337,8 @@ def save_model_hook(models, weights, output_dir): text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, ) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings(f"{output_dir}/{args.output_dir}_emb.safetensors") def load_model_hook(models, input_dir): unet_ = None diff --git a/examples/community/README.md b/examples/community/README.md index c3aa1ecf3d64..3baab2025880 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -21,6 +21,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) | | Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | | Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) | +| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | - | [Phạm Hồng Vinh](https://github.com/rootonchair) | | Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) | | Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) | | Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) | @@ -54,6 +55,8 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap | LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) | | AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) | | DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) | +| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) | + To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -738,6 +741,49 @@ grid = image_grid(images, rows=2, cols=2) This example produces the following images: ![image](https://user-images.githubusercontent.com/4313860/198328706-295824a4-9856-4ce5-8e66-278ceb42fd29.png) +### GlueGen Stable Diffusion Pipeline +GlueGen is a minimal adapter that allow alignment between any encoder (Text Encoder of different language, Multilingual Roberta, AudioClip) and CLIP text encoder used in standard Stable Diffusion model. This method allows easy language adaptation to available english Stable Diffusion checkpoints without the need of an image captioning dataset as well as long training hours. + +Make sure you downloaded `gluenet_French_clip_overnorm_over3_noln.ckpt` for French (there are also pre-trained weights for Chinese, Italian, Japanese, Spanish or train your own) at [GlueGen's official repo](https://github.com/salesforce/GlueGen/tree/main) + +```python +from PIL import Image + +import torch + +from transformers import AutoModel, AutoTokenizer + +from diffusers import DiffusionPipeline + +if __name__ == "__main__": + device = "cuda" + + lm_model_id = "xlm-roberta-large" + token_max_length = 77 + + text_encoder = AutoModel.from_pretrained(lm_model_id) + tokenizer = AutoTokenizer.from_pretrained(lm_model_id, model_max_length=token_max_length, use_fast=False) + + tensor_norm = torch.Tensor([[43.8203],[28.3668],[27.9345],[28.0084],[28.2958],[28.2576],[28.3373],[28.2695],[28.4097],[28.2790],[28.2825],[28.2807],[28.2775],[28.2708],[28.2682],[28.2624],[28.2589],[28.2611],[28.2616],[28.2639],[28.2613],[28.2566],[28.2615],[28.2665],[28.2799],[28.2885],[28.2852],[28.2863],[28.2780],[28.2818],[28.2764],[28.2532],[28.2412],[28.2336],[28.2514],[28.2734],[28.2763],[28.2977],[28.2971],[28.2948],[28.2818],[28.2676],[28.2831],[28.2890],[28.2979],[28.2999],[28.3117],[28.3363],[28.3554],[28.3626],[28.3589],[28.3597],[28.3543],[28.3660],[28.3731],[28.3717],[28.3812],[28.3753],[28.3810],[28.3777],[28.3693],[28.3713],[28.3670],[28.3691],[28.3679],[28.3624],[28.3703],[28.3703],[28.3720],[28.3594],[28.3576],[28.3562],[28.3438],[28.3376],[28.3389],[28.3433],[28.3191]]) + + pipeline = DiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + text_encoder=text_encoder, + tokenizer=tokenizer, + custom_pipeline="gluegen" + ).to(device) + pipeline.load_language_adapter("gluenet_French_clip_overnorm_over3_noln.ckpt", num_token=token_max_length, dim=1024, dim_out=768, tensor_norm=tensor_norm) + + prompt = "une voiture sur la plage" + + generator = torch.Generator(device=device).manual_seed(42) + image = pipeline(prompt, generator=generator).images[0] + image.save("gluegen_output_fr.png") +``` +Which will produce: + +![output_image](https://github.com/rootonchair/diffusers/assets/23548268/db43ffb6-8667-47c1-8872-26f85dc0a57f) + ### Image to Image Inpainting Stable Diffusion Similar to the standard stable diffusion inpainting example, except with the addition of an `inner_image` argument. @@ -3102,3 +3148,42 @@ output_image = PIL.Image.fromarray(output) output_image.save("./output.png") ``` + +### Null-Text Inversion pipeline + +This pipeline provides null-text inversion for editing real images. It enables null-text optimization, and DDIM reconstruction via w, w/o null-text optimization. No prompt-to-prompt code is implemented as there is a Prompt2PromptPipeline. +* Reference paper + ```@article{hertz2022prompt, + title={Prompt-to-prompt image editing with cross attention control}, + author={Hertz, Amir and Mokady, Ron and Tenenbaum, Jay and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel}, + booktitle={arXiv preprint arXiv:2208.01626}, + year={2022} + ```} + +```py +from diffusers.schedulers import DDIMScheduler +from examples.community.pipeline_null_text_inversion import NullTextPipeline +import torch + +# Load the pipeline +device = "cuda" +# Provide invert_prompt and the image for null-text optimization. +invert_prompt = "A lying cat" +input_image = "siamese.jpg" +steps = 50 + +# Provide prompt used for generation. Same if reconstruction +prompt = "A lying cat" +# or different if editing. +prompt = "A lying dog" + +#Float32 is essential to a well optimization +model_path = "runwayml/stable-diffusion-v1-5" +scheduler = DDIMScheduler(num_train_timesteps=1000, beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear") +pipeline = NullTextPipeline.from_pretrained(model_path, scheduler = scheduler, torch_dtype=torch.float32).to(device) + +#Saves the inverted_latent to save time +inverted_latent, uncond = pipeline.invert(input_image, invert_prompt, num_inner_steps=10, early_stop_epsilon= 1e-5, num_inference_steps = steps) +pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_steps=steps).images[0].save(input_image+".output.jpg") + +``` \ No newline at end of file diff --git a/examples/community/gluegen.py b/examples/community/gluegen.py new file mode 100644 index 000000000000..ecfe91eb9483 --- /dev/null +++ b/examples/community/gluegen.py @@ -0,0 +1,865 @@ +import inspect +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor + +from diffusers import DiffusionPipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import LoraLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TranslatorBase(nn.Module): + def __init__(self, num_tok, dim, dim_out, mult=2): + super().__init__() + + self.dim_in = dim + self.dim_out = dim_out + + self.net_tok = nn.Sequential( + nn.Linear(num_tok, int(num_tok * mult)), + nn.LayerNorm(int(num_tok * mult)), + nn.GELU(), + nn.Linear(int(num_tok * mult), int(num_tok * mult)), + nn.LayerNorm(int(num_tok * mult)), + nn.GELU(), + nn.Linear(int(num_tok * mult), num_tok), + nn.LayerNorm(num_tok), + ) + + self.net_sen = nn.Sequential( + nn.Linear(dim, int(dim * mult)), + nn.LayerNorm(int(dim * mult)), + nn.GELU(), + nn.Linear(int(dim * mult), int(dim * mult)), + nn.LayerNorm(int(dim * mult)), + nn.GELU(), + nn.Linear(int(dim * mult), dim_out), + nn.LayerNorm(dim_out), + ) + + def forward(self, x): + if self.dim_in == self.dim_out: + indentity_0 = x + x = self.net_sen(x) + x += indentity_0 + x = x.transpose(1, 2) + + indentity_1 = x + x = self.net_tok(x) + x += indentity_1 + x = x.transpose(1, 2) + else: + x = self.net_sen(x) + x = x.transpose(1, 2) + + x = self.net_tok(x) + x = x.transpose(1, 2) + return x + + +class TranslatorBaseNoLN(nn.Module): + def __init__(self, num_tok, dim, dim_out, mult=2): + super().__init__() + + self.dim_in = dim + self.dim_out = dim_out + + self.net_tok = nn.Sequential( + nn.Linear(num_tok, int(num_tok * mult)), + nn.GELU(), + nn.Linear(int(num_tok * mult), int(num_tok * mult)), + nn.GELU(), + nn.Linear(int(num_tok * mult), num_tok), + ) + + self.net_sen = nn.Sequential( + nn.Linear(dim, int(dim * mult)), + nn.GELU(), + nn.Linear(int(dim * mult), int(dim * mult)), + nn.GELU(), + nn.Linear(int(dim * mult), dim_out), + ) + + def forward(self, x): + if self.dim_in == self.dim_out: + indentity_0 = x + x = self.net_sen(x) + x += indentity_0 + x = x.transpose(1, 2) + + indentity_1 = x + x = self.net_tok(x) + x += indentity_1 + x = x.transpose(1, 2) + else: + x = self.net_sen(x) + x = x.transpose(1, 2) + + x = self.net_tok(x) + x = x.transpose(1, 2) + return x + + +class TranslatorNoLN(nn.Module): + def __init__(self, num_tok, dim, dim_out, mult=2, depth=5): + super().__init__() + + self.blocks = nn.ModuleList([TranslatorBase(num_tok, dim, dim, mult=2) for d in range(depth)]) + self.gelu = nn.GELU() + + self.tail = TranslatorBaseNoLN(num_tok, dim, dim_out, mult=2) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x + x = self.gelu(x) + + x = self.tail(x) + return x + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class GlueGenStableDiffusionPipeline(DiffusionPipeline, LoraLoaderMixin): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: AutoModel, + tokenizer: AutoTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + language_adapter: TranslatorNoLN = None, + tensor_norm: torch.FloatTensor = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + language_adapter=language_adapter, + tensor_norm=tensor_norm, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def load_language_adapter( + self, + model_path: str, + num_token: int, + dim: int, + dim_out: int, + tensor_norm: torch.FloatTensor, + mult: int = 2, + depth: int = 5, + ): + device = self._execution_device + self.tensor_norm = tensor_norm.to(device) + self.language_adapter = TranslatorNoLN(num_tok=num_token, dim=dim, dim_out=dim_out, mult=mult, depth=depth).to( + device + ) + self.language_adapter.load_state_dict(torch.load(model_path)) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _adapt_language(self, prompt_embeds: torch.FloatTensor): + prompt_embeds = prompt_embeds / 3 + prompt_embeds = self.language_adapter(prompt_embeds) * (self.tensor_norm / 2) + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + elif self.language_adapter is not None: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + # Run prompt language adapter + if self.language_adapter is not None: + prompt_embeds = self._adapt_language(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + # Run negative prompt language adapter + if self.language_adapter is not None: + negative_prompt_embeds = self._adapt_language(negative_prompt_embeds) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py index 31da842112fb..05fba7daa1a9 100644 --- a/examples/community/marigold_depth_estimation.py +++ b/examples/community/marigold_depth_estimation.py @@ -40,7 +40,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.20.1.dev0") +check_min_version("0.26.0.dev0") class MarigoldDepthOutput(BaseOutput): diff --git a/examples/community/pipeline_null_text_inversion.py b/examples/community/pipeline_null_text_inversion.py new file mode 100644 index 000000000000..7e27b4647bc9 --- /dev/null +++ b/examples/community/pipeline_null_text_inversion.py @@ -0,0 +1,260 @@ +import inspect +import os + +import numpy as np +import torch +import torch.nn.functional as nnf +from PIL import Image +from torch.optim.adam import Adam +from tqdm import tqdm + +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class NullTextPipeline(StableDiffusionPipeline): + def get_noise_pred(self, latents, t, context): + latents_input = torch.cat([latents] * 2) + guidance_scale = 7.5 + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"] + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + latents = self.prev_step(noise_pred, t, latents) + return latents + + def get_noise_pred_single(self, latents, t, context): + noise_pred = self.unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + @torch.no_grad() + def image2latent(self, image_path): + image = Image.open(image_path).convert("RGB") + image = np.array(image) + image = torch.from_numpy(image).float() / 127.5 - 1 + image = image.permute(2, 0, 1).unsqueeze(0).to(self.device) + latents = self.vae.encode(image)["latent_dist"].mean + latents = latents * 0.18215 + return latents + + @torch.no_grad() + def latent2image(self, latents): + latents = 1 / 0.18215 * latents.detach() + image = self.vae.decode(latents)["sample"].detach() + image = self.processor.postprocess(image, output_type="pil")[0] + return image + + def prev_step(self, model_output, timestep, sample): + prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output + prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + return prev_sample + + def next_step(self, model_output, timestep, sample): + timestep, next_timestep = ( + min(timestep - self.scheduler.config.num_train_timesteps // self.num_inference_steps, 999), + timestep, + ) + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction + return next_sample + + def null_optimization(self, latents, context, num_inner_steps, epsilon): + uncond_embeddings, cond_embeddings = context.chunk(2) + uncond_embeddings_list = [] + latent_cur = latents[-1] + bar = tqdm(total=num_inner_steps * self.num_inference_steps) + for i in range(self.num_inference_steps): + uncond_embeddings = uncond_embeddings.clone().detach() + uncond_embeddings.requires_grad = True + optimizer = Adam([uncond_embeddings], lr=1e-2 * (1.0 - i / 100.0)) + latent_prev = latents[len(latents) - i - 2] + t = self.scheduler.timesteps[i] + with torch.no_grad(): + noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) + for j in range(num_inner_steps): + noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) + noise_pred = noise_pred_uncond + 7.5 * (noise_pred_cond - noise_pred_uncond) + latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) + loss = nnf.mse_loss(latents_prev_rec, latent_prev) + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_item = loss.item() + bar.update() + if loss_item < epsilon + i * 2e-5: + break + for j in range(j + 1, num_inner_steps): + bar.update() + uncond_embeddings_list.append(uncond_embeddings[:1].detach()) + with torch.no_grad(): + context = torch.cat([uncond_embeddings, cond_embeddings]) + latent_cur = self.get_noise_pred(latent_cur, t, context) + bar.close() + return uncond_embeddings_list + + @torch.no_grad() + def ddim_inversion_loop(self, latent, context): + self.scheduler.set_timesteps(self.num_inference_steps) + _, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + with torch.no_grad(): + for i in range(0, self.num_inference_steps): + t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1] + noise_pred = self.unet(latent, t, encoder_hidden_states=cond_embeddings)["sample"] + latent = self.next_step(noise_pred, t, latent) + all_latent.append(latent) + return all_latent + + def get_context(self, prompt): + uncond_input = self.tokenizer( + [""], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + text_input = self.tokenizer( + [prompt], + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + return context + + def invert( + self, image_path: str, prompt: str, num_inner_steps=10, early_stop_epsilon=1e-6, num_inference_steps=50 + ): + self.num_inference_steps = num_inference_steps + context = self.get_context(prompt) + latent = self.image2latent(image_path) + ddim_latents = self.ddim_inversion_loop(latent, context) + if os.path.exists(image_path + ".pt"): + uncond_embeddings = torch.load(image_path + ".pt") + else: + uncond_embeddings = self.null_optimization(ddim_latents, context, num_inner_steps, early_stop_epsilon) + uncond_embeddings = torch.stack(uncond_embeddings, 0) + torch.save(uncond_embeddings, image_path + ".pt") + return ddim_latents[-1], uncond_embeddings + + @torch.no_grad() + def __call__( + self, + prompt, + uncond_embeddings, + inverted_latent, + num_inference_steps: int = 50, + timesteps=None, + guidance_scale=7.5, + negative_prompt=None, + num_images_per_prompt=1, + generator=None, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + output_type="pil", + ): + self._guidance_scale = guidance_scale + # 0. Default height and width to unet + height = self.unet.config.sample_size * self.vae_scale_factor + width = self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hook + callback_steps = None + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + # 2. Define call parameter + device = self._execution_device + # 3. Encode input prompt + prompt_embeds, _ = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + latents = inverted_latent + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=uncond_embeddings[i])["sample"] + noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)["sample"] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + progress_bar.update() + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + else: + image = latents + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=[True] * image.shape[0] + ) + # Offload all models + self.maybe_free_model_hooks() + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index c85e2c462b04..d3d0d947ecc7 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -61,6 +61,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -71,7 +72,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) @@ -165,6 +166,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 512, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -174,10 +176,12 @@ def __init__( # flatten list using itertools train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -353,8 +357,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -572,6 +577,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -710,6 +724,50 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=32, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -915,9 +973,10 @@ def main(args): ) # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -932,7 +991,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet = get_peft_model(unet, lora_config) @@ -1051,6 +1115,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, @@ -1162,10 +1227,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok if vae.dtype != weight_dtype: vae.to(dtype=weight_dtype) - # encode pixel values with batch size of at most 32 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 32): - latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1181,9 +1246,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 95e7b2dbaa27..a63654bc998b 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -51,6 +51,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -59,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.24.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) @@ -193,8 +194,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -396,6 +398,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -534,6 +545,50 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -776,10 +831,10 @@ def main(args): text_encoder_two.to(accelerator.device, dtype=weight_dtype) # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - lora_alpha=args.lora_rank, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -794,7 +849,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet.add_adapter(lora_config) @@ -929,7 +989,8 @@ def load_model_hook(models, input_dir): ) # Preprocessing the datasets. - train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + interpolation_mode = resolve_interpolation_mode(args.interpolation_type) + train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) @@ -1121,11 +1182,11 @@ def compute_time_ids(original_size, crops_coords_top_left): encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size pixel_values = pixel_values.to(dtype=vae.dtype) latents = [] - for i in range(0, pixel_values.shape[0], args.encode_batch_size): - latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1142,9 +1203,13 @@ def compute_time_ids(original_size, crops_coords_top_left): timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 75671c18c5e0..7e2a3604af55 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -62,6 +62,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -77,7 +78,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) @@ -171,6 +172,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 1024, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -187,10 +189,12 @@ def get_orig_size(json): else: return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -340,8 +344,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -546,6 +551,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--use_fix_crop_and_size", action="store_true", @@ -690,6 +704,50 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -929,9 +987,10 @@ def main(args): ) # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -946,7 +1005,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet = get_peft_model(unet, lora_config) @@ -1090,6 +1154,7 @@ def compute_embeddings( global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, @@ -1214,10 +1279,10 @@ def compute_embeddings( else: pixel_values = image - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 8): - latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1234,9 +1299,13 @@ def compute_embeddings( timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index b2085b7044ba..3a07b3cf34dc 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -60,6 +60,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -70,7 +71,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) @@ -147,6 +148,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 512, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -156,10 +158,12 @@ def __init__( # flatten list using itertools train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -330,8 +334,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -549,6 +554,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -690,6 +704,26 @@ def parse_args(): " does not have `time_cond_proj_dim` set." ), ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=32, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1034,6 +1068,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, @@ -1145,10 +1180,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok if vae.dtype != weight_dtype: vae.to(dtype=weight_dtype) - # encode pixel values with batch size of at most 32 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 32): - latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1164,9 +1199,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index ee86def673fa..5d2442a4e41a 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -61,6 +61,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -76,7 +77,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) @@ -153,6 +154,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 1024, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -169,10 +171,12 @@ def get_orig_size(json): else: return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -318,8 +322,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -568,6 +573,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--use_fix_crop_and_size", action="store_true", @@ -715,6 +729,26 @@ def parse_args(): " does not have `time_cond_proj_dim` set." ), ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1118,6 +1152,7 @@ def compute_embeddings( global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, @@ -1242,10 +1277,10 @@ def compute_embeddings( else: pixel_values = image - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 8): - latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1262,9 +1297,13 @@ def compute_embeddings( timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 2ba6565abf81..2e610a6fe404 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -56,7 +56,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index b3c09325fc4d..1a637395d850 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -59,7 +59,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 70d9c52eefea..e8d7607fdb4c 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -58,7 +58,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 623ce3704af6..e48bd2586299 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -62,7 +62,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c37ee87d8899..f652b1e79bcc 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 6244fa1c1a9c..c2fbcb01fa60 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -35,7 +35,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") # Cache compiled models across invocations of this script. cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index f0bc78da3f49..8d684b0ad27f 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -59,7 +59,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index aa6e7d21aa73..0b7bd64e9091 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -59,7 +59,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index ae0f216b97cd..78cb7bc2f9d0 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -52,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index e764aa63bd4d..de59cb1f0be9 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -55,7 +55,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index a4017c85e1b5..0a3a130a27bc 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -52,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 90cf540c6425..ed06bff29403 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -46,7 +46,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index b64986ecf5ae..18f1fd8c8bc0 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -46,7 +46,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index a6855abcee75..e9278260fad3 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -51,7 +51,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/research_projects/diffusion_dpo/README.md b/examples/research_projects/diffusion_dpo/README.md new file mode 100644 index 000000000000..c80d97ea2629 --- /dev/null +++ b/examples/research_projects/diffusion_dpo/README.md @@ -0,0 +1,66 @@ +# Diffusion Model Alignment Using Direct Preference Optimization + +This directory provides LoRA implementations of Diffusion DPO proposed in [DiffusionModel Alignment Using Direct Preference Optimization](https://arxiv.org/abs/2311.12908) by Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik. + +We provide implementations for both Stable Diffusion (SD) and Stable Diffusion XL (SDXL). The original checkpoints are available at the URLs below: + +* [mhdang/dpo-sd1.5-text2image-v1](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1) +* [mhdang/dpo-sdxl-text2image-v1](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1) + +> 💡 Note: The scripts are highly experimental and were only tested on low-data regimes. Proceed with caution. Feel free to let us know about your findings via GitHub issues. + +## SD training command + +```bash +accelerate launch train_diffusion_dpo.py \ + --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \ + --output_dir="diffusion-dpo" \ + --mixed_precision="fp16" \ + --dataset_name=kashif/pickascore \ + --resolution=512 \ + --train_batch_size=16 \ + --gradient_accumulation_steps=2 \ + --gradient_checkpointing \ + --use_8bit_adam \ + --rank=8 \ + --learning_rate=1e-5 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=10000 \ + --checkpointing_steps=2000 \ + --run_validation --validation_steps=200 \ + --seed="0" \ + --report_to="wandb" \ + --push_to_hub +``` + +## SDXL training command + +```bash +accelerate launch train_diffusion_dpo_sdxl.py \ + --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \ + --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ + --output_dir="diffusion-sdxl-dpo" \ + --mixed_precision="fp16" \ + --dataset_name=kashif/pickascore \ + --train_batch_size=8 \ + --gradient_accumulation_steps=2 \ + --gradient_checkpointing \ + --use_8bit_adam \ + --rank=8 \ + --learning_rate=1e-5 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=2000 \ + --checkpointing_steps=500 \ + --run_validation --validation_steps=50 \ + --seed="0" \ + --report_to="wandb" \ + --push_to_hub +``` + +## Acknowledgements + +This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/. diff --git a/examples/research_projects/diffusion_dpo/requirements.txt b/examples/research_projects/diffusion_dpo/requirements.txt new file mode 100644 index 000000000000..11a5b9df6dad --- /dev/null +++ b/examples/research_projects/diffusion_dpo/requirements.txt @@ -0,0 +1,8 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 +peft +wandb \ No newline at end of file diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py new file mode 100644 index 000000000000..06bdd6af6054 --- /dev/null +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py @@ -0,0 +1,953 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 bram-w, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import contextlib +import io +import logging +import math +import os +import shutil +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +import wandb +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.loaders import LoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) + + +VALIDATION_PROMPTS = [ + "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", + "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", +] + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False): + logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") + + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + if not is_final_validation: + pipeline.unet = accelerator.unwrap_model(unet) + else: + pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [] + context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast() + + for prompt in VALIDATION_PROMPTS: + with context: + image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + tracker_key: [ + wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images) + ] + } + ) + + # Also log images without the LoRA params for comparison. + if is_final_validation: + pipeline.disable_lora() + no_lora_images = [ + pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in no_lora_images]) + tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test_without_lora": [ + wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") + for i, image in enumerate(no_lora_images) + ] + } + ) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_split_name", + type=str, + default="validation", + help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--run_validation", + default=False, + action="store_true", + help="Whether to run validation inference in between training and also after training. Helps to track progress.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="diffusion-dpo-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + help="Batch size to use for VAE encoding of the images for efficient processing.", + ) + parser.add_argument( + "--no_hflip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--random_crop", + default=False, + action="store_true", + help=( + "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--beta_dpo", + type=int, + default=5000, + help="DPO KL Divergence penalty.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None: + raise ValueError("Must provide a `dataset_name`.") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +def tokenize_captions(tokenizer, examples): + max_length = tokenizer.model_max_length + captions = [] + for caption in examples["caption"]: + captions.append(caption) + + text_inputs = tokenizer( + captions, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt" + ) + + return text_inputs.input_ids + + +@torch.no_grad() +def encode_prompt(text_encoder, input_ids): + text_input_ids = input_ids.to(text_encoder.device) + attention_mask = None + + prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # Set up LoRA. + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + # Add adapter and make sure the trainable params are in float32. + unet.add_adapter(unet_lora_config) + if args.mixed_precision == "fp16": + for param in unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=None, + ) + + def load_model_hook(models, input_dir): + unet_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = load_dataset( + args.dataset_name, + cache_dir=args.cache_dir, + split=args.dataset_split_name, + ) + + train_transforms = transforms.Compose( + [ + transforms.Resize(int(args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution), + transforms.Lambda(lambda x: x) if args.no_hflip else transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + all_pixel_values = [] + for col_name in ["jpg_0", "jpg_1"]: + images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]] + pixel_values = [train_transforms(image) for image in images] + all_pixel_values.append(pixel_values) + + # Double on channel dim, jpg_y then jpg_w + im_tup_iterator = zip(*all_pixel_values) + combined_pixel_values = [] + for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]): + if label_0 == 0: + im_tup = im_tup[::-1] + combined_im = torch.cat(im_tup, dim=0) # no batch dim + combined_pixel_values.append(combined_im) + examples["pixel_values"] = combined_pixel_values + + examples["input_ids"] = tokenize_captions(tokenizer, examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = train_dataset.with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + final_dict = {"pixel_values": pixel_values} + final_dict["input_ids"] = torch.stack([example["input_ids"] for example in examples]) + return final_dict + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("diffusion-dpo-lora", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + unet.train() + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w) + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1)) + + latents = [] + for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size): + latents.append( + vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample() + ) + latents = torch.cat(latents, dim=0) + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1) + + # Sample a random timestep for each image + bsz = latents.shape[0] // 2 + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long + ).repeat(2) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = encode_prompt(text_encoder, batch["input_ids"]).repeat(2, 1, 1) + + # Predict the noise residual + model_pred = unet( + noisy_model_input, + timesteps, + encoder_hidden_states, + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Compute losses. + model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none") + model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape)))) + model_losses_w, model_losses_l = model_losses.chunk(2) + + # For logging + raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean()) + model_diff = model_losses_w - model_losses_l # These are both LBS (as is t) + + # Reference model predictions. + accelerator.unwrap_model(unet).disable_adapters() + with torch.no_grad(): + ref_preds = unet( + noisy_model_input, + timesteps, + encoder_hidden_states, + ).sample.detach() + ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none") + ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape)))) + + ref_losses_w, ref_losses_l = ref_loss.chunk(2) + ref_diff = ref_losses_w - ref_losses_l + raw_ref_loss = ref_loss.mean() + + # Re-enable adapters. + accelerator.unwrap_model(unet).enable_adapters() + + # Final loss. + scale_term = -0.5 * args.beta_dpo + inside_term = scale_term * (model_diff - ref_diff) + loss = -1 * F.logsigmoid(inside_term.mean()) + + implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) + implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.run_validation and global_step % args.validation_steps == 0: + log_validation( + args, unet=unet, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch + ) + + logs = { + "loss": loss.detach().item(), + "raw_model_loss": raw_model_loss.detach().item(), + "ref_loss": raw_ref_loss.detach().item(), + "implicit_acc": implicit_acc.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet = unet.to(torch.float32) + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=None + ) + + # Final validation? + if args.run_validation: + log_validation( + args, + unet=None, + accelerator=accelerator, + weight_dtype=weight_dtype, + epoch=epoch, + is_final_validation=True, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py new file mode 100644 index 000000000000..0d109f741fb7 --- /dev/null +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -0,0 +1,1073 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 bram-w, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import contextlib +import io +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +import wandb +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import StableDiffusionXLLoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) + + +VALIDATION_PROMPTS = [ + "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", + "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", +] + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): + logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") + + if is_final_validation: + if args.mixed_precision == "fp16": + vae.to(weight_dtype) + + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + if not is_final_validation: + pipeline.unet = accelerator.unwrap_model(unet) + else: + pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [] + context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast() + + for prompt in VALIDATION_PROMPTS: + with context: + image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + tracker_key: [ + wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images) + ] + } + ) + + # Also log images without the LoRA params for comparison. + if is_final_validation: + pipeline.disable_lora() + no_lora_images = [ + pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in no_lora_images]) + tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test_without_lora": [ + wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") + for i, image in enumerate(no_lora_images) + ] + } + ) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_split_name", + type=str, + default="validation", + help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--run_validation", + default=False, + action="store_true", + help="Whether to run validation inference in between training and also after training. Helps to track progress.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="diffusion-dpo-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + help="Batch size to use for VAE encoding of the images for efficient processing.", + ) + parser.add_argument( + "--no_hflip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--random_crop", + default=False, + action="store_true", + help=( + "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--beta_dpo", + type=int, + default=5000, + help="DPO KL Divergence penalty.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None: + raise ValueError("Must provide a `dataset_name`.") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +def tokenize_captions(tokenizers, examples): + captions = [] + for caption in examples["caption"]: + captions.append(caption) + + tokens_one = tokenizers[0]( + captions, truncation=True, padding="max_length", max_length=tokenizers[0].model_max_length, return_tensors="pt" + ).input_ids + tokens_two = tokenizers[1]( + captions, truncation=True, padding="max_length", max_length=tokenizers[1].model_max_length, return_tensors="pt" + ).input_ids + + return tokens_one, tokens_two + + +@torch.no_grad() +def encode_prompt(text_encoders, text_input_ids_list): + prompt_embeds_list = [] + + for i, text_encoder in enumerate(text_encoders): + text_input_ids = text_input_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + # We only train the additional adapter LoRA layers + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet and text_encoders to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # The VAE is always in float32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + + # Set up LoRA. + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + # Add adapter and make sure the trainable params are in float32. + unet.add_adapter(unet_lora_config) + if args.mixed_precision == "fp16": + for param in unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + StableDiffusionXLLoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save) + + def load_model_hook(models, input_dir): + unet_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) + StableDiffusionXLLoraLoaderMixin.load_lora_into_unet( + lora_state_dict, network_alphas=network_alphas, unet=unet_ + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = load_dataset( + args.dataset_name, + cache_dir=args.cache_dir, + split=args.dataset_split_name, + ) + + # Preprocessing the datasets. + train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + + def preprocess_train(examples): + all_pixel_values = [] + images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples["jpg_0"]] + original_sizes = [(image.height, image.width) for image in images] + crop_top_lefts = [] + + for col_name in ["jpg_0", "jpg_1"]: + images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]] + if col_name == "jpg_1": + # Need to bring down the image to the same resolution. + # This seems like the simplest reasonable approach. + # "::-1" because PIL resize takes (width, height). + images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)] + pixel_values = [to_tensor(image) for image in images] + all_pixel_values.append(pixel_values) + + # Double on channel dim, jpg_y then jpg_w + im_tup_iterator = zip(*all_pixel_values) + combined_pixel_values = [] + for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]): + if label_0 == 0: + im_tup = im_tup[::-1] + + combined_im = torch.cat(im_tup, dim=0) # no batch dim + + # Resize. + combined_im = train_resize(combined_im) + + # Cropping. + if not args.random_crop: + y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0))) + x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0))) + combined_im = train_crop(combined_im) + else: + y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution)) + combined_im = crop(combined_im, y1, x1, h, w) + + # Flipping. + if random.random() < 0.5: + x1 = combined_im.shape[2] - x1 + combined_im = train_flip(combined_im) + + crop_top_left = (y1, x1) + crop_top_lefts.append(crop_top_left) + combined_im = normalize(combined_im) + combined_pixel_values.append(combined_im) + + examples["pixel_values"] = combined_pixel_values + examples["original_sizes"] = original_sizes + examples["crop_top_lefts"] = crop_top_lefts + tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples) + examples["input_ids_one"] = tokens_one + examples["input_ids_two"] = tokens_two + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = train_dataset.with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + original_sizes = [example["original_sizes"] for example in examples] + crop_top_lefts = [example["crop_top_lefts"] for example in examples] + input_ids_one = torch.stack([example["input_ids_one"] for example in examples]) + input_ids_two = torch.stack([example["input_ids_two"] for example in examples]) + + return { + "pixel_values": pixel_values, + "input_ids_one": input_ids_one, + "input_ids_two": input_ids_two, + "original_sizes": original_sizes, + "crop_top_lefts": crop_top_lefts, + } + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("diffusion-dpo-lora-sdxl", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + unet.train() + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w) + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1)) + + latents = [] + for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size): + latents.append( + vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample() + ) + latents = torch.cat(latents, dim=0) + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1) + + # Sample a random timestep for each image + bsz = latents.shape[0] // 2 + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long + ).repeat(2) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps) + + # time ids + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + add_time_ids = torch.cat( + [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] + ).repeat(2, 1) + + # Get the text embedding for conditioning + prompt_embeds, pooled_prompt_embeds = encode_prompt( + [text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]] + ) + prompt_embeds = prompt_embeds.repeat(2, 1, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1) + + # Predict the noise residual + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds}, + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Compute losses. + model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none") + model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape)))) + model_losses_w, model_losses_l = model_losses.chunk(2) + + # For logging + raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean()) + model_diff = model_losses_w - model_losses_l # These are both LBS (as is t) + + # Reference model predictions. + accelerator.unwrap_model(unet).disable_adapters() + with torch.no_grad(): + ref_preds = unet( + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds}, + ).sample + ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none") + ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape)))) + + ref_losses_w, ref_losses_l = ref_loss.chunk(2) + ref_diff = ref_losses_w - ref_losses_l + raw_ref_loss = ref_loss.mean() + + # Re-enable adapters. + accelerator.unwrap_model(unet).enable_adapters() + + # Final loss. + scale_term = -0.5 * args.beta_dpo + inside_term = scale_term * (model_diff - ref_diff) + loss = -1 * F.logsigmoid(inside_term.mean()) + + implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) + implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.run_validation and global_step % args.validation_steps == 0: + log_validation( + args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch + ) + + logs = { + "loss": loss.detach().item(), + "raw_model_loss": raw_model_loss.detach().item(), + "ref_loss": raw_ref_loss.detach().item(), + "implicit_acc": implicit_acc.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet = unet.to(torch.float32) + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + + StableDiffusionXLLoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_state_dict, + text_encoder_lora_layers=None, + text_encoder_2_lora_layers=None, + ) + + # Final validation? + if args.run_validation: + log_validation( + args, + unet=None, + vae=vae, + accelerator=accelerator, + weight_dtype=weight_dtype, + epoch=epoch, + is_final_validation=True, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index b1c554a7a2ed..e06fc227af80 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -58,7 +58,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c5371c95469e..27d7bc54432f 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -53,7 +53,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index c0d2f639883c..2a29767bf75d 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -33,7 +33,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index f6575f7494ca..27dedc8f7fd1 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -49,7 +49,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 81dbd0f1f612..606a88f55b32 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -57,7 +57,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 5c024f4080ae..0bb57b1f3126 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -57,7 +57,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 50bcc992064d..5c32f5bb7019 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -79,7 +79,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__) @@ -341,7 +341,7 @@ def parse_args(): help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU." + "and Nvidia Ampere GPU or Intel Gen 4 Xeon (and later) ." ), ) parser.add_argument( diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index bce70223eb3b..97050bf5a682 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -56,7 +56,7 @@ # ------------------------------------------------------------------------------ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index a721c605fdea..d410c9981b6b 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -29,7 +29,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index f1f6b3215201..7a53125573fa 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -50,7 +50,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 1ae5092ad10c..df89b8afc843 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -51,7 +51,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.25.0.dev0") +check_min_version("0.26.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/setup.py b/setup.py index c0bfbd6efae5..177c918d38ca 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.19.4", + "huggingface-hub>=0.20.2", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", @@ -251,7 +251,7 @@ def run(self): setup( name="diffusers", - version="0.25.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.26.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="State-of-the-art diffusion in PyTorch and JAX.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 180b210953c1..3f0d02119b89 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.25.0.dev0" +__version__ = "0.26.0.dev0" from typing import TYPE_CHECKING diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 7891984b0c5d..03e8fe7a0a00 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -9,7 +9,7 @@ "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.19.4", + "huggingface-hub": "huggingface-hub>=0.20.2", "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark>=0.2.0", diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 45c8c97c76eb..d7855206a287 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate -from ..utils.import_utils import is_torch_available, is_transformers_available +from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available def text_encoder_lora_state_dict(text_encoder): @@ -64,6 +64,8 @@ def text_encoder_attn_modules(text_encoder): _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] +_import_structure["peft"] = ["PeftAdapterMixin"] + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): @@ -76,6 +78,8 @@ def text_encoder_attn_modules(text_encoder): from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin + + from .peft import PeftAdapterMixin else: import sys diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py new file mode 100644 index 000000000000..a7f1e1c938f0 --- /dev/null +++ b/src/diffusers/loaders/peft.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available + + +class PeftAdapterMixin: + """ + A class containing all functions for loading and using adapters weights that are supported in PEFT library. For + more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT + library: https://huggingface.co/docs/peft/index. + + + With this mixin, if the correct PEFT version is installed, it is possible to: + + - Attach new adapters in the model. + - Attach multiple adapters and iteratively activate / deactivate them. + - Activate / deactivate all adapters from the model. + - Get a list of the active adapters. + """ + + _hf_peft_config_loaded = False + + def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: + r""" + Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned + to the adapter to follow the convention of the PEFT library. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT + [documentation](https://huggingface.co/docs/peft). + + Args: + adapter_config (`[~peft.PeftConfig]`): + The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt + methods. + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not is_peft_available(): + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") + + from peft import PeftConfig, inject_adapter_in_model + + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + elif adapter_name in self.peft_config: + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + if not isinstance(adapter_config, PeftConfig): + raise ValueError( + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." + ) + + # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is + # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. + adapter_config.base_model_name_or_path = None + inject_adapter_in_model(adapter_config, self, adapter_name) + self.set_adapter(adapter_name) + + def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: + """ + Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Args: + adapter_name (Union[str, List[str]])): + The list of adapters to set or the adapter name in case of single adapter. + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + missing = set(adapter_name) - set(self.peft_config) + if len(missing) > 0: + raise ValueError( + f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." + f" current loaded adapters are: {list(self.peft_config.keys())}" + ) + + from peft.tuners.tuners_utils import BaseTunerLayer + + _adapters_has_been_set = False + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + # Previous versions of PEFT does not support multi-adapter inference + elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: + raise ValueError( + "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." + " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" + ) + else: + module.active_adapter = adapter_name + _adapters_has_been_set = True + + if not _adapters_has_been_set: + raise ValueError( + "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." + ) + + def disable_adapters(self) -> None: + r""" + Disable all adapters attached to the model and fallback to inference with the base model only. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=False) + else: + # support for older PEFT versions + module.disable_adapters = True + + def enable_adapters(self) -> None: + """ + Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the + list of adapters to enable. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=True) + else: + # support for older PEFT versions + module.disable_adapters = False + + def active_adapters(self) -> List[str]: + """ + Gets the current list of active adapters of the model. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not is_peft_available(): + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + return module.active_adapter diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 546c5b20f937..445c3ca71caf 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -32,12 +32,10 @@ from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, - MIN_PEFT_VERSION, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, - check_peft_version, deprecate, is_accelerate_available, is_torch_version, @@ -197,7 +195,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None - _hf_peft_config_loaded = False def __init__(self): super().__init__() @@ -303,153 +300,6 @@ def disable_xformers_memory_efficient_attention(self) -> None: """ self.set_use_memory_efficient_attention_xformers(False) - def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: - r""" - Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned - to the adapter to follow the convention of the PEFT library. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT - [documentation](https://huggingface.co/docs/peft). - - Args: - adapter_config (`[~peft.PeftConfig]`): - The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt - methods. - adapter_name (`str`, *optional*, defaults to `"default"`): - The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - from peft import PeftConfig, inject_adapter_in_model - - if not self._hf_peft_config_loaded: - self._hf_peft_config_loaded = True - elif adapter_name in self.peft_config: - raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") - - if not isinstance(adapter_config, PeftConfig): - raise ValueError( - f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." - ) - - # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is - # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. - adapter_config.base_model_name_or_path = None - inject_adapter_in_model(adapter_config, self, adapter_name) - self.set_adapter(adapter_name) - - def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: - """ - Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - - Args: - adapter_name (Union[str, List[str]])): - The list of adapters to set or the adapter name in case of single adapter. - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - if isinstance(adapter_name, str): - adapter_name = [adapter_name] - - missing = set(adapter_name) - set(self.peft_config) - if len(missing) > 0: - raise ValueError( - f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." - f" current loaded adapters are: {list(self.peft_config.keys())}" - ) - - from peft.tuners.tuners_utils import BaseTunerLayer - - _adapters_has_been_set = False - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - if hasattr(module, "set_adapter"): - module.set_adapter(adapter_name) - # Previous versions of PEFT does not support multi-adapter inference - elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: - raise ValueError( - "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." - " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" - ) - else: - module.active_adapter = adapter_name - _adapters_has_been_set = True - - if not _adapters_has_been_set: - raise ValueError( - "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." - ) - - def disable_adapters(self) -> None: - r""" - Disable all adapters attached to the model and fallback to inference with the base model only. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from peft.tuners.tuners_utils import BaseTunerLayer - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - if hasattr(module, "enable_adapters"): - module.enable_adapters(enabled=False) - else: - # support for older PEFT versions - module.disable_adapters = True - - def enable_adapters(self) -> None: - """ - Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the - list of adapters to enable. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from peft.tuners.tuners_utils import BaseTunerLayer - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - if hasattr(module, "enable_adapters"): - module.enable_adapters(enabled=True) - else: - # support for older PEFT versions - module.disable_adapters = False - - def active_adapters(self) -> List[str]: - """ - Gets the current list of active adapters of the model. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from peft.tuners.tuners_utils import BaseTunerLayer - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - return module.active_adapter - def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 8ada0a7c08a5..6b52ea344d41 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -6,7 +6,7 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin +from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ..utils import BaseOutput from .attention import BasicTransformerBlock from .attention_processor import ( @@ -33,7 +33,7 @@ class PriorTransformerOutput(BaseOutput): predicted_image_embedding: torch.FloatTensor -class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): +class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): """ A Prior Transformer model. diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4554016bdd53..7b4f9f5594ea 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -19,7 +19,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin +from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from .activations import get_activation from .attention_processor import ( @@ -68,7 +68,7 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor = None -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. @@ -829,6 +829,17 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 2fd629b2f8d9..fc8695e064b5 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging +from ..utils import BaseOutput, deprecate, logging from .activations import get_activation from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -503,6 +503,18 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/uvit_2d.py b/src/diffusers/models/uvit_2d.py index a49c77a51b02..c0e224562cf2 100644 --- a/src/diffusers/models/uvit_2d.py +++ b/src/diffusers/models/uvit_2d.py @@ -21,6 +21,7 @@ from torch.utils.checkpoint import checkpoint from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import PeftAdapterMixin from .attention import BasicTransformerBlock, SkipFFTransformerBlock from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -35,7 +36,7 @@ from .resnet import Downsample2D, Upsample2D -class UVit2DModel(ModelMixin, ConfigMixin): +class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 6f95112c3d50..86e3cfefaf4f 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1034,6 +1034,17 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index e7a795365ad3..de5dea679ee9 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import fnmatch import importlib import inspect @@ -27,6 +26,7 @@ import numpy as np import PIL.Image +import requests import torch from huggingface_hub import ( ModelCard, @@ -35,7 +35,7 @@ model_info, snapshot_download, ) -from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args from packaging import version from requests.exceptions import HTTPError from tqdm.auto import tqdm @@ -530,6 +530,36 @@ def load_sub_model( return loaded_sub_model +def _fetch_class_library_tuple(module): + # import it here to avoid circular import + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # register the config from the original module, not the dynamo compiled one + not_compiled_module = _unwrap_model(module) + library = not_compiled_module.__module__.split(".")[0] + + # check if the module is a pipeline module + module_path_items = not_compiled_module.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + + path = not_compiled_module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if is_pipeline_module: + library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = not_compiled_module.__module__ + + # retrieve class_name + class_name = not_compiled_module.__class__.__name__ + + return (library, class_name) + + class DiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for all pipelines. @@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): _is_onnx = False def register_modules(self, **kwargs): - # import it here to avoid circular import - diffusers_module = importlib.import_module(__name__.split(".")[0]) - pipelines = getattr(diffusers_module, "pipelines") - for name, module in kwargs.items(): # retrieve library if module is None or isinstance(module, (tuple, list)) and module[0] is None: register_dict = {name: (None, None)} else: - # register the config from the original module, not the dynamo compiled one - not_compiled_module = _unwrap_model(module) - - library = not_compiled_module.__module__.split(".")[0] - - # check if the module is a pipeline module - module_path_items = not_compiled_module.__module__.split(".") - pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None - - path = not_compiled_module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if is_pipeline_module: - library = pipeline_dir - elif library not in LOADABLE_CLASSES: - library = not_compiled_module.__module__ - - # retrieve class_name - class_name = not_compiled_module.__class__.__name__ - + library, class_name = _fetch_class_library_tuple(module) register_dict = {name: (library, class_name)} # save model index config @@ -601,7 +605,7 @@ def __setattr__(self, name: str, value: Any): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: - class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) + class_library_tuple = _fetch_class_library_tuple(value) else: class_library_tuple = (None, None) @@ -1654,7 +1658,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: if not local_files_only: try: info = model_info(pretrained_model_name, token=token, revision=revision) - except HTTPError as e: + except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e: logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") local_files_only = True model_info_call_error = e # save error to reraise it if model is not cached locally diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index d4502639cebc..8b494fa32476 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -20,7 +20,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import UNet2DConditionLoadersMixin +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -34,7 +34,7 @@ from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm -class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): +class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): unet_name = "prior" _supports_gradient_checkpointing = True diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 992ae7d1b194..9fb6fad3a267 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,6 +5,7 @@ import numpy as np import torch +from torchvision import transforms from .models import UNet2DConditionModel from .utils import deprecate, is_transformers_available @@ -53,6 +54,45 @@ def compute_snr(noise_scheduler, timesteps): return snr +def resolve_interpolation_mode(interpolation_type: str): + """ + Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The + full list of supported enums is documented at + https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. + + Args: + interpolation_type (`str`): + A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`, + `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes + in torchvision. + + Returns: + `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` + transform. + """ + if interpolation_type == "bilinear": + interpolation_mode = transforms.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = transforms.InterpolationMode.BICUBIC + elif interpolation_type == "box": + interpolation_mode = transforms.InterpolationMode.BOX + elif interpolation_type == "nearest": + interpolation_mode = transforms.InterpolationMode.NEAREST + elif interpolation_type == "nearest_exact": + interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT + elif interpolation_type == "hamming": + interpolation_mode = transforms.InterpolationMode.HAMMING + elif interpolation_type == "lanczos": + interpolation_mode = transforms.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ) + + return interpolation_mode + + def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: r""" Returns: diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index ab8f1b70cb4a..fbc12d74f1aa 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -151,9 +151,7 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): unet_lora_sd = unet_lora_state_dict(unet) # Unload LoRA. - for module in unet.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) + unet.unload_lora() return unet_lora_parameters, unet_lora_sd @@ -230,9 +228,7 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): unet_lora_sd = unet_lora_state_dict(unet) # Unload LoRA. - for module in unet.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) + unet.unload_lora() return unet_lora_sd @@ -1496,7 +1492,7 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_lora_processors(self): + def test_lora_at_different_scales(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -1514,9 +1510,6 @@ def test_lora_processors(self): model.load_attn_procs(lora_params) model.to(torch_device) - # test that attn processors can be set to itself - model.set_attn_processor(model.attn_processors) - with torch.no_grad(): sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample @@ -1548,9 +1541,7 @@ def test_lora_on_off(self, expected_max_diff=1e-3): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample # Unload LoRA. - for module in model.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) + model.unload_lora() with torch.no_grad(): new_sample = model(**inputs_dict).sample @@ -1595,7 +1586,7 @@ def test_lora_xformers_on_off(self, expected_max_diff=6e-4): @deprecate_after_peft_backend -class UNet3DConditionModelTests(unittest.TestCase): +class UNet3DConditionLoRAModelTests(unittest.TestCase): model_class = UNet3DConditionModel main_input_name = "sample" @@ -1638,7 +1629,7 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_lora_processors(self): + def test_lora_at_different_scales(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["attention_head_dim"] = 8 @@ -1655,9 +1646,6 @@ def test_lora_processors(self): model.load_attn_procs(unet_lora_params) model.to(torch_device) - # test that attn processors can be set to itself - model.set_attn_processor(model.attn_processors) - with torch.no_grad(): sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample