diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index 6b4ddc1f18d7..248846a8747e 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -32,7 +32,7 @@ jobs:
fetch-depth: 2
- name: Install dependencies
run: |
- pip install -e .
+ pip install -e [test]
pip install huggingface_hub
- name: Fetch Pipeline Matrix
id: fetch_pipeline_matrix
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index bce17b291478..048638cebf53 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -253,6 +253,8 @@
title: HunyuanDiT2DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
+ - local: api/models/flux_transformer
+ title: FluxTransformer2DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
@@ -320,6 +322,8 @@
title: DiffEdit
- local: api/pipelines/dit
title: DiT
+ - local: api/pipelines/flux
+ title: Flux
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/i2vgenxl
diff --git a/docs/source/en/api/models/flux_transformer.md b/docs/source/en/api/models/flux_transformer.md
new file mode 100644
index 000000000000..381593f1bfe6
--- /dev/null
+++ b/docs/source/en/api/models/flux_transformer.md
@@ -0,0 +1,19 @@
+
+
+# FluxTransformer2DModel
+
+A Transformer model for image-like data from [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).
+
+## FluxTransformer2DModel
+
+[[autodoc]] FluxTransformer2DModel
diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md
new file mode 100644
index 000000000000..f7d27dfbbf98
--- /dev/null
+++ b/docs/source/en/api/pipelines/flux.md
@@ -0,0 +1,84 @@
+
+
+# Flux
+
+Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.
+
+Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux).
+
+
+
+Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
+
+
+
+Flux comes in two variants:
+
+* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`)
+* Guidance-distilled (`black-forest-labs/FLUX.1-dev`)
+
+Both checkpoints have slightly difference usage which we detail below.
+
+### Timestep-distilled
+
+* `max_sequence_length` cannot be more than 256.
+* `guidance_scale` needs to be 0.
+* As this is a timestep-distilled model, it benefits from fewer sampling steps.
+
+```python
+import torch
+from diffusers import FluxPipeline
+
+pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+pipe.enable_model_cpu_offload()
+
+prompt = "A cat holding a sign that says hello world"
+out = pipe(
+ prompt=prompt,
+ guidance_scale=0.,
+ height=768,
+ width=1360,
+ num_inference_steps=4,
+ max_sequence_length=256,
+).images[0]
+out.save("image.png")
+```
+
+### Guidance-distilled
+
+* The guidance-distilled variant takes about 50 sampling steps for good-quality generation.
+* It doesn't have any limitations around the `max_sequence_length`.
+
+```python
+import torch
+from diffusers import FluxPipeline
+
+pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
+pipe.enable_model_cpu_offload()
+
+prompt = "a tiny astronaut hatching from an egg on the moon"
+out = pipe(
+ prompt=prompt,
+ guidance_scale=3.5,
+ height=768,
+ width=1360,
+ num_inference_steps=50,
+).images[0]
+out.save("image.png")
+```
+
+## FluxPipeline
+
+[[autodoc]] FluxPipeline
+ - all
+ - __call__
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md
index abfeb930d5ba..5b48569f3505 100644
--- a/docs/source/en/api/pipelines/pag.md
+++ b/docs/source/en/api/pipelines/pag.md
@@ -20,6 +20,11 @@ The abstract from the paper is:
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
+## AnimateDiffPAGPipeline
+[[autodoc]] AnimateDiffPAGPipeline
+ - all
+ - __call__
+
## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
@@ -49,3 +54,9 @@ The abstract from the paper is:
[[autodoc]] StableDiffusionXLControlNetPAGPipeline
- all
- __call__
+
+
+## PixArtSigmaPAGPipeline
+[[autodoc]] PixArtSigmaPAGPipeline
+ - all
+ - __call__
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/pag.md b/docs/source/en/using-diffusers/pag.md
index f6ca87ef0662..e852aec03fd4 100644
--- a/docs/source/en/using-diffusers/pag.md
+++ b/docs/source/en/using-diffusers/pag.md
@@ -22,7 +22,7 @@ This guide will show you how to use PAG for various tasks and use cases.
You can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument.
> [!TIP]
-> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
+> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines and [`PixArtSigmaPAGPipeline`]. But feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
diff --git a/scripts/convert_flux_to_diffusers.py b/scripts/convert_flux_to_diffusers.py
new file mode 100644
index 000000000000..05a1da256d33
--- /dev/null
+++ b/scripts/convert_flux_to_diffusers.py
@@ -0,0 +1,303 @@
+import argparse
+from contextlib import nullcontext
+
+import safetensors.torch
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+
+from diffusers import AutoencoderKL, FluxTransformer2DModel
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+"""
+# Transformer
+
+python scripts/convert_flux_to_diffusers.py \
+--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
+--filename "flux1-schnell.sft"
+--output_path "flux-schnell" \
+--transformer
+"""
+
+"""
+# VAE
+
+python scripts/convert_flux_to_diffusers.py \
+--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
+--filename "ae.sft"
+--output_path "flux-schnell" \
+--vae
+"""
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
+parser.add_argument("--filename", default="flux.safetensors", type=str)
+parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--vae", action="store_true")
+parser.add_argument("--transformer", action="store_true")
+parser.add_argument("--output_path", type=str)
+parser.add_argument("--dtype", type=str, default="bf16")
+
+args = parser.parse_args()
+dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
+
+
+def load_original_checkpoint(args):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
+def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_flux_transformer_checkpoint_to_diffusers(
+ original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
+):
+ converted_state_dict = {}
+
+ ## time_text_embed.timestep_embedder <- time_in
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "time_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "time_in.out_layer.bias"
+ )
+
+ ## time_text_embed.text_embedder <- vector_in
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
+ "vector_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
+ "vector_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
+ "vector_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
+ "vector_in.out_layer.bias"
+ )
+
+ # guidance
+ has_guidance = any("guidance" in k for k in original_state_dict)
+ if has_guidance:
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
+ "guidance_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
+ "guidance_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
+ "guidance_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
+ "guidance_in.out_layer.bias"
+ )
+
+ # context_embedder
+ converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
+ converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")
+
+ # x_embedder
+ converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
+ converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ # norms.
+ ## norm1
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.bias"
+ )
+ ## norm1_context
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.bias"
+ )
+ # Q, K, V
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
+ )
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
+ )
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
+ )
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk_norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
+ )
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.0.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.2.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.0.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.2.bias"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.bias"
+ )
+
+ # single transfomer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.bias"
+ )
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+ q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
+ # qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.key_norm.scale"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.weight"
+ )
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.bias"
+ )
+
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ )
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
+ )
+
+ return converted_state_dict
+
+
+def main(args):
+ original_ckpt = load_original_checkpoint(args)
+ has_guidance = any("guidance" in k for k in original_ckpt)
+
+ if args.transformer:
+ num_layers = 19
+ num_single_layers = 38
+ inner_dim = 3072
+ mlp_ratio = 4.0
+ converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
+ original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
+ )
+ transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ print(
+ f"Saving Flux Transformer in Diffusers format. Variant: {'guidance-distilled' if has_guidance else 'timestep-distilled'}"
+ )
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ if args.vae:
+ config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
+ vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
+
+ converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 9bb7be4b0dd6..d58bbdac1867 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -85,6 +85,7 @@
"ControlNetModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
+ "FluxTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
@@ -233,6 +234,7 @@
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffControlNetPipeline",
+ "AnimateDiffPAGPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
@@ -248,6 +250,7 @@
"ChatGLMTokenizer",
"CLIPImageProjection",
"CycleDiffusionPipeline",
+ "FluxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPipeline",
"I2VGenXLPipeline",
@@ -292,6 +295,7 @@
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
+ "PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
@@ -526,6 +530,7 @@
ControlNetModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
+ FluxTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
@@ -654,6 +659,7 @@
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffControlNetPipeline,
+ AnimateDiffPAGPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
@@ -667,6 +673,7 @@
ChatGLMTokenizer,
CLIPImageProjection,
CycleDiffusionPipeline,
+ FluxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPipeline,
I2VGenXLPipeline,
@@ -711,6 +718,7 @@
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
+ PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 76fe2b682a46..53fd3ebd4bbd 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -51,6 +51,7 @@
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
+ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -93,6 +94,7 @@
AuraFlowTransformer2DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
+ FluxTransformer2DModel,
HunyuanDiT2DModel,
LatteTransformer3DModel,
LuminaNextDiT2DModel,
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 5c5464c37683..855085c0d933 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -121,11 +121,12 @@ def __init__(
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
context_pre_only=None,
+ pre_only=False,
):
super().__init__()
# To prevent circular import.
- from .normalization import FP32LayerNorm
+ from .normalization import FP32LayerNorm, RMSNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
@@ -141,6 +142,7 @@ def __init__(
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
@@ -186,6 +188,9 @@ def __init__(
# Lumina applys qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
@@ -228,9 +233,10 @@ def __init__(
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
- self.to_out = nn.ModuleList([])
- self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
- self.to_out.append(nn.Dropout(dropout))
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
@@ -239,6 +245,9 @@ def __init__(
if qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
else:
self.norm_added_q = None
self.norm_added_k = None
@@ -1265,6 +1274,179 @@ def __call__(
return hidden_states
+# YiYi to-do: refactor rope related functions/classes
+def apply_rope(xq, xk, freqs_cis):
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+
+
+class FluxSingleAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ query = attn.to_q(hidden_states)
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ # YiYi to-do: update uising apply_rotary_emb
+ # from ..embeddings import apply_rotary_emb
+ # query = apply_rotary_emb(query, image_rotary_emb)
+ # key = apply_rotary_emb(key, image_rotary_emb)
+ query, key = apply_rope(query, key, image_rotary_emb)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states
+
+
+class FluxAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ # YiYi to-do: update uising apply_rotary_emb
+ # from ..embeddings import apply_rotary_emb
+ # query = apply_rotary_emb(query, image_rotary_emb)
+ # key = apply_rotary_emb(key, image_rotary_emb)
+ query, key = apply_rope(query, key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 71e301d0d707..2821ce0330fc 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -795,6 +795,30 @@ def forward(self, timestep, pooled_projection):
return conditioning
+class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, guidance, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ time_guidance_emb = timesteps_emb + guidance_emb
+
+ pooled_projections = self.text_embedder(pooled_projection)
+ conditioning = time_guidance_emb + pooled_projections
+
+ return conditioning
+
+
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 4e532f3fc990..8d09999e5c95 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -106,6 +106,38 @@ def forward(
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+class AdaLayerNormZeroSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa
+
+
class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 8d4b8d9d6ecb..d0d351ce88e1 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -13,5 +13,6 @@
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
+ from .transformer_flux import FluxTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py
index 9c8f9b09083f..b1ab0ad2b657 100644
--- a/src/diffusers/models/transformers/pixart_transformer_2d.py
+++ b/src/diffusers/models/transformers/pixart_transformer_2d.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Union
import torch
from torch import nn
@@ -19,6 +19,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
+from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -186,6 +187,66 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
new file mode 100644
index 000000000000..73ccc03b38c4
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -0,0 +1,446 @@
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...models.attention import FeedForward
+from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
+from ...models.modeling_utils import ModelMixin
+from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
+from ..modeling_outputs import Transformer2DModelOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# YiYi to-do: refactor rope related functions/classes
+def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ assert dim % 2 == 0, "The dimension must be even."
+
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+
+ batch_size, seq_length = pos.shape
+ out = torch.einsum("...n,d->...nd", pos, omega)
+ cos_out = torch.cos(out)
+ sin_out = torch.sin(out)
+
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
+ return out.float()
+
+
+# YiYi to-do: refactor rope related functions/classes
+class EmbedND(nn.Module):
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+
+ return emb.unsqueeze(1)
+
+
+@maybe_allow_in_graph
+class FluxSingleTransformerBlock(nn.Module):
+ r"""
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
+ processing of `context` conditions.
+ """
+
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm = AdaLayerNormZeroSingle(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ processor = FluxSingleAttnProcessor2_0()
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ qk_norm="rms_norm",
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ image_rotary_emb=None,
+ ):
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class FluxTransformerBlock(nn.Module):
+ r"""
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
+ processing of `context` conditions.
+ """
+
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
+ super().__init__()
+
+ self.norm1 = AdaLayerNormZero(dim)
+
+ self.norm1_context = AdaLayerNormZero(dim)
+
+ if hasattr(F, "scaled_dot_product_attention"):
+ processor = FluxAttnProcessor2_0()
+ else:
+ raise ValueError(
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
+ )
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=processor,
+ qk_norm=qk_norm,
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ image_rotary_emb=None,
+ ):
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # Attention.
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return encoder_hidden_states, hidden_states
+
+
+class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ """
+ The Transformer model introduced in Flux.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Parameters:
+ patch_size (`int`): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ guidance_embeds: bool = False,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
+ text_time_guidance_cls = (
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
+ )
+ self.time_text_embed = text_time_guidance_cls(
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
+ )
+
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ FluxTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ FluxSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_single_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+ else:
+ guidance = None
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ image_rotary_emb = self.pos_embed(ids)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index e1a56efa851d..10f6c4a92054 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -123,6 +123,7 @@
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
]
+ _import_structure["flux"] = ["FluxPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -143,12 +144,14 @@
)
_import_structure["pag"].extend(
[
+ "AnimateDiffPAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
+ "PixArtSigmaPAGPipeline",
]
)
_import_structure["controlnet_xs"].extend(
@@ -475,6 +478,7 @@
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
+ from .flux import FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
@@ -527,6 +531,8 @@
)
from .musicldm import MusicLDMPipeline
from .pag import (
+ AnimateDiffPAGPipeline,
+ PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index 2df09f62c880..854cfaa47b7a 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -28,6 +28,7 @@
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
+from .flux import FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -49,6 +50,7 @@
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import (
+ PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
@@ -97,8 +99,10 @@
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
+ ("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("kolors", KolorsPipeline),
+ ("flux", FluxPipeline),
]
)
diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py
new file mode 100644
index 000000000000..d8c3edf0eaca
--- /dev/null
+++ b/src/diffusers/pipelines/flux/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_flux"] = ["FluxPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_flux import FluxPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
new file mode 100644
index 000000000000..4378f97ffd68
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -0,0 +1,760 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import SD3LoraLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxPipeline
+
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
+ r"""
+ The Flux pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
+ as its dimension.
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=self.text_encoder.dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
+ latent_image_ids = latent_image_ids.reshape(
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.tensor([guidance_scale], device=device)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_output.py b/src/diffusers/pipelines/flux/pipeline_output.py
new file mode 100644
index 000000000000..b5d98fb5bf60
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class FluxPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py
index bf14821f3fdb..b80064eb5e9a 100644
--- a/src/diffusers/pipelines/pag/__init__.py
+++ b/src/diffusers/pipelines/pag/__init__.py
@@ -24,7 +24,9 @@
else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
+ _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
+ _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
@@ -39,7 +41,9 @@
else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
+ from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
+ from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py
index 2009024e4e47..7c9bb2d098d2 100644
--- a/src/diffusers/pipelines/pag/pag_utils.py
+++ b/src/diffusers/pipelines/pag/pag_utils.py
@@ -33,7 +33,7 @@ def _check_input_pag_applied_layer(layer):
Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats:
"{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type`
can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be
- in the format of "attentions_{j}".
+ in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}"
"""
layer_splits = layer.split(".")
@@ -52,8 +52,11 @@ def _check_input_pag_applied_layer(layer):
raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'")
if len(layer_splits) == 3:
- if not layer_splits[2].startswith("attentions_"):
- raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'")
+ layer_2 = layer_splits[2]
+ if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"):
+ raise ValueError(
+ f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'"
+ )
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
@@ -72,33 +75,46 @@ def is_self_attn(module_name):
def get_block_type(module_name):
r"""
- Get the block type from the module name. can be "down", "mid", "up".
+ Get the block type from the module name. Can be "down", "mid", "up".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down"
+ # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down"
return module_name.split(".")[0].split("_")[0]
def get_block_index(module_name):
r"""
- Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
+ Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g.
mid_block) and index is ommited from the name, it will be "block_0".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0"
- if "attentions" in module_name.split(".")[1]:
+ module_name_splits = module_name.split(".")
+ block_index = module_name_splits[1]
+ if "attentions" in block_index or "motion_modules" in block_index:
return "block_0"
else:
- return f"block_{module_name.split('.')[1]}"
+ return f"block_{block_index}"
def get_attn_index(module_name):
r"""
- Get the attention index from the module name. can be "attentions_0", "attentions_1", ...
+ Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0",
+ "motion_modules_1", ...
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
- if "attentions" in module_name.split(".")[2]:
- return f"attentions_{module_name.split('.')[3]}"
- elif "attentions" in module_name.split(".")[1]:
- return f"attentions_{module_name.split('.')[2]}"
+ # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
+ # mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
+ module_name_split = module_name.split(".")
+ mid_name = module_name_split[1]
+ down_name = module_name_split[2]
+ if "attentions" in down_name:
+ return f"attentions_{module_name_split[3]}"
+ if "attentions" in mid_name:
+ return f"attentions_{module_name_split[2]}"
+ if "motion_modules" in down_name:
+ return f"motion_modules_{module_name_split[3]}"
+ if "motion_modules" in mid_name:
+ return f"motion_modules_{module_name_split[2]}"
for pag_layer_input in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the unet model
@@ -114,7 +130,7 @@ def get_attn_index(module_name):
target_modules.append(module)
elif len(pag_layer_input_splits) == 2:
- # when the layer inpput contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
+ # when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
for name, module in self.unet.named_modules():
@@ -126,7 +142,8 @@ def get_attn_index(module_name):
target_modules.append(module)
elif len(pag_layer_input_splits) == 3:
- # when the layer input contains block_type, block_index and attention_index. e.g. "down.blocks_1.attentions_1"
+ # when the layer input contains block_type, block_index and attention_index.
+ # e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
attn_index = pag_layer_input_splits[2]
@@ -258,3 +275,185 @@ def pag_attn_processors(self):
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
processors[name] = proc
return processors
+
+
+class PixArtPAGMixin:
+ @staticmethod
+ def _check_input_pag_applied_layer(layer):
+ r"""
+ Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}.
+ """
+
+ # Check if the layer index is valid (should be int or str of int)
+ if isinstance(layer, int):
+ return # Valid layer index
+
+ if isinstance(layer, str):
+ if layer.isdigit():
+ return # Valid layer index
+
+ # If it is not a valid layer index, raise a ValueError
+ raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.")
+
+ def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
+ r"""
+ Set the attention processor for the PAG layers.
+ """
+ if do_classifier_free_guidance:
+ pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
+ else:
+ pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()
+
+ def is_self_attn(module_name):
+ r"""
+ Check if the module is self-attention module based on its name.
+ """
+ return (
+ "attn1" in module_name and len(module_name.split(".")) == 3
+ ) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ...
+
+ def get_block_index(module_name):
+ r"""
+ Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
+ mid_block) and index is ommited from the name, it will be "block_0".
+ """
+ # transformer_blocks.23.attn -> "23"
+ return module_name.split(".")[1]
+
+ for pag_layer_input in pag_applied_layers:
+ # for each PAG layer input, we find corresponding self-attention layers in the transformer model
+ target_modules = []
+
+ block_index = str(pag_layer_input)
+
+ for name, module in self.transformer.named_modules():
+ if is_self_attn(name) and get_block_index(name) == block_index:
+ target_modules.append(module)
+
+ if len(target_modules) == 0:
+ raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")
+
+ for module in target_modules:
+ module.processor = pag_attn_proc
+
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers
+ def set_pag_applied_layers(self, pag_applied_layers):
+ r"""
+ set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
+ """
+
+ if not isinstance(pag_applied_layers, list):
+ pag_applied_layers = [pag_applied_layers]
+
+ for pag_layer in pag_applied_layers:
+ self._check_input_pag_applied_layer(pag_layer)
+
+ self.pag_applied_layers = pag_applied_layers
+
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale
+ def _get_pag_scale(self, t):
+ r"""
+ Get the scale factor for the perturbed attention guidance at timestep `t`.
+ """
+
+ if self.do_pag_adaptive_scaling:
+ signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
+ if signal_scale < 0:
+ signal_scale = 0
+ return signal_scale
+ else:
+ return self.pag_scale
+
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance
+ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
+ r"""
+ Apply perturbed attention guidance to the noise prediction.
+
+ Args:
+ noise_pred (torch.Tensor): The noise prediction tensor.
+ do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
+ guidance_scale (float): The scale factor for the guidance term.
+ t (int): The current time step.
+
+ Returns:
+ torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
+ """
+ pag_scale = self._get_pag_scale(t)
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ + pag_scale * (noise_pred_text - noise_pred_perturb)
+ )
+ else:
+ noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
+ noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
+ return noise_pred
+
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance
+ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
+ """
+ Prepares the perturbed attention guidance for the PAG model.
+
+ Args:
+ cond (torch.Tensor): The conditional input tensor.
+ uncond (torch.Tensor): The unconditional input tensor.
+ do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
+
+ Returns:
+ torch.Tensor: The prepared perturbed attention guidance tensor.
+ """
+
+ cond = torch.cat([cond] * 2, dim=0)
+
+ if do_classifier_free_guidance:
+ cond = torch.cat([uncond, cond], dim=0)
+ return cond
+
+ @property
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale
+ def pag_scale(self):
+ """
+ Get the scale factor for the perturbed attention guidance.
+ """
+ return self._pag_scale
+
+ @property
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale
+ def pag_adaptive_scale(self):
+ """
+ Get the adaptive scale factor for the perturbed attention guidance.
+ """
+ return self._pag_adaptive_scale
+
+ @property
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling
+ def do_pag_adaptive_scaling(self):
+ """
+ Check if the adaptive scaling is enabled for the perturbed attention guidance.
+ """
+ return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
+
+ @property
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance
+ def do_perturbed_attention_guidance(self):
+ """
+ Check if the perturbed attention guidance is enabled.
+ """
+ return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
+
+ @property
+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer
+ def pag_attn_processors(self):
+ r"""
+ Returns:
+ `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
+ with the key as the name of the layer.
+ """
+
+ processors = {}
+ for name, proc in self.transformer.attn_processors.items():
+ if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
+ processors[name] = proc
+ return processors
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
new file mode 100644
index 000000000000..1188ffe52ed7
--- /dev/null
+++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
@@ -0,0 +1,872 @@
+# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ...image_processor import PixArtImageProcessor
+from ...models import AutoencoderKL, PixArtTransformer2DModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ BACKENDS_MAPPING,
+ deprecate,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pixart_alpha.pipeline_pixart_alpha import (
+ ASPECT_RATIO_256_BIN,
+ ASPECT_RATIO_512_BIN,
+ ASPECT_RATIO_1024_BIN,
+)
+from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
+from .pag_utils import PixArtPAGMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import AutoPipelineForText2Image
+
+ >>> pipe = AutoPipelineForText2Image.from_pretrained(
+ ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
+ ... torch_dtype=torch.float16,
+ ... pag_applied_layers=[14],
+ ... enable_pag=True,
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A small cactus with a happy face in the Sahara desert"
+ >>> image = pipe(prompt, pag_scale=4.0, guidance_scale=1.0).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin):
+ r"""
+ [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation
+ using PixArt-Sigma.
+ """
+
+ bad_punct_regex = re.compile(
+ r"["
+ + "#®•©™&@·º½¾¿¡§~"
+ + r"\)"
+ + r"\("
+ + r"\]"
+ + r"\["
+ + r"\}"
+ + r"\{"
+ + r"\|"
+ + "\\"
+ + r"\/"
+ + r"\*"
+ + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKL,
+ transformer: PixArtTransformer2DModel,
+ scheduler: KarrasDiffusionSchedulers,
+ pag_applied_layers: Union[str, List[str]] = "1", # 1st transformer block
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.set_pag_applied_layers(pag_applied_layers)
+
+ # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ **kwargs,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ """
+
+ if "mask_feature" in kwargs:
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because T5 can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ negative_prompt_attention_mask = uncond_input.attention_mask
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ use_resolution_binning: bool = True,
+ max_sequence_length: int = 300,
+ pag_scale: float = 3.0,
+ pag_adaptive_scale: float = 0.0,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
+ the requested resolution. Useful for generating non-square images.
+ max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`.
+ pag_scale (`float`, *optional*, defaults to 3.0):
+ The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
+ guidance will not be used.
+ pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
+ The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
+ used.
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # 1. Check inputs. Raise error if not correct
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 256:
+ aspect_ratio_bin = ASPECT_RATIO_2048_BIN
+ elif self.transformer.config.sample_size == 128:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ elif self.transformer.config.sample_size == 64:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ elif self.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_256_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+ self._pag_scale = pag_scale
+ self._pag_adaptive_scale = pag_adaptive_scale
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ )
+ if self.do_perturbed_attention_guidance:
+ prompt_embeds = self._prepare_perturbed_attention_guidance(
+ prompt_embeds, negative_prompt_embeds, do_classifier_free_guidance
+ )
+ prompt_attention_mask = self._prepare_perturbed_attention_guidance(
+ prompt_attention_mask, negative_prompt_attention_mask, do_classifier_free_guidance
+ )
+ elif do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ if self.do_perturbed_attention_guidance:
+ original_attn_proc = self.transformer.attn_processors
+ self._set_pag_attn_processor(
+ pag_applied_layers=self.pag_applied_layers,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both
+ latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_perturbed_attention_guidance:
+ noise_pred = self._apply_perturbed_attention_guidance(
+ noise_pred, do_classifier_free_guidance, guidance_scale, current_timestep
+ )
+ elif do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if self.do_perturbed_attention_guidance:
+ self.transformer.set_attn_processor(original_attn_proc)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
new file mode 100644
index 000000000000..e37506a60c61
--- /dev/null
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
@@ -0,0 +1,846 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ...image_processor import PipelineImageInput
+from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...models.unets.unet_motion_model import MotionAdapter
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..animatediff.pipeline_output import AnimateDiffPipelineOutput
+from ..free_init_utils import FreeInitMixin
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from .pag_utils import PAGMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import AnimateDiffPAGPipeline, MotionAdapter, DDIMScheduler
+ >>> from diffusers.utils import export_to_gif
+
+ >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+ >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2"
+ >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id)
+ >>> scheduler = DDIMScheduler.from_pretrained(
+ ... model_id, subfolder="scheduler", beta_schedule="linear", steps_offset=1, clip_sample=False
+ ... )
+ >>> pipe = AnimateDiffPAGPipeline.from_pretrained(
+ ... model_id,
+ ... motion_adapter=motion_adapter,
+ ... scheduler=scheduler,
+ ... pag_applied_layers=["mid"],
+ ... torch_dtype=torch.float16,
+ ... ).to("cuda")
+
+ >>> video = pipe(
+ ... prompt="car, futuristic cityscape with neon lights, street, no human",
+ ... negative_prompt="low quality, bad quality",
+ ... num_inference_steps=25,
+ ... guidance_scale=6.0,
+ ... pag_scale=3.0,
+ ... generator=torch.Generator().manual_seed(42),
+ ... ).frames[0]
+
+ >>> export_to_gif(video, "animatediff_pag.gif")
+ ```
+"""
+
+
+class AnimateDiffPAGPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ FreeInitMixin,
+ PAGMixin,
+):
+ r"""
+ Pipeline for text-to-video generation using
+ [AnimateDiff](https://huggingface.co/docs/diffusers/en/api/pipelines/animatediff) and [Perturbed Attention
+ Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag).
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
+ motion_adapter: MotionAdapter,
+ scheduler: KarrasDiffusionSchedulers,
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1"], ["up.block_0.attentions_0"]
+ ):
+ super().__init__()
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
+
+ self.set_pag_applied_layers(pag_applied_layers)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ image = self.vae.decode(latents).sample
+ video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.pia.pipeline_pia.PIAPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ num_frames: Optional[int] = 16,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ pag_scale: float = 3.0,
+ pag_adaptive_scale: float = 0.0,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_frames (`int`, *optional*, defaults to 16):
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
+ amounts to 2 seconds of video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ pag_scale (`float`, *optional*, defaults to 3.0):
+ The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
+ guidance will not be used.
+ pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
+ The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
+ used.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._pag_scale = pag_scale
+ self._pag_adaptive_scale = pag_adaptive_scale
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_perturbed_attention_guidance:
+ prompt_embeds = self._prepare_perturbed_attention_guidance(
+ prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
+ )
+ elif self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ for i, image_embeds in enumerate(ip_adapter_image_embeds):
+ negative_image_embeds = None
+ if self.do_classifier_free_guidance:
+ negative_image_embeds, image_embeds = image_embeds.chunk(2)
+ if self.do_perturbed_attention_guidance:
+ image_embeds = self._prepare_perturbed_attention_guidance(
+ image_embeds, negative_image_embeds, self.do_classifier_free_guidance
+ )
+ elif self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+ image_embeds = image_embeds.to(device)
+ ip_adapter_image_embeds[i] = image_embeds
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": ip_adapter_image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ if self.do_perturbed_attention_guidance:
+ original_attn_proc = self.unet.attn_processors
+ self._set_pag_attn_processor(
+ pag_applied_layers=self.pag_applied_layers,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
+ for free_init_iter in range(num_free_init_iters):
+ if self.free_init_enabled:
+ latents, timesteps = self._apply_free_init(
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
+ )
+
+ self._num_timesteps = len(timesteps)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8. Denoising loop
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+
+ # perform guidance
+ if self.do_perturbed_attention_guidance:
+ noise_pred = self._apply_perturbed_attention_guidance(
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
+ )
+ elif self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 9. Post processing
+ if output_type == "latent":
+ video = latents
+ else:
+ video_tensor = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
+
+ # 10. Offload all models
+ self.maybe_free_model_hooks()
+
+ if self.do_perturbed_attention_guidance:
+ self.unet.set_attn_processor(original_attn_proc)
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py
index f21d11ba918d..69a55f605bfd 100644
--- a/src/diffusers/pipelines/pia/pipeline_pia.py
+++ b/src/diffusers/pipelines/pia/pipeline_pia.py
@@ -54,22 +54,21 @@
Examples:
```py
>>> import torch
- >>> from diffusers import (
- ... EulerDiscreteScheduler,
- ... MotionAdapter,
- ... PIAPipeline,
- ... )
+ >>> from diffusers import EulerDiscreteScheduler, MotionAdapter, PIAPipeline
>>> from diffusers.utils import export_to_gif, load_image
- >>> adapter = MotionAdapter.from_pretrained("../checkpoints/pia-diffusers")
- >>> pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter)
+ >>> adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
+ >>> pipe = PIAPipeline.from_pretrained(
+ ... "SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16
+ ... )
+
>>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
>>> image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
... )
>>> image = image.resize((512, 512))
>>> prompt = "cat in a hat"
- >>> negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
+ >>> negative_prompt = "wrong white balance, dark, sketches, worst quality, low quality, deformed, distorted"
>>> generator = torch.Generator("cpu").manual_seed(0)
>>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator)
>>> frames = output.frames[0]
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 779e691f0c27..937cae2e47f5 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@@ -20,7 +21,6 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
-from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin
@@ -66,12 +66,19 @@ def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
+ use_dynamic_shifting=False,
+ base_shift: Optional[float] = 0.5,
+ max_shift: Optional[float] = 1.15,
+ base_image_seq_len: Optional[int] = 256,
+ max_image_seq_len: Optional[int] = 4096,
):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
+ if not use_dynamic_shifting:
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
@@ -158,11 +165,15 @@ def scale_noise(
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
+ mu: Optional[float] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -174,6 +185,9 @@ def set_timesteps(
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
+
if sigmas is None:
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(
@@ -181,6 +195,10 @@ def set_timesteps(
)
sigmas = timesteps / self.config.num_train_timesteps
+
+ if self.config.use_dynamic_shifting:
+ sigmas = self.time_shift(mu, 1.0, sigmas)
+ else:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
@@ -274,32 +292,10 @@ def step(
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
+ sigma_next = self.sigmas[self.step_index + 1]
- gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
-
- noise = randn_tensor(
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
- )
-
- eps = noise * s_noise
- sigma_hat = sigma * (gamma + 1)
-
- if gamma > 0:
- sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
-
- # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
- # NOTE: "original_sample" should not be an expected prediction_type but is left in for
- # backwards compatibility
-
- # if self.config.prediction_type == "vector_field":
-
- denoised = sample - model_output * sigma
- # 2. Convert to an ODE derivative
- derivative = (sample - denoised) / sigma_hat
-
- dt = self.sigmas[self.step_index + 1] - sigma_hat
+ prev_sample = sample + (sigma_next - sigma) * model_output
- prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 230b0b29b2c2..34c6c20f7f2c 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -152,6 +152,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class FluxTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index ead85a2a498e..3e9a33503906 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class AnimateDiffPAGPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AnimateDiffPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -287,6 +302,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class FluxPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -947,6 +977,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class PixArtSigmaPAGPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class PixArtSigmaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index 020411dc7883..1cdc02e87328 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
_check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
- return pretrained_model_name_or_path, sharded_metadata
+ return shards_path, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames
@@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
"required according to the checkpoint index."
)
- try:
- # Load from URL
- cached_folder = snapshot_download(
- pretrained_model_name_or_path,
- cache_dir=cache_dir,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- allow_patterns=allow_patterns,
- ignore_patterns=ignore_patterns,
- user_agent=user_agent,
- )
- if subfolder is not None:
- cached_folder = os.path.join(cached_folder, subfolder)
+ try:
+ # Load from URL
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ ignore_patterns=ignore_patterns,
+ user_agent=user_agent,
+ )
+ if subfolder is not None:
+ cached_folder = os.path.join(cached_folder, subfolder)
- # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
- # we don't have to catch them here. We have also dealt with EntryNotFoundError.
- except HTTPError as e:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
- " again after checking your internet connection."
- ) from e
+ # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
+ # we don't have to catch them here. We have also dealt with EntryNotFoundError.
+ except HTTPError as e:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
+ " again after checking your internet connection."
+ ) from e
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
- if local_files_only:
+ elif local_files_only:
_check_if_shards_exist_locally(
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
+ if subfolder is not None:
+ cached_folder = os.path.join(cached_folder, subfolder)
return cached_folder, sharded_metadata
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index a84968e613b5..1c688c9e9c8a 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -1068,6 +1068,17 @@ def test_load_sharded_checkpoint_from_hub_local(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
+ @require_torch_gpu
+ def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
+ _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
+ loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
+ loaded_model = loaded_model.to(torch_device)
+ new_output = loaded_model(**inputs_dict)
+
+ assert loaded_model
+ assert new_output.sample.shape == (4, 4, 16, 16)
+
@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1077,6 +1088,17 @@ def test_load_sharded_checkpoint_device_map_from_hub(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
+ @require_torch_gpu
+ def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self):
+ _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ loaded_model = self.model_class.from_pretrained(
+ "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto"
+ )
+ new_output = loaded_model(**inputs_dict)
+
+ assert loaded_model
+ assert new_output.sample.shape == (4, 4, 16, 16)
+
@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1087,6 +1109,18 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
+ @require_torch_gpu
+ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
+ _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
+ loaded_model = self.model_class.from_pretrained(
+ ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
+ )
+ new_output = loaded_model(**inputs_dict)
+
+ assert loaded_model
+ assert new_output.sample.shape == (4, 4, 16, 16)
+
@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
diff --git a/tests/pipelines/flux/__init__.py b/tests/pipelines/flux/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
new file mode 100644
index 000000000000..0dc13911c55b
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -0,0 +1,281 @@
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
+from diffusers.utils.testing_utils import (
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+)
+
+
+@unittest.skip("Tests needs to be revisited.")
+class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = FluxPipeline
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "guidance_scale",
+ "negative_prompt",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt"])
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ sample_size=32,
+ patch_size=1,
+ in_channels=4,
+ num_layers=1,
+ attention_head_dim=8,
+ num_attention_heads=4,
+ caption_projection_dim=32,
+ joint_attention_dim=32,
+ pooled_projection_dim=64,
+ out_channels=4,
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_flux_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ assert max_diff > 1e-2
+
+ def test_flux_different_negative_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["negative_prompt_2"] = "deformed"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ assert max_diff > 1e-2
+
+ def test_flux_prompt_embeds(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ output_with_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ prompt = inputs.pop("prompt")
+
+ do_classifier_free_guidance = inputs["guidance_scale"] > 1
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ text_ids,
+ ) = pipe.encode_prompt(
+ prompt,
+ prompt_2=None,
+ prompt_3=None,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ device=torch_device,
+ )
+ output_with_embeds = pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ **inputs,
+ ).images[0]
+
+ max_diff = np.abs(output_with_prompt - output_with_embeds).max()
+ assert max_diff < 1e-4
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
+
+@slow
+@require_torch_gpu
+class FluxPipelineSlowTests(unittest.TestCase):
+ pipeline_class = FluxPipeline
+ repo_id = "black-forest-labs/FLUX.1-schnell"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ return {
+ "prompt": "A photo of a cat",
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ # TODO: Dhruv. Move large model tests to a dedicated runner)
+ @unittest.skip("We cannot run inference on this model with the current CI hardware")
+ def test_flux_inference(self):
+ pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
+ pipe.enable_model_cpu_offload()
+
+ inputs = self.get_inputs(torch_device)
+
+ image = pipe(**inputs).images[0]
+ image_slice = image[0, :10, :10]
+ expected_slice = np.array(
+ [
+ [0.36132812, 0.30004883, 0.25830078],
+ [0.36669922, 0.31103516, 0.23754883],
+ [0.34814453, 0.29248047, 0.23583984],
+ [0.35791016, 0.30981445, 0.23999023],
+ [0.36328125, 0.31274414, 0.2607422],
+ [0.37304688, 0.32177734, 0.26171875],
+ [0.3671875, 0.31933594, 0.25756836],
+ [0.36035156, 0.31103516, 0.2578125],
+ [0.3857422, 0.33789062, 0.27563477],
+ [0.3701172, 0.31982422, 0.265625],
+ ],
+ dtype=np.float32,
+ )
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
+
+ assert max_diff < 1e-4
diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py
new file mode 100644
index 000000000000..8f637b991056
--- /dev/null
+++ b/tests/pipelines/pag/test_pag_animatediff.py
@@ -0,0 +1,492 @@
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AnimateDiffPAGPipeline,
+ AnimateDiffPipeline,
+ AutoencoderKL,
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ LCMScheduler,
+ MotionAdapter,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+ UNetMotionModel,
+)
+from diffusers.utils import is_xformers_available
+from diffusers.utils.testing_utils import torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ IPAdapterTesterMixin,
+ PipelineFromPipeTesterMixin,
+ PipelineTesterMixin,
+ SDFunctionTesterMixin,
+)
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+class AnimateDiffPAGPipelineFastTests(
+ IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
+):
+ pipeline_class = AnimateDiffPAGPipeline
+ params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ def get_dummy_components(self):
+ cross_attention_dim = 8
+ block_out_channels = (8, 8)
+
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=block_out_channels,
+ layers_per_block=2,
+ sample_size=8,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=2,
+ )
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="linear",
+ clip_sample=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=block_out_channels,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=2,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=cross_attention_dim,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ motion_adapter = MotionAdapter(
+ block_out_channels=block_out_channels,
+ motion_layers_per_block=2,
+ motion_norm_num_groups=2,
+ motion_num_attention_heads=4,
+ )
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "motion_adapter": motion_adapter,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "feature_extractor": None,
+ "image_encoder": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 7.5,
+ "pag_scale": 3.0,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_from_pipe_consistent_config(self):
+ assert self.original_pipeline_class == StableDiffusionPipeline
+ original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
+ original_kwargs = {"requires_safety_checker": False}
+
+ # create original_pipeline_class(sd)
+ pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
+
+ # original_pipeline_class(sd) -> pipeline_class
+ pipe_components = self.get_dummy_components()
+ pipe_additional_components = {}
+ for name, component in pipe_components.items():
+ if name not in pipe_original.components:
+ pipe_additional_components[name] = component
+
+ pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
+
+ # pipeline_class -> original_pipeline_class(sd)
+ original_pipe_additional_components = {}
+ for name, component in pipe_original.components.items():
+ if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
+ original_pipe_additional_components[name] = component
+
+ pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
+
+ # compare the config
+ original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
+ original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
+ assert original_config_2 == original_config
+
+ def test_motion_unet_loading(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ assert isinstance(pipe.unet, UNetMotionModel)
+
+ @unittest.skip("Attention slicing is not enabled in this pipeline")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_ip_adapter_single(self):
+ expected_pipe_slice = None
+
+ if torch_device == "cpu":
+ expected_pipe_slice = np.array(
+ [
+ 0.5068,
+ 0.5294,
+ 0.4926,
+ 0.4810,
+ 0.4188,
+ 0.5935,
+ 0.5295,
+ 0.3947,
+ 0.5300,
+ 0.4706,
+ 0.3950,
+ 0.4737,
+ 0.4072,
+ 0.3227,
+ 0.5481,
+ 0.4864,
+ 0.4518,
+ 0.5315,
+ 0.5979,
+ 0.5374,
+ 0.3503,
+ 0.5275,
+ 0.6067,
+ 0.4914,
+ 0.5440,
+ 0.4775,
+ 0.5538,
+ ]
+ )
+ return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
+
+ def test_dict_tuple_outputs_equivalent(self):
+ expected_slice = None
+ if torch_device == "cpu":
+ expected_slice = np.array([0.5295, 0.3947, 0.5300, 0.4864, 0.4518, 0.5315, 0.5440, 0.4775, 0.5538])
+ return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
+
+ @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ def test_prompt_embeds(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("prompt")
+ inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
+ pipe(**inputs)
+
+ def test_free_init(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs_normal = self.get_dummy_inputs(torch_device)
+ frames_normal = pipe(**inputs_normal).frames[0]
+
+ pipe.enable_free_init(
+ num_iters=2,
+ use_fast_sampling=True,
+ method="butterworth",
+ order=4,
+ spatial_stop_frequency=0.25,
+ temporal_stop_frequency=0.25,
+ )
+ inputs_enable_free_init = self.get_dummy_inputs(torch_device)
+ frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
+
+ pipe.disable_free_init()
+ inputs_disable_free_init = self.get_dummy_inputs(torch_device)
+ frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
+
+ sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
+ max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
+ self.assertGreater(
+ sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
+ )
+ self.assertLess(
+ max_diff_disabled,
+ 1e-3,
+ "Disabling of FreeInit should lead to results similar to the default pipeline results",
+ )
+
+ def test_free_init_with_schedulers(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs_normal = self.get_dummy_inputs(torch_device)
+ frames_normal = pipe(**inputs_normal).frames[0]
+
+ schedulers_to_test = [
+ DPMSolverMultistepScheduler.from_config(
+ components["scheduler"].config,
+ timestep_spacing="linspace",
+ beta_schedule="linear",
+ algorithm_type="dpmsolver++",
+ steps_offset=1,
+ clip_sample=False,
+ ),
+ LCMScheduler.from_config(
+ components["scheduler"].config,
+ timestep_spacing="linspace",
+ beta_schedule="linear",
+ steps_offset=1,
+ clip_sample=False,
+ ),
+ ]
+ components.pop("scheduler")
+
+ for scheduler in schedulers_to_test:
+ components["scheduler"] = scheduler
+ pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ frames_enable_free_init = pipe(**inputs).frames[0]
+ sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
+
+ self.assertGreater(
+ sum_enabled,
+ 1e1,
+ "Enabling of FreeInit should lead to results different from the default pipeline results",
+ )
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_without_offload = pipe(**inputs).frames[0]
+ output_without_offload = (
+ output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
+ )
+
+ pipe.enable_xformers_memory_efficient_attention()
+ inputs = self.get_dummy_inputs(torch_device)
+ output_with_offload = pipe(**inputs).frames[0]
+ output_with_offload = (
+ output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
+ )
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
+
+ def test_vae_slicing(self):
+ return super().test_vae_slicing(image_count=2)
+
+ def test_pag_disable_enable(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ # base pipeline (expect same output when pag is disabled)
+ components.pop("pag_applied_layers", None)
+ pipe_sd = AnimateDiffPipeline(**components)
+ pipe_sd = pipe_sd.to(device)
+ pipe_sd.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ del inputs["pag_scale"]
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ out = pipe_sd(**inputs).frames[0, -3:, -3:, -1]
+
+ components = self.get_dummy_components()
+
+ # pag disabled with pag_scale=0.0
+ pipe_pag = self.pipeline_class(**components)
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["pag_scale"] = 0.0
+ out_pag_disabled = pipe_pag(**inputs).frames[0, -3:, -3:, -1]
+
+ # pag enabled
+ pipe_pag = self.pipeline_class(**components)
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ out_pag_enabled = pipe_pag(**inputs).frames[0, -3:, -3:, -1]
+
+ assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
+ assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
+
+ def test_pag_applied_layers(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ # base pipeline
+ components.pop("pag_applied_layers", None)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers
+ all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k]
+ original_attn_procs = pipe.unet.attn_processors
+ pag_layers = [
+ "down",
+ "mid",
+ "up",
+ ]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
+
+ # pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.motion_modules_0"] should apply to all self-attention layers in mid_block, i.e.
+ # mid_block.motion_modules.0.transformer_blocks.0.attn1.processor
+ # mid_block.attentions.0.transformer_blocks.0.attn1.processor
+ all_self_attn_mid_layers = [
+ "mid_block.motion_modules.0.transformer_blocks.0.attn1.processor",
+ "mid_block.attentions.0.transformer_blocks.0.attn1.processor",
+ ]
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["mid"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["mid.block_0"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["mid.block_0.attentions_0", "mid.block_0.motion_modules_0"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["mid.block_0.attentions_1"]
+ with self.assertRaises(ValueError):
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+
+ # pag_applied_layers = "down" should apply to all self-attention layers in down_blocks
+ # down_blocks.1.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor
+ # down_blocks.1.(attentions|motion_modules).0.transformer_blocks.1.attn1.processor
+ # down_blocks.1.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["down"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert len(pipe.pag_attn_processors) == 6
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["down.block_0"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert (len(pipe.pag_attn_processors)) == 4
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["down.block_1"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert len(pipe.pag_attn_processors) == 2
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["down.block_1.motion_modules_2"]
+ with self.assertRaises(ValueError):
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py
new file mode 100644
index 000000000000..be86afe45be0
--- /dev/null
+++ b/tests/pipelines/pag/test_pag_pixart_sigma.py
@@ -0,0 +1,423 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ PixArtSigmaPAGPipeline,
+ PixArtSigmaPipeline,
+ PixArtTransformer2DModel,
+)
+from diffusers.utils import logging
+from diffusers.utils.testing_utils import (
+ CaptureLogger,
+ enable_full_determinism,
+ print_tensor_test,
+ torch_device,
+)
+
+from ..pipeline_params import (
+ TEXT_TO_IMAGE_BATCH_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference, to_np
+
+
+enable_full_determinism()
+
+
+class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = PixArtSigmaPAGPipeline
+ params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
+ params = set(params)
+ params.remove("cross_attention_kwargs")
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = PipelineTesterMixin.required_optional_params
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = PixArtTransformer2DModel(
+ sample_size=8,
+ num_layers=2,
+ patch_size=2,
+ attention_head_dim=8,
+ num_attention_heads=3,
+ caption_channels=32,
+ in_channels=4,
+ cross_attention_dim=24,
+ out_channels=8,
+ attention_bias=True,
+ activation_fn="gelu-approximate",
+ num_embeds_ada_norm=1000,
+ norm_type="ada_norm_single",
+ norm_elementwise_affine=False,
+ norm_eps=1e-6,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL()
+
+ scheduler = DDIMScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer.eval(),
+ "vae": vae.eval(),
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 1.0,
+ "pag_scale": 3.0,
+ "use_resolution_binning": False,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_pag_disable_enable(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ # base pipeline (expect same output when pag is disabled)
+ pipe = PixArtSigmaPipeline(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ del inputs["pag_scale"]
+ assert (
+ "pag_scale" not in inspect.signature(pipe.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
+ out = pipe(**inputs).images[0, -3:, -3:, -1]
+
+ # pag disabled with pag_scale=0.0
+ components["pag_applied_layers"] = [1]
+ pipe_pag = self.pipeline_class(**components)
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["pag_scale"] = 0.0
+ out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
+
+ # pag enabled
+ pipe_pag = self.pipeline_class(**components)
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
+
+ assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
+ assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
+
+ def test_pag_applied_layers(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ # base pipeline
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # "attn1" should apply to all self-attention layers.
+ all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k]
+ pag_layers = [0, 1]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
+
+ def test_pag_inference(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ pipe_pag = self.pipeline_class(**components)
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe_pag(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+ print_tensor_test(image_slice)
+
+ assert image.shape == (
+ 1,
+ 8,
+ 8,
+ 3,
+ ), f"the shape of the output image should be (1, 8, 8, 3) but got {image.shape}"
+ expected_slice = np.array([0.6499, 0.3250, 0.3572, 0.6780, 0.4453, 0.4582, 0.2770, 0.5168, 0.4594])
+
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+ # Copied from tests.pipelines.pixart_sigma.test_pixart.PixArtSigmaPipelineFastTests.test_save_load_optional_components
+ def test_save_load_optional_components(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ prompt = inputs["prompt"]
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = pipe.encode_prompt(prompt)
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "prompt_attention_mask": prompt_attention_mask,
+ "negative_prompt": None,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "negative_prompt_attention_mask": negative_prompt_attention_mask,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ "use_resolution_binning": False,
+ }
+
+ # set all optional components to None
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1])
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "prompt_attention_mask": prompt_attention_mask,
+ "negative_prompt": None,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "negative_prompt_attention_mask": negative_prompt_attention_mask,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ "use_resolution_binning": False,
+ }
+
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, 1e-4)
+
+ # Because the PAG PixArt Sigma has `pag_applied_layers`.
+ # Also, we shouldn't be doing `set_default_attn_processor()` after loading
+ # the pipeline with `pag_applied_layers`.
+ def test_save_load_local(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(diffusers.logging.INFO)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+
+ with CaptureLogger(logger) as cap_logger:
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1])
+
+ for name in pipe_loaded.components.keys():
+ if name not in pipe_loaded._optional_components:
+ assert name in str(cap_logger)
+
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ # We shouldn't be setting `set_default_attn_processor` here.
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ if test_mean_pixel_difference:
+ assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
+ assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
+
+ # Because we have `pag_applied_layers` we cannot direcly apply
+ # `set_default_attn_processor`
+ def test_dict_tuple_outputs_equivalent(self, expected_slice=None, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ if expected_slice is None:
+ output = pipe(**self.get_dummy_inputs(generator_device))[0]
+ else:
+ output = expected_slice
+
+ output_tuple = pipe(**self.get_dummy_inputs(generator_device), return_dict=False)[0]
+
+ if expected_slice is None:
+ max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
+ else:
+ if output_tuple.ndim != 5:
+ max_diff = np.abs(to_np(output) - to_np(output_tuple)[0, -3:, -3:, -1].flatten()).max()
+ else:
+ max_diff = np.abs(to_np(output) - to_np(output_tuple)[0, -3:, -3:, -1, -1].flatten()).max()
+
+ self.assertLess(max_diff, expected_max_difference)
+
+ # Same reason as above
+ def test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ additional_params_copy_to_batched_inputs=["num_inference_steps"],
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+ batched_inputs[name][-1] = 100 * "very long"
+
+ else:
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ for arg in additional_params_copy_to_batched_inputs:
+ batched_inputs[arg] = inputs[arg]
+
+ output = pipe(**inputs)
+ output_batch = pipe(**batched_inputs)
+
+ assert output_batch[0].shape[0] == batch_size
+
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
+ assert max_diff < expected_max_diff
+
+ # Because we're passing `pag_applied_layers` (type of List) in the components as well.
+ def test_components_function(self):
+ init_components = self.get_dummy_components()
+ init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float, list))}
+
+ pipe = self.pipeline_class(**init_components)
+
+ self.assertTrue(hasattr(pipe, "components"))
+ self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))