diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index a85adfc2bfec..a8987d177b28 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -13,13 +13,13 @@ env: jobs: torch_pipelines_cuda_benchmark_tests: - env: + env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }} name: Torch Core Pipelines CUDA Benchmarking Tests strategy: fail-fast: false max-parallel: 1 - runs-on: + runs-on: group: aws-g6-4xlarge-plus container: image: diffusers/diffusers-pytorch-compile-cuda @@ -59,7 +59,7 @@ jobs: if: ${{ success() }} run: | pip install requests && python utils/notify_benchmarking_status.py --status=success - + - name: Report failure status if: ${{ failure() }} run: | diff --git a/.github/workflows/mirror_community_pipeline.yml b/.github/workflows/mirror_community_pipeline.yml index e1028c77b700..a7a2a809bbeb 100644 --- a/.github/workflows/mirror_community_pipeline.yml +++ b/.github/workflows/mirror_community_pipeline.yml @@ -24,7 +24,7 @@ jobs: mirror_community_pipeline: env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }} - + runs-on: ubuntu-latest steps: # Checkout to correct ref @@ -95,7 +95,7 @@ jobs: if: ${{ success() }} run: | pip install requests && python utils/notify_community_pipelines_mirror.py --status=success - + - name: Report failure status if: ${{ failure() }} run: | diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 248846a8747e..3314c2c1cfb4 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -32,7 +32,7 @@ jobs: fetch-depth: 2 - name: Install dependencies run: | - pip install -e [test] + pip install -e .[test] pip install huggingface_hub - name: Fetch Pipeline Matrix id: fetch_pipeline_matrix diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 16acc87dde42..0aa2a77dbcac 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,7 +63,7 @@ In the same spirit, you are of immense help to the community by answering such q **Please** keep in mind that the more effort you put into asking or answering a question, the higher the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database. -In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. +In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. **NOTE about channels**: [*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago. diff --git a/README.md b/README.md index c8a734b8ce12..775f6f5e8289 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi ## Quickstart -Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 27.000+ checkpoints): +Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 30,000+ checkpoints): ```python from diffusers import DiffusionPipeline @@ -211,7 +211,7 @@ Also, say 👋 in our public Discord channel + +# StableCascadeUNet + +A UNet model from the [Stable Cascade pipeline](../pipelines/stable_cascade.md). + +## StableCascadeUNet + +[[autodoc]] models.unets.unet_stable_cascade.StableCascadeUNet diff --git a/docs/source/en/api/pipelines/aura_flow.md b/docs/source/en/api/pipelines/aura_flow.md index 90b882051a12..aa5a04800e6f 100644 --- a/docs/source/en/api/pipelines/aura_flow.md +++ b/docs/source/en/api/pipelines/aura_flow.md @@ -18,7 +18,7 @@ It was developed by the Fal team and more details about it can be found in [this -AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. +AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index f7d27dfbbf98..095bf76af37f 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -12,13 +12,13 @@ specific language governing permissions and limitations under the License. # Flux -Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs. +Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs. Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux). -Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. +Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c). @@ -27,11 +27,11 @@ Flux comes in two variants: * Timestep-distilled (`black-forest-labs/FLUX.1-schnell`) * Guidance-distilled (`black-forest-labs/FLUX.1-dev`) -Both checkpoints have slightly difference usage which we detail below. +Both checkpoints have slightly difference usage which we detail below. ### Timestep-distilled -* `max_sequence_length` cannot be more than 256. +* `max_sequence_length` cannot be more than 256. * `guidance_scale` needs to be 0. * As this is a timestep-distilled model, it benefits from fewer sampling steps. @@ -44,11 +44,11 @@ pipe.enable_model_cpu_offload() prompt = "A cat holding a sign that says hello world" out = pipe( - prompt=prompt, - guidance_scale=0., - height=768, - width=1360, - num_inference_steps=4, + prompt=prompt, + guidance_scale=0., + height=768, + width=1360, + num_inference_steps=4, max_sequence_length=256, ).images[0] out.save("image.png") @@ -57,7 +57,7 @@ out.save("image.png") ### Guidance-distilled * The guidance-distilled variant takes about 50 sampling steps for good-quality generation. -* It doesn't have any limitations around the `max_sequence_length`. +* It doesn't have any limitations around the `max_sequence_length`. ```python import torch @@ -68,10 +68,10 @@ pipe.enable_model_cpu_offload() prompt = "a tiny astronaut hatching from an egg on the moon" out = pipe( - prompt=prompt, - guidance_scale=3.5, - height=768, - width=1360, + prompt=prompt, + guidance_scale=3.5, + height=768, + width=1360, num_inference_steps=50, ).images[0] out.save("image.png") diff --git a/docs/source/en/api/pipelines/lumina.md b/docs/source/en/api/pipelines/lumina.md index 6f950f2eef66..cc8aceefc1b1 100644 --- a/docs/source/en/api/pipelines/lumina.md +++ b/docs/source/en/api/pipelines/lumina.md @@ -59,7 +59,7 @@ First, load the pipeline: ```python from diffusers import LuminaText2ImgPipeline -import torch +import torch pipeline = LuminaText2ImgPipeline.from_pretrained( "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 @@ -87,4 +87,4 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w [[autodoc]] LuminaText2ImgPipeline - all - __call__ - + diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md index 5b48569f3505..ac12bdb5578d 100644 --- a/docs/source/en/api/pipelines/pag.md +++ b/docs/source/en/api/pipelines/pag.md @@ -20,11 +20,29 @@ The abstract from the paper is: *Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.* +PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers. + +- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor` +- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor` +- Partial identifier as a RegEx: `down_blocks.2`, or `attn1` +- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]` + + + +Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results. + + + ## AnimateDiffPAGPipeline [[autodoc]] AnimateDiffPAGPipeline - all - __call__ +## HunyuanDiTPAGPipeline +[[autodoc]] HunyuanDiTPAGPipeline + - all + - __call__ + ## StableDiffusionPAGPipeline [[autodoc]] StableDiffusionPAGPipeline - all @@ -59,4 +77,4 @@ The abstract from the paper is: ## PixArtSigmaPAGPipeline [[autodoc]] PixArtSigmaPAGPipeline - all - - __call__ \ No newline at end of file + - __call__ diff --git a/docs/source/en/api/pipelines/stable_audio.md b/docs/source/en/api/pipelines/stable_audio.md index 3e7b2857e4eb..96b2678b5027 100644 --- a/docs/source/en/api/pipelines/stable_audio.md +++ b/docs/source/en/api/pipelines/stable_audio.md @@ -16,7 +16,7 @@ Stable Audio was proposed in [Stable Audio Open](https://arxiv.org/abs/2407.1435 Stable Audio Open generates variable-length (up to 47s) stereo audio at 44.1kHz from text prompts. It comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder. -Stable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT. +Stable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT. The abstract of the paper is the following: *Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.* diff --git a/docs/source/en/training/instructpix2pix.md b/docs/source/en/training/instructpix2pix.md index 3f797ced497d..3a651e5abd2d 100644 --- a/docs/source/en/training/instructpix2pix.md +++ b/docs/source/en/training/instructpix2pix.md @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. [InstructPix2Pix](https://hf.co/papers/2211.09800) is a Stable Diffusion model trained to edit images from human-provided instructions. For example, your prompt can be "turn the clouds rainy" and the model will edit the input image accordingly. This model is conditioned on the text prompt (or editing instruction) and the input image. -This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use-case. +This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use case. Before running the script, make sure you install the library from source: @@ -117,7 +117,7 @@ optimizer = optimizer_cls( ) ``` -Next, the edited images and and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images. +Next, the edited images and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images. ```py def preprocess_train(examples): @@ -249,4 +249,4 @@ The SDXL training script is discussed in more detail in the [SDXL training](sdxl Congratulations on training your own InstructPix2Pix model! 🥳 To learn more about the model, it may be helpful to: -- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions. \ No newline at end of file +- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions. diff --git a/docs/source/en/tutorials/fast_diffusion.md b/docs/source/en/tutorials/fast_diffusion.md index e758ea399b59..0f1133dc2dc3 100644 --- a/docs/source/en/tutorials/fast_diffusion.md +++ b/docs/source/en/tutorials/fast_diffusion.md @@ -35,7 +35,7 @@ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu ``` > [!TIP] -> The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum. +> The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum. > If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast). @@ -168,7 +168,7 @@ Using SDPA attention and compiling both the UNet and VAE cuts the latency from 3 > [!TIP] -> From PyTorch 2.3.1, you can control the caching behavior of `torch.compile()`. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial. +> From PyTorch 2.3.1, you can control the caching behavior of `torch.compile()`. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial. ### Prevent graph breaks diff --git a/docs/source/en/tutorials/inference_with_big_models.md b/docs/source/en/tutorials/inference_with_big_models.md index b3d1067cfc6e..6700bbad07a4 100644 --- a/docs/source/en/tutorials/inference_with_big_models.md +++ b/docs/source/en/tutorials/inference_with_big_models.md @@ -18,13 +18,13 @@ A modern diffusion model, like [Stable Diffusion XL (SDXL)](../using-diffusers/s * Two text encoders * A UNet for denoising -Usually, the text encoders and the denoiser are much larger compared to the VAE. +Usually, the text encoders and the denoiser are much larger compared to the VAE. As models get bigger and better, it’s possible your model is so big that even a single copy won’t fit in memory. But that doesn’t mean it can’t be loaded. If you have more than one GPU, there is more memory available to store your model. In this case, it’s better to split your model checkpoint into several smaller *checkpoint shards*. When a text encoder checkpoint has multiple shards, like [T5-xxl for SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/tree/main/text_encoder_3), it is automatically handled by the [Transformers](https://huggingface.co/docs/transformers/index) library as it is a required dependency of Diffusers when using the [`StableDiffusion3Pipeline`]. More specifically, Transformers will automatically handle the loading of multiple shards within the requested model class and get it ready so that inference can be performed. -The denoiser checkpoint can also have multiple shards and supports inference thanks to the [Accelerate](https://huggingface.co/docs/accelerate/index) library. +The denoiser checkpoint can also have multiple shards and supports inference thanks to the [Accelerate](https://huggingface.co/docs/accelerate/index) library. > [!TIP] > Refer to the [Handling big models for inference](https://huggingface.co/docs/accelerate/main/en/concept_guides/big_model_inference) guide for general guidance when working with big models that are hard to fit into memory. @@ -43,7 +43,7 @@ unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB") The size of the fp32 variant of the SDXL UNet checkpoint is ~10.4GB. Set the `max_shard_size` parameter to 5GB to create 3 shards. After saving, you can load them in [`StableDiffusionXLPipeline`]: ```python -from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline +from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline import torch unet = UNet2DConditionModel.from_pretrained( @@ -57,14 +57,14 @@ image = pipeline("a cute dog running on the grass", num_inference_steps=30).imag image.save("dog.png") ``` -If placing all the model-level components on the GPU at once is not feasible, use [`~DiffusionPipeline.enable_model_cpu_offload`] to help you: +If placing all the model-level components on the GPU at once is not feasible, use [`~DiffusionPipeline.enable_model_cpu_offload`] to help you: ```diff - pipeline.to("cuda") + pipeline.enable_model_cpu_offload() ``` -In general, we recommend sharding when a checkpoint is more than 5GB (in fp32). +In general, we recommend sharding when a checkpoint is more than 5GB (in fp32). ## Device placement diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index c37dd90fa172..907f93d573a0 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -34,7 +34,7 @@ pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda") ``` -Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`. +Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which lets you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`. ```python pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md index 2ed74ab80dbf..ce4c6d1b98c8 100644 --- a/docs/source/en/using-diffusers/callback.md +++ b/docs/source/en/using-diffusers/callback.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # Pipeline callbacks -The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code! +The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code! > [!TIP] > 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point! @@ -75,7 +75,7 @@ out.images[0].save("official_callback.png")
without SDXLCFGCutoffCallback
- generated image of a a sports car at the road with cfg callback + generated image of a sports car at the road with cfg callback
with SDXLCFGCutoffCallback
diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md index 20b7f9429034..cdd687ae8130 100644 --- a/docs/source/en/using-diffusers/controlnet.md +++ b/docs/source/en/using-diffusers/controlnet.md @@ -256,7 +256,7 @@ make_image_grid([init_image, mask_image, output], rows=1, cols=3) ## Guess mode -[Guess mode](https://github.com/lllyasviel/ControlNet/discussions/188) does not require supplying a prompt to a ControlNet at all! This forces the ControlNet encoder to do it's best to "guess" the contents of the input control map (depth map, pose estimation, canny edge, etc.). +[Guess mode](https://github.com/lllyasviel/ControlNet/discussions/188) does not require supplying a prompt to a ControlNet at all! This forces the ControlNet encoder to do its best to "guess" the contents of the input control map (depth map, pose estimation, canny edge, etc.). Guess mode adjusts the scale of the output residuals from a ControlNet by a fixed ratio depending on the block depth. The shallowest `DownBlock` corresponds to 0.1, and as the blocks get deeper, the scale increases exponentially such that the scale of the `MidBlock` output becomes 1.0. diff --git a/docs/source/en/using-diffusers/custom_pipeline_overview.md b/docs/source/en/using-diffusers/custom_pipeline_overview.md index 341a98a5c897..17ba779b8136 100644 --- a/docs/source/en/using-diffusers/custom_pipeline_overview.md +++ b/docs/source/en/using-diffusers/custom_pipeline_overview.md @@ -289,9 +289,9 @@ scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="sche 3. Load an image processor: ```python -from transformers import CLIPFeatureExtractor +from transformers import CLIPImageProcessor -feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor") +feature_extractor = CLIPImageProcessor.from_pretrained(pipe_id, subfolder="feature_extractor") ``` diff --git a/docs/source/en/using-diffusers/inference_with_tcd_lora.md b/docs/source/en/using-diffusers/inference_with_tcd_lora.md index df49fc8475ad..d6fa61be557a 100644 --- a/docs/source/en/using-diffusers/inference_with_tcd_lora.md +++ b/docs/source/en/using-diffusers/inference_with_tcd_lora.md @@ -212,14 +212,14 @@ TCD-LoRA is very versatile, and it can be combined with other adapter types like import torch import numpy as np from PIL import Image -from transformers import DPTFeatureExtractor, DPTForDepthEstimation +from transformers import DPTImageProcessor, DPTForDepthEstimation from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline from diffusers.utils import load_image, make_image_grid from scheduling_tcd import TCDScheduler device = "cuda" depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device) -feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") +feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") def get_depth_map(image): image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) diff --git a/docs/source/en/using-diffusers/pag.md b/docs/source/en/using-diffusers/pag.md index e852aec03fd4..26961d959c49 100644 --- a/docs/source/en/using-diffusers/pag.md +++ b/docs/source/en/using-diffusers/pag.md @@ -130,10 +130,10 @@ prompt = "a dog catching a frisbee in the jungle" generator = torch.Generator(device="cpu").manual_seed(0) image = pipeline( - prompt, - image=init_image, - strength=0.8, - guidance_scale=guidance_scale, + prompt, + image=init_image, + strength=0.8, + guidance_scale=guidance_scale, pag_scale=pag_scale, generator=generator).images[0] ``` @@ -161,14 +161,14 @@ pipeline_inpaint = AutoPipelineForInpaiting.from_pretrained("stabilityai/stable- pipeline = AutoPipelineForInpaiting.from_pipe(pipeline_inpaint, enable_pag=True) ``` -This still works when your pipeline has a different task: +This still works when your pipeline has a different task: ```py pipeline_t2i = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) pipeline = AutoPipelineForInpaiting.from_pipe(pipeline_t2i, enable_pag=True) ``` -Let's generate an image! +Let's generate an image! ```py img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" @@ -258,7 +258,7 @@ for pag_scale in [0.0, 3.0]: -## PAG with IP-Adapter +## PAG with IP-Adapter [IP-Adapter](https://hf.co/papers/2308.06721) is a popular model that can be plugged into diffusion models to enable image prompting without any changes to the underlying model. You can enable PAG on a pipeline with IP-Adapter loaded. @@ -317,7 +317,7 @@ PAG reduces artifacts and improves the overall compposition. -## Configure parameters +## Configure parameters ### pag_applied_layers diff --git a/docs/source/ko/conceptual/philosophy.md b/docs/source/ko/conceptual/philosophy.md index 5d49c075a165..fab2a4d6d3ab 100644 --- a/docs/source/ko/conceptual/philosophy.md +++ b/docs/source/ko/conceptual/philosophy.md @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# 철학 [[philosophy]] +# 철학 [[philosophy]] 🧨 Diffusers는 다양한 모달리티에서 **최신의** 사전 훈련된 diffusion 모델을 제공합니다. 그 목적은 추론과 훈련을 위한 **모듈식 툴박스**로 사용되는 것입니다. diff --git a/docs/source/ko/using-diffusers/loading.md b/docs/source/ko/using-diffusers/loading.md index 39cd228af401..2106b91a68cf 100644 --- a/docs/source/ko/using-diffusers/loading.md +++ b/docs/source/ko/using-diffusers/loading.md @@ -307,7 +307,7 @@ print(pipeline) 위의 코드 출력 결과를 확인해보면, `pipeline`은 [`StableDiffusionPipeline`]의 인스턴스이며, 다음과 같이 총 7개의 컴포넌트로 구성된다는 것을 알 수 있습니다. -- `"feature_extractor"`: [`~transformers.CLIPFeatureExtractor`]의 인스턴스 +- `"feature_extractor"`: [`~transformers.CLIPImageProcessor`]의 인스턴스 - `"safety_checker"`: 유해한 컨텐츠를 스크리닝하기 위한 [컴포넌트](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32) - `"scheduler"`: [`PNDMScheduler`]의 인스턴스 - `"text_encoder"`: [`~transformers.CLIPTextModel`]의 인스턴스 diff --git a/docs/source/ko/using-diffusers/sdxl_turbo.md b/docs/source/ko/using-diffusers/sdxl_turbo.md index 766ac0f10a7d..99b96fd3b875 100644 --- a/docs/source/ko/using-diffusers/sdxl_turbo.md +++ b/docs/source/ko/using-diffusers/sdxl_turbo.md @@ -52,7 +52,7 @@ pipeline = pipeline.to("cuda") Text-to-image의 경우 텍스트 프롬프트를 전달합니다. 기본적으로 SDXL Turbo는 512x512 이미지를 생성하며, 이 해상도에서 최상의 결과를 제공합니다. `height` 및 `width` 매개 변수를 768x768 또는 1024x1024로 설정할 수 있지만 이 경우 품질 저하를 예상할 수 있습니다. -모델이 `guidance_scale` 없이 학습되었으므로 이를 0.0으로 설정해 비활성화해야 합니다. 단일 추론 스텝만으로도 고품질 이미지를 생성할 수 있습니다. +모델이 `guidance_scale` 없이 학습되었으므로 이를 0.0으로 설정해 비활성화해야 합니다. 단일 추론 스텝만으로도 고품질 이미지를 생성할 수 있습니다. 스텝 수를 2, 3 또는 4로 늘리면 이미지 품질이 향상됩니다. ```py @@ -74,7 +74,7 @@ image ## Image-to-image -Image-to-image 생성의 경우 `num_inference_steps * strength`가 1보다 크거나 같은지 확인하세요. +Image-to-image 생성의 경우 `num_inference_steps * strength`가 1보다 크거나 같은지 확인하세요. Image-to-image 파이프라인은 아래 예제에서 `0.5 * 2.0 = 1` 스텝과 같이 `int(num_inference_steps * strength)` 스텝으로 실행됩니다. ```py diff --git a/docs/source/ko/using-diffusers/svd.md b/docs/source/ko/using-diffusers/svd.md index 678e21728ad4..7c5b9f09e690 100644 --- a/docs/source/ko/using-diffusers/svd.md +++ b/docs/source/ko/using-diffusers/svd.md @@ -21,7 +21,7 @@ specific language governing permissions and limitations under the License. 시작하기 전에 다음 라이브러리가 설치되어 있는지 확인하세요: ```py -!pip install -q -U diffusers transformers accelerate +!pip install -q -U diffusers transformers accelerate ``` 이 모델에는 [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)와 [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) 두 가지 종류가 있습니다. SVD 체크포인트는 14개의 프레임을 생성하도록 학습되었고, SVD-XT 체크포인트는 25개의 프레임을 생성하도록 파인튜닝되었습니다. diff --git a/docs/source/ko/using-diffusers/textual_inversion_inference.md b/docs/source/ko/using-diffusers/textual_inversion_inference.md index 1b52fee923b3..39fab939a704 100644 --- a/docs/source/ko/using-diffusers/textual_inversion_inference.md +++ b/docs/source/ko/using-diffusers/textual_inversion_inference.md @@ -24,7 +24,7 @@ import PIL from PIL import Image from diffusers import StableDiffusionPipeline -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer def image_grid(imgs, rows, cols): diff --git a/examples/community/README.md b/examples/community/README.md index 652d65f900fe..090bb980b221 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1435,9 +1435,9 @@ import requests import torch from diffusers import DiffusionPipeline from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPModel +from transformers import CLIPImageProcessor, CLIPModel -feature_extractor = CLIPFeatureExtractor.from_pretrained( +feature_extractor = CLIPImageProcessor.from_pretrained( "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" ) clip_model = CLIPModel.from_pretrained( @@ -1487,17 +1487,16 @@ NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes. ```python import torch from diffusers import DDIMScheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline +from diffusers.pipelines import DiffusionPipeline # Use the DDIMScheduler scheduler here instead -scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", - subfolder="scheduler") +scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler") -pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", - custom_pipeline="stable_diffusion_tensorrt_txt2img", - variant='fp16', - torch_dtype=torch.float16, - scheduler=scheduler,) +pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", + custom_pipeline="stable_diffusion_tensorrt_txt2img", + variant='fp16', + torch_dtype=torch.float16, + scheduler=scheduler,) # re-use cached folder to save ONNX models and TensorRT Engines pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',) @@ -2123,7 +2122,7 @@ import torch import open_clip from open_clip import SimpleTokenizer from diffusers import DiffusionPipeline -from transformers import CLIPFeatureExtractor, CLIPModel +from transformers import CLIPImageProcessor, CLIPModel def download_image(url): @@ -2131,7 +2130,7 @@ def download_image(url): return PIL.Image.open(BytesIO(response.content)).convert("RGB") # Loading additional models -feature_extractor = CLIPFeatureExtractor.from_pretrained( +feature_extractor = CLIPImageProcessor.from_pretrained( "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" ) clip_model = CLIPModel.from_pretrained( @@ -2231,12 +2230,12 @@ from io import BytesIO from PIL import Image import torch from diffusers import PNDMScheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline +from diffusers.pipelines import DiffusionPipeline # Use the PNDMScheduler scheduler here instead scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler") -pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", +pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", custom_pipeline="stable_diffusion_tensorrt_inpaint", variant='fp16', torch_dtype=torch.float16, diff --git a/examples/community/clip_guided_images_mixing_stable_diffusion.py b/examples/community/clip_guided_images_mixing_stable_diffusion.py index 75b7df84dc77..f9a4b12ad20f 100644 --- a/examples/community/clip_guided_images_mixing_stable_diffusion.py +++ b/examples/community/clip_guided_images_mixing_stable_diffusion.py @@ -7,7 +7,7 @@ import torch from torch.nn import functional as F from torchvision import transforms -from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, @@ -86,7 +86,7 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, coca_model=None, coca_tokenizer=None, coca_transform=None, diff --git a/examples/community/clip_guided_stable_diffusion_img2img.py b/examples/community/clip_guided_stable_diffusion_img2img.py index 0f3de94e2cdb..91c74b9ffa74 100644 --- a/examples/community/clip_guided_stable_diffusion_img2img.py +++ b/examples/community/clip_guided_stable_diffusion_img2img.py @@ -7,7 +7,7 @@ from torch import nn from torch.nn import functional as F from torchvision import transforms -from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, @@ -32,9 +32,9 @@ import torch from diffusers import DiffusionPipeline from PIL import Image - from transformers import CLIPFeatureExtractor, CLIPModel + from transformers import CLIPImageProcessor, CLIPModel - feature_extractor = CLIPFeatureExtractor.from_pretrained( + feature_extractor = CLIPImageProcessor.from_pretrained( "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" ) clip_model = CLIPModel.from_pretrained( @@ -139,7 +139,7 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() self.register_modules( diff --git a/examples/community/fresco_v2v.py b/examples/community/fresco_v2v.py index 77071b023106..779dc3c2b4f1 100644 --- a/examples/community/fresco_v2v.py +++ b/examples/community/fresco_v2v.py @@ -2436,7 +2436,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/examples/community/mixture_canvas.py b/examples/community/mixture_canvas.py index 7196ee9587f2..2bb054a123d0 100644 --- a/examples/community/mixture_canvas.py +++ b/examples/community/mixture_canvas.py @@ -9,7 +9,7 @@ from numpy import exp, pi, sqrt from torchvision.transforms.functional import resize from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -275,7 +275,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() self.register_modules( diff --git a/examples/community/mixture_tiling.py b/examples/community/mixture_tiling.py index 7e3d592d8514..867bce0d9eb8 100644 --- a/examples/community/mixture_tiling.py +++ b/examples/community/mixture_tiling.py @@ -15,7 +15,7 @@ try: from ligo.segments import segment - from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer except ImportError: raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline") @@ -144,7 +144,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() self.register_modules( diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py index 5fc6d8af03c4..ae495979f366 100644 --- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py @@ -189,7 +189,7 @@ class StableDiffusionXLControlNetAdapterPipeline( safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py index 07954f013295..94ca71cf7b1b 100644 --- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py @@ -332,7 +332,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline( safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config diff --git a/examples/community/pipeline_stable_diffusion_xl_instandid_img2img.py b/examples/community/pipeline_stable_diffusion_xl_instandid_img2img.py index fdd229570daf..c51a5132a772 100644 --- a/examples/community/pipeline_stable_diffusion_xl_instandid_img2img.py +++ b/examples/community/pipeline_stable_diffusion_xl_instandid_img2img.py @@ -1002,7 +1002,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/examples/community/pipeline_stable_diffusion_xl_instantid.py b/examples/community/pipeline_stable_diffusion_xl_instantid.py index 18dcd0a13a98..8d28c2fbd348 100644 --- a/examples/community/pipeline_stable_diffusion_xl_instantid.py +++ b/examples/community/pipeline_stable_diffusion_xl_instantid.py @@ -991,7 +991,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py index af1c82dfad95..95bb37ce02b7 100644 --- a/examples/community/pipeline_zero1to3.py +++ b/examples/community/pipeline_zero1to3.py @@ -9,7 +9,7 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection # from ...configuration_utils import FrozenDict # from ...models import AutoencoderKL, UNet2DConditionModel @@ -87,7 +87,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. cc_projection ([`CCProjection`]): Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size. @@ -102,7 +102,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, cc_projection: CCProjection, requires_safety_checker: bool = True, ): diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index cad71338faed..8a022987ba9d 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -3,7 +3,7 @@ import torch import torchvision.transforms.functional as FF -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import StableDiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -69,7 +69,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__( diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index 6e25b92603d5..d9c616ab5ebc 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -864,7 +864,7 @@ def __call__( ) if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] @@ -1038,7 +1038,7 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None): ) if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [ diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index e4d8e12f85a9..577c7712e771 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -752,7 +752,7 @@ def hacked_UpBlock2D_forward( ) if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py index 388992a740ec..123892f6229a 100644 --- a/examples/community/stable_diffusion_ipex.py +++ b/examples/community/stable_diffusion_ipex.py @@ -18,7 +18,7 @@ import intel_extension_for_pytorch as ipex import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.configuration_utils import FrozenDict from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin @@ -86,7 +86,7 @@ class StableDiffusionIPEXPipeline( safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -100,7 +100,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py index 40ad38bfe903..91540d1f4159 100755 --- a/examples/community/stable_diffusion_tensorrt_img2img.py +++ b/examples/community/stable_diffusion_tensorrt_img2img.py @@ -42,7 +42,7 @@ network_from_onnx_path, save_engine, ) -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict, deprecate @@ -60,7 +60,7 @@ """ Installation instructions python3 -m pip install --upgrade transformers diffusers>=0.16.0 -python3 -m pip install --upgrade tensorrt-cu12==10.2.0 +python3 -m pip install --upgrade tensorrt~=10.2.0 python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime """ @@ -659,7 +659,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline): r""" Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion. - This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: @@ -679,7 +679,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -693,7 +693,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae", "vae_encoder"], diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py index 8bacd050571a..b6f6711a53e7 100755 --- a/examples/community/stable_diffusion_tensorrt_inpaint.py +++ b/examples/community/stable_diffusion_tensorrt_inpaint.py @@ -18,8 +18,7 @@ import gc import os from collections import OrderedDict -from copy import copy -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import onnx @@ -27,9 +26,11 @@ import PIL.Image import tensorrt as trt import torch +from cuda import cudart from huggingface_hub import snapshot_download from huggingface_hub.utils import validate_hf_hub_args from onnx import shape_inference +from packaging import version from polygraphy import cuda from polygraphy.backend.common import bytes_from_path from polygraphy.backend.onnx.loader import fold_constants @@ -41,24 +42,29 @@ network_from_onnx_path, save_engine, ) -from polygraphy.backend.trt import util as trt_util -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict, deprecate +from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( - StableDiffusionInpaintPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker, ) -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import ( + prepare_mask_and_masked_image, + retrieve_latents, +) from diffusers.schedulers import DDIMScheduler from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor """ Installation instructions python3 -m pip install --upgrade transformers diffusers>=0.16.0 -python3 -m pip install --upgrade tensorrt>=8.6.1 +python3 -m pip install --upgrade tensorrt~=10.2.0 python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime """ @@ -88,10 +94,6 @@ torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} -def device_view(t): - return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) - - def preprocess_image(image): """ image: torch.Tensor @@ -125,10 +127,8 @@ def build( onnx_path, fp16, input_profile=None, - enable_preview=False, enable_all_tactics=False, timing_cache=None, - workspace_size=0, ): logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") p = Profile() @@ -137,20 +137,13 @@ def build( assert len(dims) == 3 p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - config_kwargs = {} - - config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] - if enable_preview: - # Faster dynamic shapes made optional since it increases engine build time. - config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) - if workspace_size > 0: - config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + extra_build_args = {} if not enable_all_tactics: - config_kwargs["tactic_sources"] = [] + extra_build_args["tactic_sources"] = [] engine = engine_from_network( network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), - config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), + config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args), save_timing_cache=timing_cache, ) save_engine(engine, path=self.engine_path) @@ -163,28 +156,24 @@ def activate(self): self.context = self.engine.create_execution_context() def allocate_buffers(self, shape_dict=None, device="cuda"): - for idx in range(trt_util.get_bindings_per_profile(self.engine)): - binding = self.engine[idx] - if shape_dict and binding in shape_dict: - shape = shape_dict[binding] + for binding in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(binding) + if shape_dict and name in shape_dict: + shape = shape_dict[name] else: - shape = self.engine.get_binding_shape(binding) - dtype = trt.nptype(self.engine.get_binding_dtype(binding)) - if self.engine.binding_is_input(binding): - self.context.set_binding_shape(idx, shape) + shape = self.engine.get_tensor_shape(name) + dtype = trt.nptype(self.engine.get_tensor_dtype(name)) + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self.context.set_input_shape(name, shape) tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) - self.tensors[binding] = tensor - self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + self.tensors[name] = tensor def infer(self, feed_dict, stream): - start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) - # shallow copy of ordered dict - device_buffers = copy(self.buffers) for name, buf in feed_dict.items(): - assert isinstance(buf, cuda.DeviceView) - device_buffers[name] = buf - bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] - noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + self.tensors[name].copy_(buf) + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + noerror = self.context.execute_async_v3(stream) if not noerror: raise ValueError("ERROR: inference failed.") @@ -325,10 +314,8 @@ def build_engines( force_engine_rebuild=False, static_batch=False, static_shape=True, - enable_preview=False, enable_all_tactics=False, timing_cache=None, - max_workspace_size=0, ): built_engines = {} if not os.path.isdir(onnx_dir): @@ -393,9 +380,7 @@ def build_engines( static_batch=static_batch, static_shape=static_shape, ), - enable_preview=enable_preview, timing_cache=timing_cache, - workspace_size=max_workspace_size, ) built_engines[model_name] = engine @@ -674,11 +659,11 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False) return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) -class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): +class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline): r""" Pipeline for inpainting using TensorRT accelerated Stable Diffusion. - This model inherits from [`StableDiffusionInpaintPipeline`]. Check the superclass documentation for the generic methods the + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: @@ -698,10 +683,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + def __init__( self, vae: AutoencoderKL, @@ -710,7 +697,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae", "vae_encoder"], @@ -722,24 +709,86 @@ def __init__( onnx_dir: str = "onnx", # TensorRT engine build parameters engine_dir: str = "engine", - build_preview_features: bool = True, force_engine_rebuild: bool = False, timing_cache: str = "timing_cache", ): - super().__init__( - vae, - text_encoder, - tokenizer, - unet, - scheduler, + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, - requires_safety_checker=requires_safety_checker, ) - self.vae.forward = self.vae.decode - self.stages = stages self.image_height, self.image_width = image_height, image_width self.inpaint = True @@ -750,7 +799,6 @@ def __init__( self.timing_cache = timing_cache self.build_static_batch = False self.build_dynamic_shape = False - self.build_preview_features = build_preview_features self.max_batch_size = max_batch_size # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. @@ -761,6 +809,11 @@ def __init__( self.models = {} # loaded in __loadModels() self.engine = {} # loaded in build_engines() + self.vae.forward = self.vae.decode + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + def __loadModels(self): # Load pipeline models self.embedding_dim = self.text_encoder.config.hidden_size @@ -779,6 +832,112 @@ def __loadModels(self): if "vae_encoder" in self.stages: self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker( + self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype + ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: + r""" + Runs the safety checker on the given image. + Args: + image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked. + device (torch.device): The device to run the safety checker on. + dtype (torch.dtype): The data type of the input image. + Returns: + (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and + a boolean indicating whether the image has a NSFW (Not Safe for Work) concept. + """ + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + @classmethod @validate_hf_hub_args def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): @@ -826,7 +985,6 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dt force_engine_rebuild=self.force_engine_rebuild, static_batch=self.build_static_batch, static_shape=not self.build_dynamic_shape, - enable_preview=self.build_preview_features, timing_cache=self.timing_cache, ) @@ -850,9 +1008,7 @@ def __preprocess_images(self, batch_size, images=()): return tuple(init_images) def __encode_image(self, init_image): - init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[ - "latent" - ] + init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"] init_latents = 0.18215 * init_latents return init_latents @@ -881,9 +1037,8 @@ def __encode_prompt(self, prompt, negative_prompt): .to(self.torch_device) ) - text_input_ids_inp = device_view(text_input_ids) # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ + text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[ "text_embeddings" ].clone() @@ -899,8 +1054,7 @@ def __encode_prompt(self, prompt, negative_prompt): .input_ids.type(torch.int32) .to(self.torch_device) ) - uncond_input_ids_inp = device_view(uncond_input_ids) - uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[ + uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[ "text_embeddings" ] @@ -924,18 +1078,15 @@ def __denoise_latent( # Predict the noise residual timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - sample_inp = device_view(latent_model_input) - timestep_inp = device_view(timestep_float) - embeddings_inp = device_view(text_embeddings) noise_pred = runEngine( self.engine["unet"], - {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, self.stream, )["latent"] # Perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample @@ -943,12 +1094,12 @@ def __denoise_latent( return latents def __decode_latent(self, latents): - images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] + images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"] images = (images / 2 + 0.5).clamp(0, 1) return images.cpu().permute(0, 2, 3, 1).float().numpy() def __loadResources(self, image_height, image_width, batch_size): - self.stream = cuda.Stream() + self.stream = cudart.cudaStreamCreate()[1] # Allocate buffers for TensorRT engine bindings for model_name, obj in self.models.items(): @@ -1112,5 +1263,6 @@ def __call__( # VAE decode latent images = self.__decode_latent(latents) + images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) - return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None) + return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index 6072a357bc5d..f8761053ed1a 100755 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -18,17 +18,19 @@ import gc import os from collections import OrderedDict -from copy import copy -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import onnx import onnx_graphsurgeon as gs +import PIL.Image import tensorrt as trt import torch +from cuda import cudart from huggingface_hub import snapshot_download from huggingface_hub.utils import validate_hf_hub_args from onnx import shape_inference +from packaging import version from polygraphy import cuda from polygraphy.backend.common import bytes_from_path from polygraphy.backend.onnx.loader import fold_constants @@ -40,23 +42,25 @@ network_from_onnx_path, save_engine, ) -from polygraphy.backend.trt import util as trt_util -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict, deprecate +from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( - StableDiffusionPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor """ Installation instructions python3 -m pip install --upgrade transformers diffusers>=0.16.0 -python3 -m pip install --upgrade tensorrt>=8.6.1 +python3 -m pip install --upgrade tensorrt~=10.2.0 python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime """ @@ -86,10 +90,6 @@ torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} -def device_view(t): - return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) - - class Engine: def __init__(self, engine_path): self.engine_path = engine_path @@ -110,10 +110,8 @@ def build( onnx_path, fp16, input_profile=None, - enable_preview=False, enable_all_tactics=False, timing_cache=None, - workspace_size=0, ): logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") p = Profile() @@ -122,20 +120,13 @@ def build( assert len(dims) == 3 p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - config_kwargs = {} - - config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] - if enable_preview: - # Faster dynamic shapes made optional since it increases engine build time. - config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) - if workspace_size > 0: - config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + extra_build_args = {} if not enable_all_tactics: - config_kwargs["tactic_sources"] = [] + extra_build_args["tactic_sources"] = [] engine = engine_from_network( network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), - config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), + config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args), save_timing_cache=timing_cache, ) save_engine(engine, path=self.engine_path) @@ -148,28 +139,24 @@ def activate(self): self.context = self.engine.create_execution_context() def allocate_buffers(self, shape_dict=None, device="cuda"): - for idx in range(trt_util.get_bindings_per_profile(self.engine)): - binding = self.engine[idx] - if shape_dict and binding in shape_dict: - shape = shape_dict[binding] + for binding in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(binding) + if shape_dict and name in shape_dict: + shape = shape_dict[name] else: - shape = self.engine.get_binding_shape(binding) - dtype = trt.nptype(self.engine.get_binding_dtype(binding)) - if self.engine.binding_is_input(binding): - self.context.set_binding_shape(idx, shape) + shape = self.engine.get_tensor_shape(name) + dtype = trt.nptype(self.engine.get_tensor_dtype(name)) + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self.context.set_input_shape(name, shape) tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) - self.tensors[binding] = tensor - self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + self.tensors[name] = tensor def infer(self, feed_dict, stream): - start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) - # shallow copy of ordered dict - device_buffers = copy(self.buffers) for name, buf in feed_dict.items(): - assert isinstance(buf, cuda.DeviceView) - device_buffers[name] = buf - bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] - noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + self.tensors[name].copy_(buf) + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + noerror = self.context.execute_async_v3(stream) if not noerror: raise ValueError("ERROR: inference failed.") @@ -310,10 +297,8 @@ def build_engines( force_engine_rebuild=False, static_batch=False, static_shape=True, - enable_preview=False, enable_all_tactics=False, timing_cache=None, - max_workspace_size=0, ): built_engines = {} if not os.path.isdir(onnx_dir): @@ -378,9 +363,7 @@ def build_engines( static_batch=static_batch, static_shape=static_shape, ), - enable_preview=enable_preview, timing_cache=timing_cache, - workspace_size=max_workspace_size, ) built_engines[model_name] = engine @@ -588,11 +571,11 @@ def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) -class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): +class TensorRTStableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion. - This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: @@ -612,10 +595,12 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] + def __init__( self, vae: AutoencoderKL, @@ -624,7 +609,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae"], @@ -632,28 +617,90 @@ def __init__( image_width: int = 768, max_batch_size: int = 16, # ONNX export parameters - onnx_opset: int = 17, + onnx_opset: int = 18, onnx_dir: str = "onnx", # TensorRT engine build parameters engine_dir: str = "engine", - build_preview_features: bool = True, force_engine_rebuild: bool = False, timing_cache: str = "timing_cache", ): - super().__init__( - vae, - text_encoder, - tokenizer, - unet, - scheduler, + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, - requires_safety_checker=requires_safety_checker, ) - self.vae.forward = self.vae.decode - self.stages = stages self.image_height, self.image_width = image_height, image_width self.inpaint = False @@ -664,7 +711,6 @@ def __init__( self.timing_cache = timing_cache self.build_static_batch = False self.build_dynamic_shape = False - self.build_preview_features = build_preview_features self.max_batch_size = max_batch_size # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. @@ -675,6 +721,11 @@ def __init__( self.models = {} # loaded in __loadModels() self.engine = {} # loaded in build_engines() + self.vae.forward = self.vae.decode + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + def __loadModels(self): # Load pipeline models self.embedding_dim = self.text_encoder.config.hidden_size @@ -691,6 +742,75 @@ def __loadModels(self): if "vae" in self.stages: self.models["vae"] = make_VAE(self.vae, **models_args) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Union[torch.Generator, List[torch.Generator]], + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Prepare the latent vectors for diffusion. + Args: + batch_size (int): The number of samples in the batch. + num_channels_latents (int): The number of channels in the latent vectors. + height (int): The height of the latent vectors. + width (int): The width of the latent vectors. + dtype (torch.dtype): The data type of the latent vectors. + device (torch.device): The device to place the latent vectors on. + generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation. + latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated. + Returns: + torch.Tensor: The prepared latent vectors. + """ + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker( + self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype + ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: + r""" + Runs the safety checker on the given image. + Args: + image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked. + device (torch.device): The device to run the safety checker on. + dtype (torch.dtype): The data type of the input image. + Returns: + (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and + a boolean indicating whether the image has a NSFW (Not Safe for Work) concept. + """ + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + @classmethod @validate_hf_hub_args def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): @@ -738,7 +858,6 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dt force_engine_rebuild=self.force_engine_rebuild, static_batch=self.build_static_batch, static_shape=not self.build_dynamic_shape, - enable_preview=self.build_preview_features, timing_cache=self.timing_cache, ) @@ -769,9 +888,8 @@ def __encode_prompt(self, prompt, negative_prompt): .to(self.torch_device) ) - text_input_ids_inp = device_view(text_input_ids) # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ + text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[ "text_embeddings" ].clone() @@ -787,8 +905,7 @@ def __encode_prompt(self, prompt, negative_prompt): .input_ids.type(torch.int32) .to(self.torch_device) ) - uncond_input_ids_inp = device_view(uncond_input_ids) - uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[ + uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[ "text_embeddings" ] @@ -812,18 +929,15 @@ def __denoise_latent( # Predict the noise residual timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - sample_inp = device_view(latent_model_input) - timestep_inp = device_view(timestep_float) - embeddings_inp = device_view(text_embeddings) noise_pred = runEngine( self.engine["unet"], - {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, self.stream, )["latent"] # Perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample @@ -831,12 +945,12 @@ def __denoise_latent( return latents def __decode_latent(self, latents): - images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] + images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"] images = (images / 2 + 0.5).clamp(0, 1) return images.cpu().permute(0, 2, 3, 1).float().numpy() def __loadResources(self, image_height, image_width, batch_size): - self.stream = cuda.Stream() + self.stream = cudart.cudaStreamCreate()[1] # Allocate buffers for TensorRT engine bindings for model_name, obj in self.models.items(): diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md index 052e383ef6f0..6f41c395629a 100644 --- a/examples/dreambooth/README_sd3.md +++ b/examples/dreambooth/README_sd3.md @@ -148,12 +148,12 @@ accelerate launch train_dreambooth_lora_sd3.py \ ``` ### Text Encoder Training -Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported. +Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: > [!NOTE] -> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL). -By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. +> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL). +By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. To perform DreamBooth LoRA with text-encoder training, run: ```bash @@ -185,4 +185,4 @@ accelerate launch train_dreambooth_lora_sd3.py \ 1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities. 2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917). -3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well. \ No newline at end of file +3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well. \ No newline at end of file diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py index 615eb834ac24..88a5d93d8edf 100644 --- a/examples/research_projects/controlnet/train_controlnet_webdataset.py +++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py @@ -43,7 +43,7 @@ from torch.utils.data import default_collate from torchvision import transforms from tqdm.auto import tqdm -from transformers import AutoTokenizer, DPTFeatureExtractor, DPTForDepthEstimation, PretrainedConfig +from transformers import AutoTokenizer, DPTForDepthEstimation, DPTImageProcessor, PretrainedConfig from webdataset.tariterators import ( base_plus_ext, tar_file_expander, @@ -205,7 +205,7 @@ def __init__( pin_memory: bool = False, persistent_workers: bool = False, control_type: str = "canny", - feature_extractor: Optional[DPTFeatureExtractor] = None, + feature_extractor: Optional[DPTImageProcessor] = None, ): if not isinstance(train_shards_path_or_url, str): train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] @@ -1011,7 +1011,7 @@ def main(args): controlnet = pre_controlnet if args.control_type == "depth": - feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") + feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas") depth_model.requires_grad_(False) else: diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb index c4c6292a9cde..571f1a0323a2 100644 --- a/examples/research_projects/gligen/demo.ipynb +++ b/examples/research_projects/gligen/demo.ipynb @@ -45,7 +45,7 @@ " UniPCMultistepScheduler,\n", " EulerDiscreteScheduler,\n", ")\n", - "from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n", + "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "\n", "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", diff --git a/examples/research_projects/promptdiffusion/README.md b/examples/research_projects/promptdiffusion/README.md index 7d76b1baa3df..33ffec312501 100644 --- a/examples/research_projects/promptdiffusion/README.md +++ b/examples/research_projects/promptdiffusion/README.md @@ -46,5 +46,4 @@ pipe.enable_model_cpu_offload() # generate image generator = torch.manual_seed(0) image = pipe("a tortoise", num_inference_steps=20, generator=generator, image_pair=[image_a,image_b], image=query).images[0] - ``` diff --git a/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py b/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py index 76b7b133ad70..26b56a21e865 100644 --- a/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py +++ b/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py @@ -2051,7 +2051,7 @@ def download_promptdiffusion_from_original_ckpt( default=512, type=int, help=( - "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2" " Base. Use 768 for Stable Diffusion v2." ), ) diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py index 74f75f7dce35..e035cdc9ac6b 100644 --- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py +++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py @@ -1253,7 +1253,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/examples/research_projects/rdm/pipeline_rdm.py b/examples/research_projects/rdm/pipeline_rdm.py index 201acb95aabd..f8093a3f217d 100644 --- a/examples/research_projects/rdm/pipeline_rdm.py +++ b/examples/research_projects/rdm/pipeline_rdm.py @@ -4,7 +4,7 @@ import torch from PIL import Image from retriever import Retriever, normalize_images, preprocess_images -from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer from diffusers import ( AutoencoderKL, @@ -47,7 +47,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin): scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -65,7 +65,7 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, retriever: Optional[Retriever] = None, ): super().__init__() diff --git a/examples/research_projects/rdm/retriever.py b/examples/research_projects/rdm/retriever.py index 6be9785a21f3..4ae4989bd8bb 100644 --- a/examples/research_projects/rdm/retriever.py +++ b/examples/research_projects/rdm/retriever.py @@ -6,7 +6,7 @@ import torch from datasets import Dataset, load_dataset from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPModel, PretrainedConfig +from transformers import CLIPImageProcessor, CLIPModel, PretrainedConfig from diffusers import logging @@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]): return images -def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.Tensor: +def preprocess_images(images: List[np.array], feature_extractor: CLIPImageProcessor) -> torch.Tensor: """ Preprocesses a list of images into a batch of tensors. @@ -95,14 +95,12 @@ def init_index(self): def build_index( self, model=None, - feature_extractor: CLIPFeatureExtractor = None, + feature_extractor: CLIPImageProcessor = None, torch_dtype=torch.float32, ): if not self.index_initialized: model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype) - feature_extractor = feature_extractor or CLIPFeatureExtractor.from_pretrained( - self.config.clip_name_or_path - ) + feature_extractor = feature_extractor or CLIPImageProcessor.from_pretrained(self.config.clip_name_or_path) self.dataset = get_dataset_with_emb_from_clip_model( self.dataset, model, @@ -136,7 +134,7 @@ def __init__( index: Index = None, dataset: Dataset = None, model=None, - feature_extractor: CLIPFeatureExtractor = None, + feature_extractor: CLIPImageProcessor = None, ): self.config = config self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor) @@ -148,7 +146,7 @@ def from_pretrained( index: Index = None, dataset: Dataset = None, model=None, - feature_extractor: CLIPFeatureExtractor = None, + feature_extractor: CLIPImageProcessor = None, **kwargs, ): config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs) @@ -156,7 +154,7 @@ def from_pretrained( @staticmethod def _build_index( - config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPFeatureExtractor = None + config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPImageProcessor = None ): dataset = dataset or load_dataset(config.dataset_name) dataset = dataset[config.dataset_set] diff --git a/examples/research_projects/sd3_lora_colab/README.md b/examples/research_projects/sd3_lora_colab/README.md index d90a1c9f0ae2..b7d7eedfb5dc 100644 --- a/examples/research_projects/sd3_lora_colab/README.md +++ b/examples/research_projects/sd3_lora_colab/README.md @@ -11,28 +11,28 @@ huggingface-cli login This will also allow us to push the trained model parameters to the Hugging Face Hub platform. -For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above. +For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above. ## How We make use of several techniques to make this possible: -* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB. +* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB. * In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of: * 8bit Adam for optimization through the `bitsandbytes` library. * Gradient checkpointing and gradient accumulation. * FP16 precision. - * Flash attention through `F.scaled_dot_product_attention()`. + * Flash attention through `F.scaled_dot_product_attention()`. -Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB. +Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB. ## Gotchas This project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are some commonly known gotchas that users should be aware of: -* Training of text encoders is purposefully disabled. -* Techniques such as prior-preservation is unsupported. +* Training of text encoders is purposefully disabled. +* Techniques such as prior-preservation is unsupported. * Custom instance captions for instance images are unsupported, but this should be relatively easy to integrate. Hopefully, this project gives you a template to extend it further to suit your needs. diff --git a/examples/research_projects/sdxl_flax/sdxl_single_aot.py b/examples/research_projects/sdxl_flax/sdxl_single_aot.py index 58447fd86daf..08bd13902aa9 100644 --- a/examples/research_projects/sdxl_flax/sdxl_single_aot.py +++ b/examples/research_projects/sdxl_flax/sdxl_single_aot.py @@ -18,7 +18,7 @@ NUM_DEVICES = jax.device_count() # 1. Let's start by downloading the model and loading it into our pipeline class -# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and +# Adhering to JAX's functional approach, the model's parameters are returned separately and # will have to be passed to the pipeline during inference pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True @@ -69,7 +69,7 @@ def replicate_all(prompt_ids, neg_prompt_ids, seed): # to the function and tell JAX which are static arguments, that is, arguments that # are known at compile time and won't change. In our case, it is num_inference_steps, # height, width and return_latents. -# Once the function is compiled, these parameters are ommited from future calls and +# Once the function is compiled, these parameters are omitted from future calls and # cannot be changed without modifying the code and recompiling. def aot_compile( prompt=default_prompt, diff --git a/scripts/convert_original_controlnet_to_diffusers.py b/scripts/convert_original_controlnet_to_diffusers.py index 44b22c33fe65..92aad4f09e70 100644 --- a/scripts/convert_original_controlnet_to_diffusers.py +++ b/scripts/convert_original_controlnet_to_diffusers.py @@ -42,7 +42,7 @@ default=512, type=int, help=( - "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2" " Base. Use 768 for Stable Diffusion v2." ), ) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 58f0ad292ead..7e7925b0a412 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -67,7 +67,7 @@ default=None, type=int, help=( - "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2" " Base. Use 768 for Stable Diffusion v2." ), ) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d58bbdac1867..39da57cec06c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -12,6 +12,7 @@ is_note_seq_available, is_onnx_available, is_scipy_available, + is_sentencepiece_available, is_torch_available, is_torchsde_available, is_transformers_available, @@ -246,12 +247,11 @@ "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", - "ChatGLMModel", - "ChatGLMTokenizer", "CLIPImageProjection", "CycleDiffusionPipeline", "FluxPipeline", "HunyuanDiTControlNetPipeline", + "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", @@ -385,6 +385,19 @@ else: _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"]) +try: + if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_") + ] + +else: + _import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPipeline"]) + try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() @@ -669,12 +682,11 @@ AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, - ChatGLMModel, - ChatGLMTokenizer, CLIPImageProjection, CycleDiffusionPipeline, FluxPipeline, HunyuanDiTControlNetPipeline, + HunyuanDiTPAGPipeline, HunyuanDiTPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, @@ -703,8 +715,6 @@ KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, - KolorsImg2ImgPipeline, - KolorsPipeline, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, LattePipeline, @@ -802,6 +812,13 @@ else: from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline + try: + if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403 + else: + from .pipelines import KolorsImg2ImgPipeline, KolorsPipeline try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 5db13825c9eb..bccd37ddc42f 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -66,6 +66,7 @@ def text_encoder_attn_modules(text_encoder): "SD3LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "LoraLoaderMixin", + "FluxLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] @@ -83,6 +84,7 @@ def text_encoder_attn_modules(text_encoder): from .ip_adapter import IPAdapterMixin from .lora_pipeline import ( AmusedLoraLoaderMixin, + FluxLoraLoaderMixin, LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 73273618956a..f612cc0c6e53 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1475,6 +1475,481 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t super().unfuse_lora(components=components) +class FluxLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`FluxTransformer2DModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + + Specific to [`StableDiffusion3Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "text_encoder"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components) + + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index fd6c639a7cdf..89d6a28b14dd 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -32,6 +32,7 @@ "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, "SD3Transformer2DModel": lambda model_cls, weights: weights, + "FluxTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 855085c0d933..80b462ef6a4f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -539,7 +539,7 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten return tensor def get_attention_scores( - self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: r""" Compute the attention scores. @@ -1785,6 +1785,11 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( @@ -2142,6 +2147,253 @@ def __call__( return hidden_states +class PAGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is @@ -2314,6 +2566,11 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( @@ -3458,4 +3715,6 @@ def __init__(self): CustomDiffusionAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, + PAGCFGHunyuanAttnProcessor2_0, + PAGHunyuanAttnProcessor2_0, ] diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index bc1273aaab7d..cb577e33c670 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -32,10 +32,7 @@ from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn from .unets.unet_2d_condition import UNet2DConditionModel -from .unets.unet_3d_blocks import ( - CrossAttnDownBlockMotion, - DownBlockMotion, -) +from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -317,7 +314,6 @@ def __init__( temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - temporal_double_self_attention=False, ) elif down_block_type == "DownBlockMotion": down_block = DownBlockMotion( @@ -334,7 +330,6 @@ def __init__( add_downsample=not is_final_block, temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, - temporal_double_self_attention=False, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], ) else: diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 0fa21755f09c..f676a70f060a 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -285,7 +285,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): upcast_attention (`bool`, defaults to `True`): Whether the attention computation should always be upcasted. max_norm_num_groups (`int`, defaults to 32): - Maximum number of groups in group normal. The actual number will the the largest divisor of the respective + Maximum number of groups in group normal. The actual number will be the largest divisor of the respective channels, that is <= max_norm_num_groups. """ diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 2821ce0330fc..a81f9e17cd0e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -302,7 +302,7 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): If True, return real part and imaginary part separately. Otherwise, return complex numbers. Returns: - `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`. + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. """ start, stop = crops_coords grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) @@ -902,7 +902,7 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) if self.use_style_cond_and_image_meta_size: - # extra condition2: image meta size embdding + # extra condition2: image meta size embedding image_meta_size = self.size_proj(image_meta_size.view(-1)) image_meta_size = image_meta_size.to(dtype=hidden_dtype) image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 73ccc03b38c4..391ca1418d34 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -1,4 +1,4 @@ -# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import PeftAdapterMixin from ...models.attention import FeedForward from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0 from ...models.modeling_utils import ModelMixin @@ -65,7 +65,6 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3, ) - return emb.unsqueeze(1) @@ -123,6 +122,7 @@ def forward( ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) hidden_states = gate * self.proj_out(hidden_states) hidden_states = residual + hidden_states @@ -227,7 +227,7 @@ def forward( return encoder_hidden_states, hidden_states -class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ The Transformer model introduced in Flux. @@ -259,12 +259,13 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], ): super().__init__() self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56]) + self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) @@ -302,6 +303,10 @@ def __init__( self.gradient_checkpointing = False + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 51c743a14d40..8b472a89e13d 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -27,17 +27,58 @@ TemporalConvLayer, Upsample2D, ) -from ..transformers.dual_transformer_2d import DualTransformer2DModel from ..transformers.transformer_2d import Transformer2DModel from ..transformers.transformer_temporal import ( TransformerSpatioTemporalModel, TransformerTemporalModel, ) +from .unet_motion_model import ( + CrossAttnDownBlockMotion, + CrossAttnUpBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UpBlockMotion, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class DownBlockMotion(DownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead." + deprecate("DownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead." + deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UpBlockMotion(UpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead." + deprecate("UpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead." + deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead." + deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + def get_down_block( down_block_type: str, num_layers: int, @@ -64,8 +105,6 @@ def get_down_block( ) -> Union[ "DownBlock3D", "CrossAttnDownBlock3D", - "DownBlockMotion", - "CrossAttnDownBlockMotion", "DownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", ]: @@ -105,49 +144,6 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, dropout=dropout, ) - if down_block_type == "DownBlockMotion": - return DownBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) - elif down_block_type == "CrossAttnDownBlockMotion": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") - return CrossAttnDownBlockMotion( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) elif down_block_type == "DownBlockSpatioTemporal": # added for SDV return DownBlockSpatioTemporal( @@ -203,8 +199,6 @@ def get_up_block( ) -> Union[ "UpBlock3D", "CrossAttnUpBlock3D", - "UpBlockMotion", - "CrossAttnUpBlockMotion", "UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", ]: @@ -246,51 +240,6 @@ def get_up_block( resolution_idx=resolution_idx, dropout=dropout, ) - if up_block_type == "UpBlockMotion": - return UpBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - resolution_idx=resolution_idx, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) - elif up_block_type == "CrossAttnUpBlockMotion": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") - return CrossAttnUpBlockMotion( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - resolution_idx=resolution_idx, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) elif up_block_type == "UpBlockSpatioTemporal": # added for SDV return UpBlockSpatioTemporal( @@ -947,924 +896,6 @@ def forward( return hidden_states -class DownBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - temporal_num_attention_heads: Union[int, Tuple[int]] = 1, - temporal_cross_attention_dim: Optional[int] = None, - temporal_max_seq_length: int = 32, - temporal_double_self_attention: bool = True, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - motion_modules = [] - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" - ) - - # support for variable number of attention head per temporal layers - if isinstance(temporal_num_attention_heads, int): - temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers - elif len(temporal_num_attention_heads) != num_layers: - raise ValueError( - f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" - ) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads[i], - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads[i], - double_self_attention=temporal_double_self_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - num_frames: int = 1, - *args, - **kwargs, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - output_states = () - - blocks = zip(self.resnets, self.motion_modules) - for resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - downsample_padding: int = 1, - add_downsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - temporal_double_self_attention: bool = True, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - double_self_attention=temporal_double_self_attention, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - encoder_attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - additional_residuals: Optional[torch.Tensor] = None, - ): - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - output_states = () - - blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) - for i, (resnet, attn, motion_module) in enumerate(blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - - # apply additional residuals to the output of the last pair of resnet and attention blocks - if i == len(blocks) - 1 and additional_residuals is not None: - hidden_states = hidden_states + additional_residuals - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnUpBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" - ) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_tuple: Tuple[torch.Tensor, ...], - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - blocks = zip(self.resnets, self.attentions, self.motion_modules) - for resnet, attn, motion_module in blocks: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - motion_modules = [] - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_tuple: Tuple[torch.Tensor, ...], - temb: Optional[torch.Tensor] = None, - upsample_size=None, - num_frames: int = 1, - *args, - **kwargs, - ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - blocks = zip(self.resnets, self.motion_modules) - - for resnet, motion_module in blocks: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UNetMidBlockCrossAttnMotion(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - output_scale_factor: float = 1.0, - cross_attention_dim: int = 1280, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_num_attention_heads: int = 1, - temporal_cross_attention_dim: Optional[int] = None, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." - ) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - motion_modules = [] - - for i in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, - attention_head_dim=in_channels // temporal_num_attention_heads, - in_channels=in_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - activation_fn="geglu", - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - hidden_states = self.resnets[0](hidden_states, temb) - - blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) - for attn, resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - class MidBlockTemporalDecoder(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index c8ea0ecc3feb..e96867bc3ed0 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch @@ -20,7 +22,9 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import logging +from ...utils import BaseOutput, deprecate, is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..attention import BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -35,24 +39,1094 @@ ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin -from ..transformers.transformer_temporal import TransformerTemporalModel +from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel -from .unet_3d_blocks import ( - CrossAttnDownBlockMotion, - CrossAttnUpBlockMotion, - DownBlockMotion, - UNetMidBlockCrossAttnMotion, - UpBlockMotion, - get_down_block, - get_up_block, -) -from .unet_3d_condition import UNet3DConditionOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@dataclass +class UNetMotionOutput(BaseOutput): + """ + The output of [`UNetMotionOutput`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor + + +class AnimateDiffTransformer3D(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + """ + The [`AnimateDiffTransformer3D`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Returns: + torch.Tensor: + The output tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + return output + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: Union[int, Tuple[int]] = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" + ) + + # support for variable number of attention head per temporal layers + if isinstance(temporal_num_attention_heads, int): + temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers + elif len(temporal_num_attention_heads) != num_layers: + raise ValueError( + f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads[i], + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads[i], + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + num_frames: int = 1, + *args, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + additional_residuals: Optional[torch.Tensor] = None, + ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size=None, + num_frames: int = 1, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + class MotionModules(nn.Module): def __init__( self, @@ -79,7 +1153,7 @@ def __init__( for i in range(layers_per_block): self.motion_modules.append( - TransformerTemporalModel( + AnimateDiffTransformer3D( in_channels=in_channels, num_layers=transformer_layers_per_block[i], norm_num_groups=norm_num_groups, @@ -394,26 +1468,45 @@ def __init__( output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - use_linear_projection=use_linear_projection, - dual_cross_attention=False, - temporal_num_attention_heads=motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=transformer_layers_per_block[i], - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - ) + if down_block_type == "CrossAttnDownBlockMotion": + down_block = CrossAttnDownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + downsample_padding=downsample_padding, + add_downsample=not is_final_block, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + elif down_block_type == "DownBlockMotion": + down_block = DownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + downsample_padding=downsample_padding, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError( + "Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" + ) + self.down_blocks.append(down_block) # mid @@ -488,27 +1581,47 @@ def __init__( else: add_upsample = False - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - dual_cross_attention=False, - resolution_idx=i, - use_linear_projection=use_linear_projection, - temporal_num_attention_heads=reversed_motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=reverse_transformer_layers_per_block[i], - temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], - ) + if up_block_type == "CrossAttnUpBlockMotion": + up_block = CrossAttnUpBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reverse_transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=reversed_num_attention_heads[i], + cross_attention_dim=reversed_cross_attention_dim[i], + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + elif up_block_type == "UpBlockMotion": + up_block = UpBlockMotion( + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=add_upsample, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError( + "Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`" + ) + self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -958,7 +2071,7 @@ def forward( down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, - ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: + ) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]: r""" The [`UNetMotionModel`] forward method. @@ -984,12 +2097,12 @@ def forward( mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain tuple. Returns: - [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + [`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. @@ -1173,4 +2286,4 @@ def forward( if not return_dict: return (sample,) - return UNet3DConditionOutput(sample=sample) + return UNetMotionOutput(sample=sample) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index fbaa14365830..f20bd94edffa 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -87,7 +87,7 @@ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_ The optimizer for which to schedule the learning rate. step_rules (`string`): The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate - if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 + if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 10f6c4a92054..69003764b373 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -10,6 +10,7 @@ is_librosa_available, is_note_seq_available, is_onnx_available, + is_sentencepiece_available, is_torch_available, is_torch_npu_available, is_transformers_available, @@ -145,6 +146,7 @@ _import_structure["pag"].extend( [ "AnimateDiffPAGPipeline", + "HunyuanDiTPAGPipeline", "StableDiffusionPAGPipeline", "StableDiffusionControlNetPAGPipeline", "StableDiffusionXLPAGPipeline", @@ -204,12 +206,6 @@ "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", ] - _import_structure["kolors"] = [ - "KolorsPipeline", - "KolorsImg2ImgPipeline", - "ChatGLMModel", - "ChatGLMTokenizer", - ] _import_structure["latent_consistency_models"] = [ "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", @@ -349,6 +345,22 @@ "StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline", ] + +try: + if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import ( + dummy_torch_and_transformers_and_sentencepiece_objects, + ) + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects)) +else: + _import_structure["kolors"] = [ + "KolorsPipeline", + "KolorsImg2ImgPipeline", + ] + try: if not is_flax_available(): raise OptionalDependencyNotAvailable() @@ -506,12 +518,6 @@ Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, ) - from .kolors import ( - ChatGLMModel, - ChatGLMTokenizer, - KolorsImg2ImgPipeline, - KolorsPipeline, - ) from .latent_consistency_models import ( LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, @@ -532,6 +538,7 @@ from .musicldm import MusicLDMPipeline from .pag import ( AnimateDiffPAGPipeline, + HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, @@ -640,6 +647,17 @@ StableDiffusionXLKDiffusionPipeline, ) + try: + if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_sentencepiece_objects import * + else: + from .kolors import ( + KolorsImg2ImgPipeline, + KolorsPipeline, + ) + try: if not is_flax_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 854cfaa47b7a..8c74c4797de6 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -18,6 +18,7 @@ from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin +from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, @@ -47,9 +48,9 @@ KandinskyV22Pipeline, ) from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline -from .kolors import KolorsImg2ImgPipeline, KolorsPipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .pag import ( + HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, @@ -85,6 +86,7 @@ ("stable-diffusion-3", StableDiffusion3Pipeline), ("if", IFPipeline), ("hunyuan", HunyuanDiTPipeline), + ("hunyuan-pag", HunyuanDiTPAGPipeline), ("kandinsky", KandinskyCombinedPipeline), ("kandinsky22", KandinskyV22CombinedPipeline), ("kandinsky3", Kandinsky3Pipeline), @@ -101,7 +103,6 @@ ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline), ("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("auraflow", AuraFlowPipeline), - ("kolors", KolorsPipeline), ("flux", FluxPipeline), ] ) @@ -119,7 +120,6 @@ ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), - ("kolors", KolorsImg2ImgPipeline), ] ) @@ -158,6 +158,12 @@ ] ) +if is_sentencepiece_available(): + from .kolors import KolorsPipeline + + AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline + SUPPORTED_TASKS_MAPPINGS = [ AUTO_TEXT2IMAGE_PIPELINES_MAPPING, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 0bde37244f5a..750b98554dee 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1272,7 +1272,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index f4c9881b4acf..3e88ace661b8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -1244,7 +1244,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index aa46f4e9b617..9f7d464f9a91 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -1408,7 +1408,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index c18c957a7c92..1c2f7bf369a5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1739,7 +1739,7 @@ def denoising_value_valid(dnv): ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 490b37f47dc2..6cd9db08c2dc 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1487,7 +1487,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 8acd290c2671..278d8d953e9b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -76,13 +76,13 @@ >>> import numpy as np >>> from PIL import Image - >>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation + >>> from transformers import DPTImageProcessor, DPTForDepthEstimation >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") - >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") + >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( ... "diffusers/controlnet-depth-sdxl-1.0-small", ... variant="fp16", @@ -1551,7 +1551,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py index 5b6fc2b393c0..8a2cc08dbb2b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -23,7 +23,7 @@ from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel from ...schedulers import ( @@ -149,7 +149,7 @@ def __init__( FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py index 701e7a3a81b2..9e91986896bd 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ....image_processor import VaeImageProcessor from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin @@ -66,8 +66,8 @@ class StableDiffusionModelEditingPipeline( Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. - feature_extractor ([`~transformers.CLIPFeatureExtractor`]): - A `CLIPFeatureExtractor` to extract features from generated images; used as inputs to the `safety_checker`. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. with_to_k ([`bool`]): Whether to edit the key projection matrices along with the value projection matrices. with_augs ([`list`]): @@ -86,7 +86,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: SchedulerMixin, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, with_to_k: bool = True, with_augs: list = AUGS_CONST, diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 33a7c3c7c6cc..513feaabc41f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -17,15 +17,10 @@ import numpy as np import torch -from transformers import ( - CLIPTextModel, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor -from ...loaders import SD3LoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -142,7 +137,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): +class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): r""" The Flux pipeline for text-to-image generation. @@ -155,22 +150,18 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModelWithProjection`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, - with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` - as its dimension. - text_encoder_2 ([`CLIPTextModelWithProjection`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), - specifically the - [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) - variant. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`CLIPTokenizer`): + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): Second Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" @@ -323,9 +314,6 @@ def encode_prompt( pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ @@ -333,7 +321,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -366,16 +354,18 @@ def encode_prompt( ) if self.text_encoder is not None: - if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: - if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=self.text_encoder.dtype) + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -748,7 +738,6 @@ def __call__( else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) diff --git a/src/diffusers/pipelines/kolors/__init__.py b/src/diffusers/pipelines/kolors/__init__.py index 843ee93c257f..671d22e9f433 100644 --- a/src/diffusers/pipelines/kolors/__init__.py +++ b/src/diffusers/pipelines/kolors/__init__.py @@ -5,6 +5,7 @@ OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, + is_sentencepiece_available, is_torch_available, is_transformers_available, ) @@ -14,12 +15,12 @@ _import_structure = {} try: - if not (is_transformers_available() and is_torch_available()): + if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 + from ...utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects)) else: _import_structure["pipeline_kolors"] = ["KolorsPipeline"] _import_structure["pipeline_kolors_img2img"] = ["KolorsImg2ImgPipeline"] @@ -28,10 +29,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: - if not (is_transformers_available() and is_torch_available()): + if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * + from ...utils.dummy_torch_and_transformers_and_sentencepiece_objects import * else: from .pipeline_kolors import KolorsPipeline diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index b80064eb5e9a..5bdbaaac0829 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] + _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] @@ -41,6 +42,7 @@ else: from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline + from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 7c9bb2d098d2..728f730c9904 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re +from typing import Dict, List, Tuple, Union + import torch +import torch.nn as nn from ...models.attention_processor import ( + Attention, + AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) @@ -25,140 +31,56 @@ class PAGMixin: - r"""Mixin class for PAG.""" - - @staticmethod - def _check_input_pag_applied_layer(layer): - r""" - Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats: - "{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type` - can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be - in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}" - """ - - layer_splits = layer.split(".") - - if len(layer_splits) > 3: - raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.") - - if len(layer_splits) >= 1: - if layer_splits[0] not in ["down", "mid", "up"]: - raise ValueError( - f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}" - ) - - if len(layer_splits) >= 2: - if not layer_splits[1].startswith("block_"): - raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'") - - if len(layer_splits) == 3: - layer_2 = layer_splits[2] - if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"): - raise ValueError( - f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'" - ) + r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): r""" Set the attention processor for the PAG layers. """ - if do_classifier_free_guidance: - pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() + pag_attn_processors = self._pag_attn_processors + if pag_attn_processors is None: + raise ValueError( + "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." + ) + + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + if hasattr(self, "unet"): + model: nn.Module = self.unet else: - pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() + model: nn.Module = self.transformer - def is_self_attn(module_name): + def is_self_attn(module: nn.Module) -> bool: r""" Check if the module is self-attention module based on its name. """ - return "attn1" in module_name and "to" not in name + return isinstance(module, Attention) and not module.is_cross_attention - def get_block_type(module_name): - r""" - Get the block type from the module name. Can be "down", "mid", "up". - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down" - # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down" - return module_name.split(".")[0].split("_")[0] + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name - def get_block_index(module_name): - r""" - Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g. - mid_block) and index is ommited from the name, it will be "block_0". - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1" - # mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0" - module_name_splits = module_name.split(".") - block_index = module_name_splits[1] - if "attentions" in block_index or "motion_modules" in block_index: - return "block_0" - else: - return f"block_{block_index}" - - def get_attn_index(module_name): - r""" - Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0", - "motion_modules_1", ... - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - # mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0" - # mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0" - module_name_split = module_name.split(".") - mid_name = module_name_split[1] - down_name = module_name_split[2] - if "attentions" in down_name: - return f"attentions_{module_name_split[3]}" - if "attentions" in mid_name: - return f"attentions_{module_name_split[2]}" - if "motion_modules" in down_name: - return f"motion_modules_{module_name_split[3]}" - if "motion_modules" in mid_name: - return f"motion_modules_{module_name_split[2]}" - - for pag_layer_input in pag_applied_layers: + for layer_id in pag_applied_layers: # for each PAG layer input, we find corresponding self-attention layers in the unet model target_modules = [] - pag_layer_input_splits = pag_layer_input.split(".") - - if len(pag_layer_input_splits) == 1: - # when the layer input only contains block_type. e.g. "mid", "down", "up" - block_type = pag_layer_input_splits[0] - for name, module in self.unet.named_modules(): - if is_self_attn(name) and get_block_type(name) == block_type: - target_modules.append(module) - - elif len(pag_layer_input_splits) == 2: - # when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0" - block_type = pag_layer_input_splits[0] - block_index = pag_layer_input_splits[1] - for name, module in self.unet.named_modules(): - if ( - is_self_attn(name) - and get_block_type(name) == block_type - and get_block_index(name) == block_index - ): - target_modules.append(module) - - elif len(pag_layer_input_splits) == 3: - # when the layer input contains block_type, block_index and attention_index. - # e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1" - block_type = pag_layer_input_splits[0] - block_index = pag_layer_input_splits[1] - attn_index = pag_layer_input_splits[2] - - for name, module in self.unet.named_modules(): - if ( - is_self_attn(name) - and get_block_type(name) == block_type - and get_block_index(name) == block_index - and get_attn_index(name) == attn_index - ): - target_modules.append(module) + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(module) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + logger.debug(f"Applying PAG to layer: {name}") + target_modules.append(module) if len(target_modules) == 0: - raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") for module in target_modules: module.processor = pag_attn_proc @@ -221,239 +143,95 @@ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free cond = torch.cat([uncond, cond], dim=0) return cond - def set_pag_applied_layers(self, pag_applied_layers): + def set_pag_applied_layers( + self, + pag_applied_layers: Union[str, List[str]], + pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), + ): r""" - set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - """ + Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + + Args: + pag_applied_layers (`str` or `List[str]`): + One or more strings identifying the layer names, or a simple regex for matching multiple layers, where + PAG is to be applied. A few ways of expected usage are as follows: + - Single layers specified as - "blocks.{layer_index}" + - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] + - Multiple layers as a block name - "mid" + - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" + pag_attn_processors: + (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention + processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second + attention processor is for PAG with CFG disabled (unconditional only). + """ + + if not hasattr(self, "_pag_attn_processors"): + self._pag_attn_processors = None if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") - for pag_layer in pag_applied_layers: - self._check_input_pag_applied_layer(pag_layer) + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) self.pag_applied_layers = pag_applied_layers + self._pag_attn_processors = pag_attn_processors @property - def pag_scale(self): - """ - Get the scale factor for the perturbed attention guidance. - """ + def pag_scale(self) -> float: + r"""Get the scale factor for the perturbed attention guidance.""" return self._pag_scale @property - def pag_adaptive_scale(self): - """ - Get the adaptive scale factor for the perturbed attention guidance. - """ + def pag_adaptive_scale(self) -> float: + r"""Get the adaptive scale factor for the perturbed attention guidance.""" return self._pag_adaptive_scale @property - def do_pag_adaptive_scaling(self): - """ - Check if the adaptive scaling is enabled for the perturbed attention guidance. - """ + def do_pag_adaptive_scaling(self) -> bool: + r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property - def do_perturbed_attention_guidance(self): - """ - Check if the perturbed attention guidance is enabled. - """ + def do_perturbed_attention_guidance(self) -> bool: + r"""Check if the perturbed attention guidance is enabled.""" return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property - def pag_attn_processors(self): + def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model with the key as the name of the layer. """ - processors = {} - for name, proc in self.unet.attn_processors.items(): - if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): - processors[name] = proc - return processors + if self._pag_attn_processors is None: + return {} + valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} -class PixArtPAGMixin: - @staticmethod - def _check_input_pag_applied_layer(layer): - r""" - Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}. - """ - - # Check if the layer index is valid (should be int or str of int) - if isinstance(layer, int): - return # Valid layer index - - if isinstance(layer, str): - if layer.isdigit(): - return # Valid layer index - - # If it is not a valid layer index, raise a ValueError - raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.") - - def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - if do_classifier_free_guidance: - pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() - else: - pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() - - def is_self_attn(module_name): - r""" - Check if the module is self-attention module based on its name. - """ - return ( - "attn1" in module_name and len(module_name.split(".")) == 3 - ) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ... - - def get_block_index(module_name): - r""" - Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g. - mid_block) and index is ommited from the name, it will be "block_0". - """ - # transformer_blocks.23.attn -> "23" - return module_name.split(".")[1] - - for pag_layer_input in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the transformer model - target_modules = [] - - block_index = str(pag_layer_input) - - for name, module in self.transformer.named_modules(): - if is_self_attn(name) and get_block_index(name) == block_index: - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") - - for module in target_modules: - module.processor = pag_attn_proc - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers - def set_pag_applied_layers(self, pag_applied_layers): - r""" - set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - - for pag_layer in pag_applied_layers: - self._check_input_pag_applied_layer(pag_layer) - - self.pag_applied_layers = pag_applied_layers - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale - def _get_pag_scale(self, t): - r""" - Get the scale factor for the perturbed attention guidance at timestep `t`. - """ - - if self.do_pag_adaptive_scaling: - signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) - if signal_scale < 0: - signal_scale = 0 - return signal_scale - else: - return self.pag_scale - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance - def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): - r""" - Apply perturbed attention guidance to the noise prediction. - - Args: - noise_pred (torch.Tensor): The noise prediction tensor. - do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. - guidance_scale (float): The scale factor for the guidance term. - t (int): The current time step. - - Returns: - torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. - """ - pag_scale = self._get_pag_scale(t) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) + processors = {} + # We could have iterated through the self.components.items() and checked if a component is + # `ModelMixin` subclassed but that can include a VAE too. + if hasattr(self, "unet"): + denoiser_module = self.unet + elif hasattr(self, "transformer"): + denoiser_module = self.transformer else: - noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - return noise_pred - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance - def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): - """ - Prepares the perturbed attention guidance for the PAG model. - - Args: - cond (torch.Tensor): The conditional input tensor. - uncond (torch.Tensor): The unconditional input tensor. - do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. - - Returns: - torch.Tensor: The prepared perturbed attention guidance tensor. - """ + raise ValueError("No denoiser module found.") - cond = torch.cat([cond] * 2, dim=0) - - if do_classifier_free_guidance: - cond = torch.cat([uncond, cond], dim=0) - return cond - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale - def pag_scale(self): - """ - Get the scale factor for the perturbed attention guidance. - """ - return self._pag_scale - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale - def pag_adaptive_scale(self): - """ - Get the adaptive scale factor for the perturbed attention guidance. - """ - return self._pag_adaptive_scale - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling - def do_pag_adaptive_scaling(self): - """ - Check if the adaptive scaling is enabled for the perturbed attention guidance. - """ - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance - def do_perturbed_attention_guidance(self): - """ - Check if the perturbed attention guidance is enabled. - """ - return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer - def pag_attn_processors(self): - r""" - Returns: - `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model - with the key as the name of the layer. - """ - - processors = {} - for name, proc in self.transformer.attn_processors.items(): - if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): + for name, proc in denoiser_module.attn_processors.items(): + if proc.__class__ in valid_attn_processors: processors[name] = proc + return processors diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py index 6dc21c9d4538..9bac883b5c99 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -1249,7 +1249,7 @@ def __call__( ) if guess_mode and self.do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. + # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py new file mode 100644 index 000000000000..63126cc5aae9 --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -0,0 +1,953 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0 +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=[14], + ... ).to("cuda") + + >>> # prompt = "an astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`MT5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: Optional[StableDiffusionSafetyChecker] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + requires_safety_checker: bool = True, + text_encoder_2: Optional[T5EncoderModel] = None, + tokenizer_2: Optional[MT5Tokenizer] = None, + pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`Tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance + ) + prompt_embeds_2 = self._prepare_perturbed_attention_guidance( + prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance + ) + prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance( + prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance + ) + add_time_ids = torch.cat([add_time_ids] * 3, dim=0) + style = torch.cat([style] * 3, dim=0) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index 1188ffe52ed7..8e5e6cbaf5ad 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -40,7 +40,7 @@ ASPECT_RATIO_1024_BIN, ) from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN -from .pag_utils import PixArtPAGMixin +from .pag_utils import PAGMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -61,7 +61,7 @@ >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", ... torch_dtype=torch.float16, - ... pag_applied_layers=[14], + ... pag_applied_layers=["blocks.14"], ... enable_pag=True, ... ) >>> pipe = pipe.to("cuda") @@ -132,7 +132,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin): +class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): r""" [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation using PixArt-Sigma. @@ -164,7 +164,7 @@ def __init__( vae: AutoencoderKL, transformer: PixArtTransformer2DModel, scheduler: KarrasDiffusionSchedulers, - pag_applied_layers: Union[str, List[str]] = "1", # 1st transformer block + pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block ): super().__init__() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index e37506a60c61..b3b103742061 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -129,7 +129,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, - pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1"], ["up.block_0.attentions_0"] + pag_applied_layers: Union[str, List[str]] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"] ): super().__init__() if isinstance(unet, UNet2DConditionModel): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index dddda6851e15..495e0e1a83fb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -20,7 +20,7 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation +from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DPTImageProcessor from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor @@ -111,7 +111,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, depth_estimator: DPTForDepthEstimation, - feature_extractor: DPTFeatureExtractor, + feature_extractor: DPTImageProcessor, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py index 62584beec6a9..52ccd5612776 100644 --- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -18,7 +18,7 @@ import PIL.Image import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin @@ -138,7 +138,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py index 67b9b927f210..c6748ad418fe 100644 --- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py @@ -19,7 +19,7 @@ import PIL.Image import torch from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPProcessor, CLIPTextModel, CLIPTokenizer, @@ -193,7 +193,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 55a8694c16e9..3cb7c26bb6a2 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -19,7 +19,7 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin @@ -209,7 +209,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -225,7 +225,7 @@ def __init__( adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 287298b87a73..0ea197e42e62 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -237,7 +237,7 @@ class StableDiffusionXLAdapterPipeline( safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d2633d2ec9e7..c7ea2bcc5b7f 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -78,6 +78,7 @@ is_peft_version, is_safetensors_available, is_scipy_available, + is_sentencepiece_available, is_tensorboard_available, is_timm_available, is_torch_available, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py new file mode 100644 index 000000000000..a70d003f7fc6 --- /dev/null +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py @@ -0,0 +1,32 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class KolorsImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) + + +class KolorsPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3e9a33503906..ad3a1663daa2 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -242,22 +242,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class ChatGLMModel(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class ChatGLMTokenizer(metaclass=DummyObject): +class CLIPImageProjection(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -272,7 +257,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class CLIPImageProjection(metaclass=DummyObject): +class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -287,7 +272,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class CycleDiffusionPipeline(metaclass=DummyObject): +class FluxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -302,7 +287,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class FluxPipeline(metaclass=DummyObject): +class HunyuanDiTControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -317,7 +302,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanDiTControlNetPipeline(metaclass=DummyObject): +class HunyuanDiTPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -752,36 +737,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class KolorsImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class KolorsPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 44477df2e220..09cb715a6068 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -294,6 +294,13 @@ except importlib_metadata.PackageNotFoundError: _torchvision_available = False +_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None +try: + _sentencepiece_version = importlib_metadata.version("sentencepiece") + logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}") +except importlib_metadata.PackageNotFoundError: + _sentencepiece_available = False + _matplotlib_available = importlib.util.find_spec("matplotlib") is not None try: _matplotlib_version = importlib_metadata.version("matplotlib") @@ -436,6 +443,10 @@ def is_google_colab(): return _is_google_colab +def is_sentencepiece_available(): + return _sentencepiece_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -553,6 +564,12 @@ def is_google_colab(): {0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors` """ +# docstyle-ignore +SENTENCEPIECE_IMPORT_ERROR = """ +{0} requires the sentencepiece library but it was not found in your environment. You can install it with pip: `pip install sentencepiece` +""" + + # docstyle-ignore BITSANDBYTES_IMPORT_ERROR = """ {0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes` @@ -581,6 +598,7 @@ def is_google_colab(): ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), + ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ] ) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py new file mode 100644 index 000000000000..c0f0684ac4de --- /dev/null +++ b/tests/lora/test_lora_layers_flux.py @@ -0,0 +1,92 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +import torch +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = FluxPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_kwargs = {} + uses_flow_matching = True + transformer_kwargs = { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + transformer_cls = FluxTransformer2DModel + vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 1, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, + } + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + @property + def output_shape(self): + return (1, 8, 8, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 46b965ec33d9..0aee4f57c2c6 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download from huggingface_hub.repocard import RepoCard from safetensors.torch import load_file +from transformers import CLIPTextModel, CLIPTokenizer from diffusers import ( AutoPipelineForImage2Image, @@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], "latent_channels": 4, } + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + + @property + def output_shape(self): + return (1, 64, 64, 3) def setUp(self): super().setUp() diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 9ce559be7f06..31c62f27a75a 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -15,10 +15,9 @@ import sys import unittest -from diffusers import ( - FlowMatchEulerDiscreteScheduler, - StableDiffusion3Pipeline, -) +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel + +from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device @@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = StableDiffusion3Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_kwargs = {} + uses_flow_matching = True transformer_kwargs = { "sample_size": 32, "patch_size": 1, @@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "pooled_projection_dim": 64, "out_channels": 4, } + transformer_cls = SD3Transformer2DModel vae_kwargs = { "sample_size": 32, "in_channels": 3, @@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "scaling_factor": 1.5035, } has_three_text_encoders = True + tokenizer_cls, tokenizer_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip" + tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip" + tokenizer_3_cls, tokenizer_3_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder" + text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder-2" + text_encoder_3_cls, text_encoder_3_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + @property + def output_shape(self): + return (1, 32, 32, 3) @require_torch_gpu def test_sd3_lora(self): diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index f6ca4f304eb9..f00f7b193abf 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -22,6 +22,7 @@ import numpy as np import torch from packaging import version +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( ControlNetModel, @@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): "latent_channels": 4, "sample_size": 128, } + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + + @property + def output_shape(self): + return (1, 64, 64, 3) def setUp(self): super().setUp() diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ca2e92832229..283b9f534766 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import tempfile import unittest @@ -19,14 +20,12 @@ import numpy as np import torch -from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from diffusers import ( AutoencoderKL, DDIMScheduler, FlowMatchEulerDiscreteScheduler, LCMScheduler, - SD3Transformer2DModel, UNet2DConditionModel, ) from diffusers.utils.import_utils import is_peft_available @@ -72,9 +71,19 @@ class PeftLoraLoaderMixinTests: pipeline_class = None scheduler_cls = None scheduler_kwargs = None + uses_flow_matching = False + has_two_text_encoders = False has_three_text_encoders = False + text_encoder_cls, text_encoder_id = None, None + text_encoder_2_cls, text_encoder_2_id = None, None + text_encoder_3_cls, text_encoder_3_id = None, None + tokenizer_cls, tokenizer_id = None, None + tokenizer_2_cls, tokenizer_2_id = None, None + tokenizer_3_cls, tokenizer_3_id = None, None + unet_kwargs = None + transformer_cls = None transformer_kwargs = None vae_kwargs = None @@ -91,28 +100,23 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): if self.unet_kwargs is not None: unet = UNet2DConditionModel(**self.unet_kwargs) else: - transformer = SD3Transformer2DModel(**self.transformer_kwargs) + transformer = self.transformer_cls(**self.transformer_kwargs) scheduler = scheduler_cls(**self.scheduler_kwargs) torch.manual_seed(0) vae = AutoencoderKL(**self.vae_kwargs) - if not self.has_three_text_encoders: - text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") - tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") + text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) - if self.has_two_text_encoders: - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2") - tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") + if self.text_encoder_2_cls is not None: + text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id) + tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id) - if self.has_three_text_encoders: - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - text_encoder = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder") - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder-2") - text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + if self.text_encoder_3_cls is not None: + text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id) + tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id) text_lora_config = LoraConfig( r=rank, @@ -130,45 +134,39 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): use_dora=use_dora, ) - if self.has_two_text_encoders or self.has_three_text_encoders: - if self.unet_kwargs is not None: - pipeline_components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "text_encoder_2": text_encoder_2, - "tokenizer_2": tokenizer_2, - "image_encoder": None, - "feature_extractor": None, - } - elif self.has_three_text_encoders and self.transformer_kwargs is not None: - pipeline_components = { - "transformer": transformer, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "text_encoder_2": text_encoder_2, - "tokenizer_2": tokenizer_2, - "text_encoder_3": text_encoder_3, - "tokenizer_3": tokenizer_3, - } - else: - pipeline_components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "safety_checker": None, - "feature_extractor": None, - "image_encoder": None, - } + pipeline_components = { + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + # Denoiser + if self.unet_kwargs is not None: + pipeline_components.update({"unet": unet}) + elif self.transformer_kwargs is not None: + pipeline_components.update({"transformer": transformer}) + + # Remaining text encoders. + if self.text_encoder_2_cls is not None: + pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2}) + if self.text_encoder_3_cls is not None: + pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3}) + + # Remaining stuff + init_params = inspect.signature(self.pipeline_class.__init__).parameters + if "safety_checker" in init_params: + pipeline_components.update({"safety_checker": None}) + if "feature_extractor" in init_params: + pipeline_components.update({"feature_extractor": None}) + if "image_encoder" in init_params: + pipeline_components.update({"image_encoder": None}) return pipeline_components, text_lora_config, denoiser_lora_config + @property + def output_shape(self): + raise NotImplementedError + def get_dummy_inputs(self, with_generator=True): batch_size = 1 sequence_length = 10 @@ -205,9 +203,7 @@ def test_simple_inference(self): Tests a simple inference and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -217,8 +213,7 @@ def test_simple_inference(self): _, _, inputs = self.get_dummy_inputs() output_no_lora = pipe(**inputs).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) def test_simple_inference_with_text_lora(self): """ @@ -226,9 +221,7 @@ def test_simple_inference_with_text_lora(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -238,17 +231,18 @@ def test_simple_inference_with_text_lora(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -261,9 +255,7 @@ def test_simple_inference_with_text_lora_and_scale(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -273,17 +265,18 @@ def test_simple_inference_with_text_lora_and_scale(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -322,9 +315,7 @@ def test_simple_inference_with_text_lora_fused(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -334,26 +325,27 @@ def test_simple_inference_with_text_lora_fused(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.fuse_lora() # Fusing should still keep the LoRA layers self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertFalse( @@ -366,9 +358,7 @@ def test_simple_inference_with_text_lora_unloaded(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -378,17 +368,18 @@ def test_simple_inference_with_text_lora_unloaded(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -397,10 +388,11 @@ def test_simple_inference_with_text_lora_unloaded(self): ) if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertFalse( + check_if_lora_correctly_set(pipe.text_encoder_2), + "Lora not correctly unloaded in text encoder 2", + ) ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -413,9 +405,7 @@ def test_simple_inference_with_text_lora_save_load(self): Tests a simple usecase where users could use saving utilities for LoRA. """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -425,31 +415,32 @@ def test_simple_inference_with_text_lora_save_load(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images with tempfile.TemporaryDirectory() as tmpdirname: text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) if self.has_two_text_encoders or self.has_three_text_encoders: - text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - safe_serialization=False, - ) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + text_encoder_2_lora_layers=text_encoder_2_state_dict, + safe_serialization=False, + ) else: self.pipeline_class.save_lora_weights( save_directory=tmpdirname, @@ -457,6 +448,14 @@ def test_simple_inference_with_text_lora_save_load(self): safe_serialization=False, ) + if self.has_two_text_encoders: + if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules: + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + safe_serialization=False, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) pipe.unload_lora_weights() @@ -466,9 +465,10 @@ def test_simple_inference_with_text_lora_save_load(self): self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), @@ -482,9 +482,7 @@ def test_simple_inference_with_partial_text_lora(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, _, _ = self.get_dummy_components(scheduler_cls) @@ -503,8 +501,7 @@ def test_simple_inference_with_partial_text_lora(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") @@ -517,17 +514,18 @@ def test_simple_inference_with_partial_text_lora(self): } if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - state_dict.update( - { - f"text_encoder_2.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() - if "text_model.encoder.layers.4" not in module_name - } - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + state_dict.update( + { + f"text_encoder_2.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + if "text_model.encoder.layers.4" not in module_name + } + ) output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -549,9 +547,7 @@ def test_simple_inference_save_pretrained(self): Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -561,17 +557,17 @@ def test_simple_inference_save_pretrained(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images @@ -587,10 +583,11 @@ def test_simple_inference_save_pretrained(self): ) if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), - "Lora not correctly set in text encoder 2", - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), + "Lora not correctly set in text encoder 2", + ) images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images @@ -604,14 +601,10 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -621,8 +614,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -635,10 +627,11 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images @@ -650,32 +643,23 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): else: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - if self.has_two_text_encoders or self.has_three_text_encoders: - text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + saving_kwargs = { + "save_directory": tmpdirname, + "text_encoder_lora_layers": text_encoder_state_dict, + "safe_serialization": False, + } - if self.unet_kwargs is not None: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - unet_lora_layers=denoiser_state_dict, - safe_serialization=False, - ) - else: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - transformer_lora_layers=denoiser_state_dict, - safe_serialization=False, - ) + if self.unet_kwargs is not None: + saving_kwargs.update({"unet_lora_layers": denoiser_state_dict}) else: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - unet_lora_layers=denoiser_state_dict, - safe_serialization=False, - ) + saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict}) + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict}) + + self.pipeline_class.save_lora_weights(**saving_kwargs) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) pipe.unload_lora_weights() @@ -688,9 +672,10 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), @@ -703,9 +688,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -715,8 +698,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -728,10 +710,11 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -775,9 +758,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): and makes sure it works as expected - with unet """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -787,8 +768,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -801,10 +781,11 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.fuse_lora() # Fusing should still keep the LoRA layers @@ -813,9 +794,10 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertFalse( @@ -828,9 +810,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -840,8 +820,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -853,10 +832,11 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -869,10 +849,11 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): ) if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertFalse( + check_if_lora_correctly_set(pipe.text_encoder_2), + "Lora not correctly unloaded in text encoder 2", + ) ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -886,9 +867,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -908,10 +887,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.fuse_lora() @@ -926,9 +906,10 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" + ) # Fuse and unfuse should lead to the same results self.assertTrue( @@ -942,9 +923,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): multiple adapters and set them """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -972,11 +951,12 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.set_adapters("adapter-1") @@ -1023,9 +1003,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self): return scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1047,10 +1025,11 @@ def test_simple_inference_with_text_denoiser_block_scale(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) weights_1 = {"text_encoder": 2, "unet": {"down": 5}} pipe.set_adapters("adapter-1", weights_1) @@ -1090,9 +1069,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): return scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1120,11 +1097,12 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) scales_1 = {"text_encoder": 2, "unet": {"down": 5}} scales_2 = {"unet": {"down": 5, "mid": 5}} @@ -1170,7 +1148,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" - if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]: return def updown_options(blocks_with_tf, layers_per_block, value): @@ -1249,7 +1227,9 @@ def all_possible_dict_opts(unet, value): pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): # test if lora block scales can be set with this scale_dict @@ -1264,9 +1244,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): multiple adapters and set/delete them """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1294,11 +1272,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.set_adapters("adapter-1") @@ -1370,9 +1350,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): multiple adapters and set them """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1400,11 +1378,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.set_adapters("adapter-1") @@ -1453,9 +1433,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): @skip_mps def test_lora_fuse_nan(self): scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1501,9 +1479,7 @@ def test_get_adapters(self): are the expected results """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1539,9 +1515,7 @@ def test_get_list_adapters(self): are the expected results """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1607,9 +1581,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): and makes sure it works as expected - with unet and multi-adapter case """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1619,8 +1591,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") if self.unet_kwargs is not None: @@ -1640,11 +1611,13 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) # set them to multi-adapter inference mode pipe.set_adapters(["adapter-1", "adapter-2"]) @@ -1676,9 +1649,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): @require_peft_version_greater(peft_version="0.9.0") def test_simple_inference_with_dora(self): scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components( @@ -1690,8 +1661,7 @@ def test_simple_inference_with_dora(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_dora_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_dora_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -1704,10 +1674,12 @@ def test_simple_inference_with_dora(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images @@ -1723,9 +1695,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1760,7 +1730,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): _ = pipe(**inputs, generator=torch.manual_seed(0)).images def test_modify_padding_mode(self): - if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]: return def set_pad_mode(network, mode="circular"): @@ -1769,9 +1739,7 @@ def set_pad_mode(network, mode="circular"): module.padding_mode = mode scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, _, _ = self.get_dummy_components(scheduler_cls) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py new file mode 100644 index 000000000000..d1c85537b00b --- /dev/null +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import FluxTransformer2DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = FluxTransformer2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) + text_ids = torch.randn((batch_size, sequence_length, num_image_channels)).to(torch_device) + image_ids = torch.randn((batch_size, height * width, num_image_channels)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + } + + @property + def input_shape(self): + return (16, 4) + + @property + def output_shape(self): + return (16, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/tests/models/transformers/test_models_transformer_latte.py b/tests/models/transformers/test_models_transformer_latte.py new file mode 100644 index 000000000000..3fe0a6098045 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_latte.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LatteTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class LatteTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = LatteTransformer3DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 1 + height = width = 8 + embedding_dim = 8 + sequence_length = 8 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "enable_temporal_attentions": True, + } + + @property + def input_shape(self): + return (4, 1, 8, 8) + + @property + def output_shape(self): + return (8, 1, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 8, + "num_layers": 1, + "patch_size": 2, + "attention_head_dim": 4, + "num_attention_heads": 2, + "caption_channels": 8, + "in_channels": 4, + "cross_attention_dim": 8, + "out_channels": 8, + "attention_bias": True, + "activation_fn": "gelu-approximate", + "num_embeds_ada_norm": 1000, + "norm_type": "ada_norm_single", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + super().test_output( + expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape + ) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 0dc13911c55b..b2744e3f0ad4 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -13,42 +13,27 @@ torch_device, ) -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin -@unittest.skip("Tests needs to be revisited.") +@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxPipeline - params = frozenset( - [ - "prompt", - "height", - "width", - "guidance_scale", - "negative_prompt", - "prompt_embeds", - "negative_prompt_embeds", - ] - ) - batch_params = frozenset(["prompt", "negative_prompt"]) + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) def get_dummy_components(self): torch.manual_seed(0) transformer = FluxTransformer2DModel( - sample_size=32, patch_size=1, in_channels=4, num_layers=1, - attention_head_dim=8, - num_attention_heads=4, - caption_projection_dim=32, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, joint_attention_dim=32, - pooled_projection_dim=64, - out_channels=4, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], ) clip_text_encoder_config = CLIPTextConfig( bos_token_id=0, @@ -80,7 +65,7 @@ def get_dummy_components(self): out_channels=3, block_out_channels=(4,), layers_per_block=1, - latent_channels=4, + latent_channels=1, norm_num_groups=1, use_quant_conv=False, use_post_quant_conv=False, @@ -111,6 +96,9 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, "output_type": "np", } return inputs @@ -128,22 +116,8 @@ def test_flux_different_prompts(self): max_diff = np.abs(output_same_prompt - output_different_prompts).max() # Outputs should be different here - assert max_diff > 1e-2 - - def test_flux_different_negative_prompts(self): - pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) - - inputs = self.get_dummy_inputs(torch_device) - output_same_prompt = pipe(**inputs).images[0] - - inputs = self.get_dummy_inputs(torch_device) - inputs["negative_prompt_2"] = "deformed" - output_different_prompts = pipe(**inputs).images[0] - - max_diff = np.abs(output_same_prompt - output_different_prompts).max() - - # Outputs should be different here - assert max_diff > 1e-2 + # For some reasons, they don't show large differences + assert max_diff > 1e-6 def test_flux_prompt_embeds(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) @@ -154,71 +128,21 @@ def test_flux_prompt_embeds(self): inputs = self.get_dummy_inputs(torch_device) prompt = inputs.pop("prompt") - do_classifier_free_guidance = inputs["guidance_scale"] > 1 - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - text_ids, - ) = pipe.encode_prompt( + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( prompt, prompt_2=None, - prompt_3=None, - do_classifier_free_guidance=do_classifier_free_guidance, device=torch_device, + max_sequence_length=inputs["max_sequence_length"], ) output_with_embeds = pipe( prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, **inputs, ).images[0] max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 - def test_fused_qkv_projections(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - original_image_slice = image[0, -3:, -3:, -1] - - # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added - # to the pipeline level. - pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - image_slice_fused = image[0, -3:, -3:, -1] - - pipe.transformer.unfuse_qkv_projections() - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - image_slice_disabled = image[0, -3:, -3:, -1] - - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." - @slow @require_torch_gpu diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py index 3f7fcaf59575..719a5ef101e7 100644 --- a/tests/pipelines/kolors/test_kolors.py +++ b/tests/pipelines/kolors/test_kolors.py @@ -20,12 +20,11 @@ from diffusers import ( AutoencoderKL, - ChatGLMModel, - ChatGLMTokenizer, EulerDiscreteScheduler, KolorsPipeline, UNet2DConditionModel, ) +from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer from diffusers.utils.testing_utils import enable_full_determinism from ..pipeline_params import ( diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 8f637b991056..2cd80921e932 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -429,7 +429,10 @@ def test_pag_applied_layers(self): pipe.set_progress_bar_config(disable=None) # pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers - all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k] + # Note that for motion modules in AnimateDiff, both attn1 and attn2 are self-attention + all_self_attn_layers = [ + k for k in pipe.unet.attn_processors.keys() if "attn1" in k or ("motion_modules" in k and "attn2" in k) + ] original_attn_procs = pipe.unet.attn_processors pag_layers = [ "down", @@ -439,12 +442,13 @@ def test_pag_applied_layers(self): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) - # pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.motion_modules_0"] should apply to all self-attention layers in mid_block, i.e. + # pag_applied_layers = ["mid"], or ["mid_block.0"] should apply to all self-attention layers in mid_block, i.e. # mid_block.motion_modules.0.transformer_blocks.0.attn1.processor # mid_block.attentions.0.transformer_blocks.0.attn1.processor all_self_attn_mid_layers = [ - "mid_block.motion_modules.0.transformer_blocks.0.attn1.processor", "mid_block.attentions.0.transformer_blocks.0.attn1.processor", + "mid_block.motion_modules.0.transformer_blocks.0.attn1.processor", + "mid_block.motion_modules.0.transformer_blocks.0.attn2.processor", ] pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["mid"] @@ -452,17 +456,17 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0", "mid.block_0.motion_modules_0"] + pag_layers = ["mid_block.(attentions|motion_modules)"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -474,19 +478,19 @@ def test_pag_applied_layers(self): pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["down"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert len(pipe.pag_attn_processors) == 6 + assert len(pipe.pag_attn_processors) == 10 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert (len(pipe.pag_attn_processors)) == 4 + assert (len(pipe.pag_attn_processors)) == 6 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert len(pipe.pag_attn_processors) == 2 + assert len(pipe.pag_attn_processors) == 10 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.motion_modules_2"] + pag_layers = ["motion_modules.42"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py new file mode 100644 index 000000000000..db0e257760ed --- /dev/null +++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py @@ -0,0 +1,358 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, BertModel, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + HunyuanDiT2DModel, + HunyuanDiTPAGPipeline, + HunyuanDiTPipeline, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanDiTPAGPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + required_optional_params = PipelineTesterMixin.required_optional_params + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HunyuanDiT2DModel( + sample_size=16, + num_layers=2, + patch_size=2, + attention_head_dim=8, + num_attention_heads=3, + in_channels=4, + cross_attention_dim=32, + cross_attention_dim_t5=32, + pooled_projection_dim=16, + hidden_size=24, + activation_fn="gelu-approximate", + ) + torch.manual_seed(0) + vae = AutoencoderKL() + + scheduler = DDPMScheduler() + text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel") + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "use_resolution_binning": False, + "pag_scale": 0.0, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 16, 16, 3)) + expected_slice = np.array( + [0.56939435, 0.34541583, 0.35915792, 0.46489206, 0.38775963, 0.45004836, 0.5957267, 0.59481275, 0.33287364] + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_sequential_cpu_offload_forward_pass(self): + # TODO(YiYi) need to fix later + pass + + def test_sequential_offload_forward_pass_twice(self): + # TODO(YiYi) need to fix later + pass + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-3, + ) + + def test_save_load_optional_components(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = inputs["prompt"] + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0) + + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = pipe.encode_prompt( + prompt, + device=torch_device, + dtype=torch.float32, + text_encoder_index=1, + ) + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prompt_embeds_2": prompt_embeds_2, + "prompt_attention_mask_2": prompt_attention_mask_2, + "negative_prompt_embeds_2": negative_prompt_embeds_2, + "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "use_resolution_binning": False, + } + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prompt_embeds_2": prompt_embeds_2, + "prompt_attention_mask_2": prompt_attention_mask_2, + "negative_prompt_embeds_2": negative_prompt_embeds_2, + "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "use_resolution_binning": False, + } + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1e-4) + + def test_feed_forward_chunking(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_no_chunking = image[0, -3:, -3:, -1] + + pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0) + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_chunking = image[0, -3:, -3:, -1] + + max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max() + self.assertLess(max_diff, 1e-4) + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image = pipe(**inputs)[0] + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image_fused = pipe(**inputs)[0] + image_slice_fused = image_fused[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image_disabled = pipe(**inputs)[0] + image_slice_disabled = image_disabled[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline (expect same output when pag is disabled) + pipe_sd = HunyuanDiTPipeline(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + components = self.get_dummy_components() + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 3.0 + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_pag_applied_layers(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] + original_attn_procs = pipe.transformer.attn_processors + pag_layers = ["blocks.0", "blocks.1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) + + # blocks.0 + block_0_self_attn = ["blocks.0.attn1.processor"] + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.0"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(block_0_self_attn) + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.0.attn1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(block_0_self_attn) + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.(0|1)"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert (len(pipe.pag_attn_processors)) == 2 + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.0", r"blocks\.1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert len(pipe.pag_attn_processors) == 2 diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py index be86afe45be0..70b528dede56 100644 --- a/tests/pipelines/pag/test_pag_pixart_sigma.py +++ b/tests/pipelines/pag/test_pag_pixart_sigma.py @@ -127,7 +127,7 @@ def test_pag_disable_enable(self): out = pipe(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 - components["pag_applied_layers"] = [1] + components["pag_applied_layers"] = ["blocks.1"] pipe_pag = self.pipeline_class(**components) pipe_pag = pipe_pag.to(device) pipe_pag.set_progress_bar_config(disable=None) @@ -158,7 +158,7 @@ def test_pag_applied_layers(self): # "attn1" should apply to all self-attention layers. all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] - pag_layers = [0, 1] + pag_layers = ["blocks.0", "blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) @@ -228,7 +228,7 @@ def test_save_load_optional_components(self): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1]) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"]) pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) @@ -282,7 +282,7 @@ def test_save_load_local(self, expected_max_difference=1e-4): pipe.save_pretrained(tmpdir, safe_serialization=False) with CaptureLogger(logger) as cap_logger: - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1]) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"]) for name in pipe_loaded.components.keys(): if name not in pipe_loaded._optional_components: diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py index a0930245b375..e9adb3ac447e 100644 --- a/tests/pipelines/pag/test_pag_sd.py +++ b/tests/pipelines/pag/test_pag_sd.py @@ -213,18 +213,18 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0"] + pag_layers = ["mid_block.attentions.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -239,17 +239,17 @@ def test_pag_applied_layers(self): assert len(pipe.pag_attn_processors) == 2 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["down_blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 2 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.attentions_1"] + pag_layers = ["down_blocks.1.attentions.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 1 diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py index 5ec3dc5555f1..589573385677 100644 --- a/tests/pipelines/pag/test_pag_sdxl.py +++ b/tests/pipelines/pag/test_pag_sdxl.py @@ -225,18 +225,18 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0"] + pag_layers = ["mid_block.attentions.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -251,17 +251,17 @@ def test_pag_applied_layers(self): assert len(pipe.pag_attn_processors) == 4 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["down_blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 4 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.attentions_1"] + pag_layers = ["down_blocks.1.attentions.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 2 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 8fc76db311a6..838f996117aa 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -26,8 +26,8 @@ CLIPTextModel, CLIPTokenizer, DPTConfig, - DPTFeatureExtractor, DPTForDepthEstimation, + DPTImageProcessor, ) from diffusers import ( @@ -145,9 +145,7 @@ def get_dummy_components(self): backbone_featmap_shape=[1, 384, 24, 24], ) depth_estimator = DPTForDepthEstimation(depth_estimator_config).eval() - feature_extractor = DPTFeatureExtractor.from_pretrained( - "hf-internal-testing/tiny-random-DPTForDepthEstimation" - ) + feature_extractor = DPTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-DPTForDepthEstimation") components = { "unet": unet, diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 352ac5defce2..abdc9fd409db 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -106,7 +106,7 @@ def checkout_commit(repo: Repo, commit_id: str): def clean_code(content: str) -> str: """ Remove docstrings, empty line or comments from some code (used to detect if a diff is real or only concern - comments or docstings). + comments or docstrings). Args: content (`str`): The code to clean @@ -165,7 +165,7 @@ def keep_doc_examples_only(content: str) -> str: def get_all_tests() -> List[str]: """ Walks the `tests` folder to return a list of files/subfolders. This is used to split the tests to run when using - paralellism. The split is: + parallelism. The split is: - folders under `tests`: (`tokenization`, `pipelines`, etc) except the subfolder `models` is excluded. - folders under `tests/models`: `bert`, `gpt2`, etc. @@ -635,7 +635,7 @@ def get_tree_starting_at(module: str, edges: List[Tuple[str, str]]) -> List[Unio Args: module (`str`): The module that will be the root of the subtree we want. - eges (`List[Tuple[str, str]]`): The list of all edges of the tree. + edges (`List[Tuple[str, str]]`): The list of all edges of the tree. Returns: `List[Union[str, List[str]]]`: The tree to print in the following format: [module, [list of edges @@ -663,7 +663,7 @@ def print_tree_deps_of(module, all_edges=None): Args: module (`str`): The module that will be the root of the subtree we want. - all_eges (`List[Tuple[str, str]]`, *optional*): + all_edges (`List[Tuple[str, str]]`, *optional*): The list of all edges of the tree. Will be set to `create_reverse_dependency_tree()` if not passed. """ if all_edges is None: @@ -706,7 +706,7 @@ def init_test_examples_dependencies() -> Tuple[Dict[str, List[str]], List[str]]: for framework in ["flax", "pytorch", "tensorflow"]: test_files = list((PATH_TO_EXAMPLES / framework).glob("test_*.py")) all_examples.extend(test_files) - # Remove the files at the root of examples/framework since they are not proper examples (they are eith utils + # Remove the files at the root of examples/framework since they are not proper examples (they are either utils # or example test files). examples = [ f for f in (PATH_TO_EXAMPLES / framework).glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / framework